1use anyhow::Result;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use rig::completion::{AssistantContent, Message as RigMessage, message::UserContent};
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
9pub enum Message {
10 System { content: String },
11 User { content: String },
12 Assistant { content: String },
13}
14
15impl Message {
16 pub fn new(role: &str, content: &str) -> Self {
17 match role {
18 "system" => Message::system(content),
19 "user" => Message::user(content),
20 "assistant" => Message::assistant(content),
21 _ => panic!("Invalid role: {role}"),
22 }
23 }
24
25 pub fn user(content: impl Into<String>) -> Self {
26 Message::User {
27 content: content.into(),
28 }
29 }
30
31 pub fn assistant(content: impl Into<String>) -> Self {
32 Message::Assistant {
33 content: content.into(),
34 }
35 }
36
37 pub fn system(content: impl Into<String>) -> Self {
38 Message::System {
39 content: content.into(),
40 }
41 }
42
43 pub fn content(&self) -> String {
44 match self {
45 Message::System { content } => content.clone(),
46 Message::User { content } => content.clone(),
47 Message::Assistant { content } => content.clone(),
48 }
49 }
50
51 pub fn get_message_turn(&self) -> RigMessage {
52 match self {
53 Message::User { content } => RigMessage::user(content.clone()),
54 Message::Assistant { content } => RigMessage::assistant(content.clone()),
55 _ => panic!("Invalid role: {:?}", self),
56 }
57 }
58
59 pub fn to_json(&self) -> Value {
60 match self {
61 Message::System { content } => json!({ "role": "system", "content": content }),
62 Message::User { content } => json!({ "role": "user", "content": content }),
63 Message::Assistant { content } => json!({ "role": "assistant", "content": content }),
64 }
65 }
66}
67
68impl From<RigMessage> for Message {
69 fn from(message: RigMessage) -> Self {
70 match message {
71 RigMessage::User { content } => {
72 let text = content
73 .into_iter()
74 .find_map(|c| {
75 if let UserContent::Text(t) = c {
76 Some(t.text)
77 } else {
78 None
79 }
80 })
81 .unwrap_or_default();
82 Message::user(text)
83 }
84 RigMessage::Assistant { content, .. } => {
85 let text = content
86 .into_iter()
87 .find_map(|c| {
88 if let AssistantContent::Text(t) = c {
89 Some(t.text)
90 } else {
91 None
92 }
93 })
94 .unwrap_or_default();
95 Message::assistant(text)
96 }
97 }
98 }
99}
100
101pub struct RigChatMessage {
102 pub system: String,
103 pub conversation: Vec<RigMessage>,
104 pub prompt: RigMessage,
105}
106
107#[derive(Clone, Debug)]
108pub struct Chat {
109 pub messages: Vec<Message>,
110}
111
112impl Chat {
113 pub fn new(messages: Vec<Message>) -> Self {
114 Self { messages }
115 }
116
117 pub fn len(&self) -> usize {
118 self.messages.len()
119 }
120
121 pub fn is_empty(&self) -> bool {
122 self.messages.is_empty()
123 }
124
125 pub fn push(&mut self, role: &str, content: &str) {
126 self.messages.push(Message::new(role, content));
127 }
128
129 pub fn push_message(&mut self, message: Message) {
130 self.messages.push(message);
131 }
132
133 pub fn push_all(&mut self, chat: &Chat) {
134 self.messages.extend(chat.messages.clone());
135 }
136
137 pub fn pop(&mut self) -> Option<Message> {
138 self.messages.pop()
139 }
140
141 pub fn from_json(&self, json_dump: Value) -> Result<Self> {
142 let messages = json_dump.as_array().unwrap();
143 let messages = messages
144 .iter()
145 .map(|message| {
146 Message::new(
147 message["role"].as_str().unwrap(),
148 message["content"].as_str().unwrap(),
149 )
150 })
151 .collect();
152 Ok(Self { messages })
153 }
154
155 pub fn to_json(&self) -> Value {
156 let messages = self
157 .messages
158 .iter()
159 .map(|message| message.to_json())
160 .collect::<Vec<Value>>();
161 json!(messages)
162 }
163
164 pub fn get_rig_messages(&self) -> RigChatMessage {
165 let system: String = self.messages[0].content();
166 let conversation: Vec<RigMessage> = if self.messages.len() > 2 {
167 self.messages[1..self.messages.len() - 1]
168 .iter()
169 .map(|message| message.get_message_turn())
170 .collect::<Vec<RigMessage>>()
171 } else {
172 vec![]
173 };
174 let prompt = self.messages.last().unwrap().get_message_turn();
175
176 RigChatMessage {
177 system,
178 conversation,
179 prompt,
180 }
181 }
182}