1use anyhow::Result;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use async_openai::types::{
7 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
8 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
9 ChatCompletionResponseMessage, Role,
10};
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub enum Message {
14 System { content: String },
15 User { content: String },
16 Assistant { content: String },
17}
18
19impl Message {
20 pub fn new(role: &str, content: &str) -> Self {
21 match role {
22 "system" => Message::system(content),
23 "user" => Message::user(content),
24 "assistant" => Message::assistant(content),
25 _ => panic!("Invalid role: {role}"),
26 }
27 }
28
29 pub fn user(content: impl Into<String>) -> Self {
30 Message::User {
31 content: content.into(),
32 }
33 }
34
35 pub fn assistant(content: impl Into<String>) -> Self {
36 Message::Assistant {
37 content: content.into(),
38 }
39 }
40
41 pub fn system(content: impl Into<String>) -> Self {
42 Message::System {
43 content: content.into(),
44 }
45 }
46
47 pub fn content(&self) -> String {
48 match self {
49 Message::System { content } => content.clone(),
50 Message::User { content } => content.clone(),
51 Message::Assistant { content } => content.clone(),
52 }
53 }
54
55 pub fn get_message_turn(&self) -> ChatCompletionRequestMessage {
56 match self {
57 Message::System { content } => ChatCompletionRequestSystemMessageArgs::default()
58 .content(content.as_str())
59 .build()
60 .unwrap()
61 .into(),
62 Message::User { content } => ChatCompletionRequestUserMessageArgs::default()
63 .content(content.as_str())
64 .build()
65 .unwrap()
66 .into(),
67 Message::Assistant { content } => ChatCompletionRequestAssistantMessageArgs::default()
68 .content(content.as_str())
69 .build()
70 .unwrap()
71 .into(),
72 }
73 }
74
75 pub fn to_json(&self) -> Value {
76 match self {
77 Message::System { content } => json!({ "role": "system", "content": content }),
78 Message::User { content } => json!({ "role": "user", "content": content }),
79 Message::Assistant { content } => json!({ "role": "assistant", "content": content }),
80 }
81 }
82}
83
84impl From<ChatCompletionResponseMessage> for Message {
85 fn from(message: ChatCompletionResponseMessage) -> Self {
86 match message.role {
87 Role::System => Message::System {
88 content: message.content.unwrap(),
89 },
90 Role::User => Message::User {
91 content: message.content.unwrap(),
92 },
93 Role::Assistant => Message::Assistant {
94 content: message.content.unwrap(),
95 },
96 _ => panic!("Invalid role: {:?}", message.role),
97 }
98 }
99}
100
101#[derive(Clone, Debug)]
102pub struct Chat {
103 pub messages: Vec<Message>,
104}
105
106impl Chat {
107 pub fn new(messages: Vec<Message>) -> Self {
108 Self { messages }
109 }
110
111 pub fn len(&self) -> usize {
112 self.messages.len()
113 }
114
115 pub fn is_empty(&self) -> bool {
116 self.messages.is_empty()
117 }
118
119 pub fn push(&mut self, role: &str, content: &str) {
120 self.messages.push(Message::new(role, content));
121 }
122
123 pub fn push_all(&mut self, chat: &Chat) {
124 self.messages.extend(chat.messages.clone());
125 }
126
127 pub fn pop(&mut self) -> Option<Message> {
128 self.messages.pop()
129 }
130
131 pub fn from_json(&self, json_dump: Value) -> Result<Self> {
132 let messages = json_dump.as_array().unwrap();
133 let messages = messages
134 .iter()
135 .map(|message| {
136 Message::new(
137 message["role"].as_str().unwrap(),
138 message["content"].as_str().unwrap(),
139 )
140 })
141 .collect();
142 Ok(Self { messages })
143 }
144
145 pub fn to_json(&self) -> Value {
146 let messages = self
147 .messages
148 .iter()
149 .map(|message| message.to_json())
150 .collect::<Vec<Value>>();
151 json!(messages)
152 }
153
154 pub fn get_async_openai_messages(&self) -> Vec<ChatCompletionRequestMessage> {
155 self.messages
156 .iter()
157 .map(|message| message.get_message_turn())
158 .collect()
159 }
160}