1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9 System,
11 User,
13 Assistant,
15 Tool,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(untagged)]
22pub enum Content {
23 Text(String),
25 Parts(Vec<ContentPart>),
27}
28
29impl Role {
30 pub fn as_str(&self) -> &str {
31 match self {
32 Role::System => "system",
33 Role::User => "user",
34 Role::Assistant => "assistant",
35 Role::Tool => "tool",
36 }
37 }
38}
39
40impl Content {
41 pub fn text(text: impl Into<String>) -> Self {
43 Self::Text(text.into())
44 }
45
46 pub fn parts(parts: Vec<ContentPart>) -> Self {
48 Self::Parts(parts)
49 }
50
51 pub fn as_text(&self) -> String {
53 match self {
54 Self::Text(t) => t.clone(),
55 Self::Parts(parts) => parts
56 .iter()
57 .filter_map(|p| match p {
58 ContentPart::Text { text } => Some(text.as_str()),
59 _ => None,
60 })
61 .collect::<Vec<_>>()
62 .join("\n"),
63 }
64 }
65}
66
67impl From<String> for Content {
68 fn from(s: String) -> Self {
69 Self::Text(s)
70 }
71}
72
73impl From<&str> for Content {
74 fn from(s: &str) -> Self {
75 Self::Text(s.to_string())
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(tag = "type", rename_all = "snake_case")]
82pub enum ContentPart {
83 Text {
85 text: String,
87 },
88 Image {
90 source: ImageSource,
92 },
93 ToolCall {
95 id: String,
97 name: String,
99 arguments: serde_json::Value,
101 },
102 ToolResult {
104 tool_call_id: String,
106 #[serde(skip_serializing_if = "Option::is_none")]
108 name: Option<String>,
109 content: String,
111 },
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(tag = "type", rename_all = "snake_case")]
117pub enum ImageSource {
118 Base64 {
120 media_type: String,
122 data: String,
124 },
125 Url {
127 url: String,
129 },
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct Message {
135 pub role: Role,
137 pub content: Content,
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub name: Option<String>,
142}
143
144impl Message {
145 pub fn new(role: Role, content: impl Into<Content>) -> Self {
147 Self {
148 role,
149 content: content.into(),
150 name: None,
151 }
152 }
153
154 pub fn system(content: impl Into<Content>) -> Self {
156 Self::new(Role::System, content)
157 }
158
159 pub fn user(content: impl Into<Content>) -> Self {
161 Self::new(Role::User, content)
162 }
163
164 pub fn assistant(content: impl Into<Content>) -> Self {
166 Self::new(Role::Assistant, content)
167 }
168
169 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
171 Self {
172 role: Role::Tool,
173 content: Content::Parts(vec![ContentPart::ToolResult {
174 tool_call_id: tool_call_id.into(),
175 name: None,
176 content: content.into(),
177 }]),
178 name: None,
179 }
180 }
181
182 pub fn with_tool_name(mut self, tool_name: impl Into<String>) -> Self {
184 let tool_name = tool_name.into();
186
187 if let Content::Parts(parts) = &mut self.content {
188 for part in parts {
189 if let ContentPart::ToolResult { name, .. } = part {
190 *name = Some(tool_name.clone());
191 break;
193 }
194 }
195 }
196 self
197 }
198
199 pub fn with_name(mut self, name: impl Into<String>) -> Self {
201 self.name = Some(name.into());
202 self
203 }
204
205 pub fn text(&self) -> String {
207 self.content.as_text()
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct ToolCall {
214 pub id: String,
216 pub name: String,
218 pub arguments: serde_json::Value,
220}
221
222impl ToolCall {
223 pub fn new(
225 id: impl Into<String>,
226 name: impl Into<String>,
227 arguments: serde_json::Value,
228 ) -> Self {
229 Self {
230 id: id.into(),
231 name: name.into(),
232 arguments,
233 }
234 }
235
236 pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
238 serde_json::from_value(self.arguments.clone())
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_message_creation() {
248 let msg = Message::user("Hello");
249 assert_eq!(msg.role, Role::User);
250 assert_eq!(msg.text(), "Hello");
251 }
252
253 #[test]
254 fn test_tool_call_parse() {
255 #[derive(Deserialize)]
256 struct SwapArgs {
257 from: String,
258 to: String,
259 amount: f64,
260 }
261
262 let call = ToolCall::new(
263 "call_123",
264 "swap_tokens",
265 serde_json::json!({
266 "from": "USDC",
267 "to": "SOL",
268 "amount": 100.0
269 }),
270 );
271
272 let args: SwapArgs = call.parse_args().expect("parse should succeed");
273 assert_eq!(args.from, "USDC");
274 assert_eq!(args.to, "SOL");
275 assert!((args.amount - 100.0).abs() < f64::EPSILON);
276 }
277
278 #[test]
279 fn test_tool_result_name() {
280 let msg = Message::tool_result("call_1", "result").with_tool_name("get_price");
281 if let Content::Parts(parts) = msg.content {
282 if let ContentPart::ToolResult { name, .. } = &parts[0] {
283 assert_eq!(name.as_deref(), Some("get_price"));
284 } else {
285 panic!("Wrong part type");
286 }
287 }
288 }
289}