Skip to main content

aagt_core/agent/
message.rs

1//! Message types for LLM communication
2
3use serde::{Deserialize, Serialize};
4
5/// Role of the message sender
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    /// System message (instructions)
10    System,
11    /// User message
12    User,
13    /// Assistant (AI) message
14    Assistant,
15    /// Tool result message
16    Tool,
17}
18
19/// Content of a message
20#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(untagged)]
22pub enum Content {
23    /// Simple text content
24    Text(String),
25    /// Structured content with multiple parts
26    Parts(Vec<ContentPart>),
27}
28
29impl Role {
30    pub fn as_str(&self) -> &str {
31        match self {
32            Role::System => "system",
33            Role::User => "user",
34            Role::Assistant => "assistant",
35            Role::Tool => "tool",
36        }
37    }
38}
39
40impl Content {
41    /// Create text content
42    pub fn text(text: impl Into<String>) -> Self {
43        Self::Text(text.into())
44    }
45
46    /// Create multi-part content
47    pub fn parts(parts: Vec<ContentPart>) -> Self {
48        Self::Parts(parts)
49    }
50
51    /// Get as text (concatenates parts if needed)
52    pub fn as_text(&self) -> String {
53        match self {
54            Self::Text(t) => t.clone(),
55            Self::Parts(parts) => parts
56                .iter()
57                .filter_map(|p| match p {
58                    ContentPart::Text { text } => Some(text.as_str()),
59                    _ => None,
60                })
61                .collect::<Vec<_>>()
62                .join("\n"),
63        }
64    }
65}
66
67impl From<String> for Content {
68    fn from(s: String) -> Self {
69        Self::Text(s)
70    }
71}
72
73impl From<&str> for Content {
74    fn from(s: &str) -> Self {
75        Self::Text(s.to_string())
76    }
77}
78
79/// A part of structured content
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(tag = "type", rename_all = "snake_case")]
82pub enum ContentPart {
83    /// Text content
84    Text {
85        /// The text
86        text: String,
87    },
88    /// Image content (base64 or URL)
89    Image {
90        /// Image source (base64 data or URL)
91        source: ImageSource,
92    },
93    /// Tool call from assistant
94    ToolCall {
95        /// Unique ID for this tool call
96        id: String,
97        /// Name of the tool to call
98        name: String,
99        /// Arguments as JSON
100        arguments: serde_json::Value,
101    },
102    /// Tool result from user
103    ToolResult {
104        /// ID of the tool call this is responding to
105        tool_call_id: String,
106        /// Optional name of the tool (required by some providers like Gemini)
107        #[serde(skip_serializing_if = "Option::is_none")]
108        name: Option<String>,
109        /// Result content
110        content: String,
111    },
112}
113
114/// Source for image content
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(tag = "type", rename_all = "snake_case")]
117pub enum ImageSource {
118    /// Base64 encoded image
119    Base64 {
120        /// Media type (e.g., "image/png")
121        media_type: String,
122        /// Base64 encoded data
123        data: String,
124    },
125    /// URL to an image
126    Url {
127        /// Image URL
128        url: String,
129    },
130}
131
132/// A message in the conversation
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct Message {
135    /// Role of the sender
136    pub role: Role,
137    /// Content of the message
138    pub content: Content,
139    /// Optional name (for multi-agent scenarios)
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub name: Option<String>,
142}
143
144impl Message {
145    /// Create a new message
146    pub fn new(role: Role, content: impl Into<Content>) -> Self {
147        Self {
148            role,
149            content: content.into(),
150            name: None,
151        }
152    }
153
154    /// Create a system message
155    pub fn system(content: impl Into<Content>) -> Self {
156        Self::new(Role::System, content)
157    }
158
159    /// Create a user message
160    pub fn user(content: impl Into<Content>) -> Self {
161        Self::new(Role::User, content)
162    }
163
164    /// Create an assistant message
165    pub fn assistant(content: impl Into<Content>) -> Self {
166        Self::new(Role::Assistant, content)
167    }
168
169    /// Create a tool result message
170    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
171        Self {
172            role: Role::Tool,
173            content: Content::Parts(vec![ContentPart::ToolResult {
174                tool_call_id: tool_call_id.into(),
175                name: None,
176                content: content.into(),
177            }]),
178            name: None,
179        }
180    }
181
182    /// Set the tool name for a tool result message (required for Gemini)
183    pub fn with_tool_name(mut self, tool_name: impl Into<String>) -> Self {
184        // Since 'name' is already a field in this method (from self.name), lets use tool_name
185        let tool_name = tool_name.into();
186
187        if let Content::Parts(parts) = &mut self.content {
188            for part in parts {
189                if let ContentPart::ToolResult { name, .. } = part {
190                    *name = Some(tool_name.clone());
191                    // Typically only one tool result per message, so break
192                    break;
193                }
194            }
195        }
196        self
197    }
198
199    /// Set the name for this message
200    pub fn with_name(mut self, name: impl Into<String>) -> Self {
201        self.name = Some(name.into());
202        self
203    }
204
205    /// Get the text content of this message
206    pub fn text(&self) -> String {
207        self.content.as_text()
208    }
209}
210
211/// Tool call extracted from assistant response
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct ToolCall {
214    /// Unique ID for this tool call
215    pub id: String,
216    /// Name of the tool
217    pub name: String,
218    /// Arguments as JSON
219    pub arguments: serde_json::Value,
220}
221
222impl ToolCall {
223    /// Create a new tool call
224    pub fn new(
225        id: impl Into<String>,
226        name: impl Into<String>,
227        arguments: serde_json::Value,
228    ) -> Self {
229        Self {
230            id: id.into(),
231            name: name.into(),
232            arguments,
233        }
234    }
235
236    /// Parse arguments into a typed struct
237    pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
238        serde_json::from_value(self.arguments.clone())
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_message_creation() {
248        let msg = Message::user("Hello");
249        assert_eq!(msg.role, Role::User);
250        assert_eq!(msg.text(), "Hello");
251    }
252
253    #[test]
254    fn test_tool_call_parse() {
255        #[derive(Deserialize)]
256        struct SwapArgs {
257            from: String,
258            to: String,
259            amount: f64,
260        }
261
262        let call = ToolCall::new(
263            "call_123",
264            "swap_tokens",
265            serde_json::json!({
266                "from": "USDC",
267                "to": "SOL",
268                "amount": 100.0
269            }),
270        );
271
272        let args: SwapArgs = call.parse_args().expect("parse should succeed");
273        assert_eq!(args.from, "USDC");
274        assert_eq!(args.to, "SOL");
275        assert!((args.amount - 100.0).abs() < f64::EPSILON);
276    }
277
278    #[test]
279    fn test_tool_result_name() {
280        let msg = Message::tool_result("call_1", "result").with_tool_name("get_price");
281        if let Content::Parts(parts) = msg.content {
282            if let ContentPart::ToolResult { name, .. } = &parts[0] {
283                assert_eq!(name.as_deref(), Some("get_price"));
284            } else {
285                panic!("Wrong part type");
286            }
287        }
288    }
289}