dspy_rs/core/lm/
chat.rs

1use anyhow::Result;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use rig::completion::{AssistantContent, Message as RigMessage, message::UserContent};
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
9pub enum Message {
10    System { content: String },
11    User { content: String },
12    Assistant { content: String },
13}
14
15impl Message {
16    pub fn new(role: &str, content: &str) -> Self {
17        match role {
18            "system" => Message::system(content),
19            "user" => Message::user(content),
20            "assistant" => Message::assistant(content),
21            _ => panic!("Invalid role: {role}"),
22        }
23    }
24
25    pub fn user(content: impl Into<String>) -> Self {
26        Message::User {
27            content: content.into(),
28        }
29    }
30
31    pub fn assistant(content: impl Into<String>) -> Self {
32        Message::Assistant {
33            content: content.into(),
34        }
35    }
36
37    pub fn system(content: impl Into<String>) -> Self {
38        Message::System {
39            content: content.into(),
40        }
41    }
42
43    pub fn content(&self) -> String {
44        match self {
45            Message::System { content } => content.clone(),
46            Message::User { content } => content.clone(),
47            Message::Assistant { content } => content.clone(),
48        }
49    }
50
51    pub fn get_message_turn(&self) -> RigMessage {
52        match self {
53            Message::User { content } => RigMessage::user(content.clone()),
54            Message::Assistant { content } => RigMessage::assistant(content.clone()),
55            _ => panic!("Invalid role: {:?}", self),
56        }
57    }
58
59    pub fn to_json(&self) -> Value {
60        match self {
61            Message::System { content } => json!({ "role": "system", "content": content }),
62            Message::User { content } => json!({ "role": "user", "content": content }),
63            Message::Assistant { content } => json!({ "role": "assistant", "content": content }),
64        }
65    }
66}
67
68impl From<RigMessage> for Message {
69    fn from(message: RigMessage) -> Self {
70        match message {
71            RigMessage::User { content } => {
72                let text = content
73                    .into_iter()
74                    .find_map(|c| {
75                        if let UserContent::Text(t) = c {
76                            Some(t.text)
77                        } else {
78                            None
79                        }
80                    })
81                    .unwrap_or_default();
82                Message::user(text)
83            }
84            RigMessage::Assistant { content, .. } => {
85                let text = content
86                    .into_iter()
87                    .find_map(|c| {
88                        if let AssistantContent::Text(t) = c {
89                            Some(t.text)
90                        } else {
91                            None
92                        }
93                    })
94                    .unwrap_or_default();
95                Message::assistant(text)
96            }
97        }
98    }
99}
100
101pub struct RigChatMessage {
102    pub system: String,
103    pub conversation: Vec<RigMessage>,
104    pub prompt: RigMessage,
105}
106
107#[derive(Clone, Debug)]
108pub struct Chat {
109    pub messages: Vec<Message>,
110}
111
112impl Chat {
113    pub fn new(messages: Vec<Message>) -> Self {
114        Self { messages }
115    }
116
117    pub fn len(&self) -> usize {
118        self.messages.len()
119    }
120
121    pub fn is_empty(&self) -> bool {
122        self.messages.is_empty()
123    }
124
125    pub fn push(&mut self, role: &str, content: &str) {
126        self.messages.push(Message::new(role, content));
127    }
128
129    pub fn push_message(&mut self, message: Message) {
130        self.messages.push(message);
131    }
132
133    pub fn push_all(&mut self, chat: &Chat) {
134        self.messages.extend(chat.messages.clone());
135    }
136
137    pub fn pop(&mut self) -> Option<Message> {
138        self.messages.pop()
139    }
140
141    pub fn from_json(&self, json_dump: Value) -> Result<Self> {
142        let messages = json_dump.as_array().unwrap();
143        let messages = messages
144            .iter()
145            .map(|message| {
146                Message::new(
147                    message["role"].as_str().unwrap(),
148                    message["content"].as_str().unwrap(),
149                )
150            })
151            .collect();
152        Ok(Self { messages })
153    }
154
155    pub fn to_json(&self) -> Value {
156        let messages = self
157            .messages
158            .iter()
159            .map(|message| message.to_json())
160            .collect::<Vec<Value>>();
161        json!(messages)
162    }
163
164    pub fn get_rig_messages(&self) -> RigChatMessage {
165        let system: String = self.messages[0].content();
166        let conversation: Vec<RigMessage> = if self.messages.len() > 2 {
167            self.messages[1..self.messages.len() - 1]
168                .iter()
169                .map(|message| message.get_message_turn())
170                .collect::<Vec<RigMessage>>()
171        } else {
172            vec![]
173        };
174        let prompt = self.messages.last().unwrap().get_message_turn();
175
176        RigChatMessage {
177            system,
178            conversation,
179            prompt,
180        }
181    }
182}