dspy_rs/core/lm/
chat.rs

1use anyhow::Result;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use async_openai::types::{
7    ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
8    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
9    ChatCompletionResponseMessage, Role,
10};
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub enum Message {
14    System { content: String },
15    User { content: String },
16    Assistant { content: String },
17}
18
19impl Message {
20    pub fn new(role: &str, content: &str) -> Self {
21        match role {
22            "system" => Message::system(content),
23            "user" => Message::user(content),
24            "assistant" => Message::assistant(content),
25            _ => panic!("Invalid role: {role}"),
26        }
27    }
28
29    pub fn user(content: impl Into<String>) -> Self {
30        Message::User {
31            content: content.into(),
32        }
33    }
34
35    pub fn assistant(content: impl Into<String>) -> Self {
36        Message::Assistant {
37            content: content.into(),
38        }
39    }
40
41    pub fn system(content: impl Into<String>) -> Self {
42        Message::System {
43            content: content.into(),
44        }
45    }
46
47    pub fn content(&self) -> String {
48        match self {
49            Message::System { content } => content.clone(),
50            Message::User { content } => content.clone(),
51            Message::Assistant { content } => content.clone(),
52        }
53    }
54
55    pub fn get_message_turn(&self) -> ChatCompletionRequestMessage {
56        match self {
57            Message::System { content } => ChatCompletionRequestSystemMessageArgs::default()
58                .content(content.as_str())
59                .build()
60                .unwrap()
61                .into(),
62            Message::User { content } => ChatCompletionRequestUserMessageArgs::default()
63                .content(content.as_str())
64                .build()
65                .unwrap()
66                .into(),
67            Message::Assistant { content } => ChatCompletionRequestAssistantMessageArgs::default()
68                .content(content.as_str())
69                .build()
70                .unwrap()
71                .into(),
72        }
73    }
74
75    pub fn to_json(&self) -> Value {
76        match self {
77            Message::System { content } => json!({ "role": "system", "content": content }),
78            Message::User { content } => json!({ "role": "user", "content": content }),
79            Message::Assistant { content } => json!({ "role": "assistant", "content": content }),
80        }
81    }
82}
83
84impl From<ChatCompletionResponseMessage> for Message {
85    fn from(message: ChatCompletionResponseMessage) -> Self {
86        match message.role {
87            Role::System => Message::System {
88                content: message.content.unwrap(),
89            },
90            Role::User => Message::User {
91                content: message.content.unwrap(),
92            },
93            Role::Assistant => Message::Assistant {
94                content: message.content.unwrap(),
95            },
96            _ => panic!("Invalid role: {:?}", message.role),
97        }
98    }
99}
100
101#[derive(Clone, Debug)]
102pub struct Chat {
103    pub messages: Vec<Message>,
104}
105
106impl Chat {
107    pub fn new(messages: Vec<Message>) -> Self {
108        Self { messages }
109    }
110
111    pub fn len(&self) -> usize {
112        self.messages.len()
113    }
114
115    pub fn is_empty(&self) -> bool {
116        self.messages.is_empty()
117    }
118
119    pub fn push(&mut self, role: &str, content: &str) {
120        self.messages.push(Message::new(role, content));
121    }
122
123    pub fn push_all(&mut self, chat: &Chat) {
124        self.messages.extend(chat.messages.clone());
125    }
126
127    pub fn pop(&mut self) -> Option<Message> {
128        self.messages.pop()
129    }
130
131    pub fn from_json(&self, json_dump: Value) -> Result<Self> {
132        let messages = json_dump.as_array().unwrap();
133        let messages = messages
134            .iter()
135            .map(|message| {
136                Message::new(
137                    message["role"].as_str().unwrap(),
138                    message["content"].as_str().unwrap(),
139                )
140            })
141            .collect();
142        Ok(Self { messages })
143    }
144
145    pub fn to_json(&self) -> Value {
146        let messages = self
147            .messages
148            .iter()
149            .map(|message| message.to_json())
150            .collect::<Vec<Value>>();
151        json!(messages)
152    }
153
154    pub fn get_async_openai_messages(&self) -> Vec<ChatCompletionRequestMessage> {
155        self.messages
156            .iter()
157            .map(|message| message.get_message_turn())
158            .collect()
159    }
160}