call_agent/chat/
prompt.rs

1use std::fmt;
2
3use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::Value;
5
6use super::function::FunctionCall;
7
8/// Represents a prompt message with different roles.
9///
10/// This enum describes various types of messages used in prompts.
11/// It supports user messages, function messages, and assistant messages.
12/// Each variant holds the content of the message.
13#[derive(Clone)]
14pub enum Message {
15    /// A message sent by a user.
16    /// should the name matches the pattern '^[a-zA-Z0-9_-]+$'."
17    User { 
18        name: Option<String>,
19        content: Vec<MessageContext> 
20    },
21    /// A message sent by a function, including its name.
22    Tool { 
23        tool_call_id: String,
24        content: Vec<MessageContext> 
25    },
26    /// A message from the assistant.
27    /// should the name matches the pattern '^[a-zA-Z0-9_-]+$'."
28    Assistant { 
29        name: Option<String>,
30        content: Vec<MessageContext>, 
31        tool_calls: Option<Vec<FunctionCall>>,
32    },
33    /// A system prompt.
34    /// should the name matches the pattern '^[a-zA-Z0-9_-]+$'."
35    System { 
36        name: Option<String>,
37        content: String
38    },
39    /// A message from the developer.
40    /// Treated as a system message in unsupported models.
41    /// should the name matches the pattern '^[a-zA-Z0-9_-]+$'."
42    Developer { 
43        name: Option<String>,
44        content: String
45    },
46}
47
48impl fmt::Debug for Message {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            Message::User { name, content } => {
52                writeln!(f, "User: {}", name.as_deref().unwrap_or("Anonymous"))?;
53                for ctx in content {
54                    match ctx {
55                        MessageContext::Text(text) => writeln!(f, "    {}", text)?,
56                        MessageContext::Image(image) => writeln!(f, "    [Image URL: {}]", image.url)?,
57                    }
58                }
59                Ok(())
60            }
61            Message::Tool { tool_call_id, content } => {
62                writeln!(f, "Tool: {} - Tool Call", tool_call_id)?;
63                for ctx in content {
64                    match ctx {
65                        MessageContext::Text(text) => writeln!(f, "    {}", text)?,
66                        MessageContext::Image(image) => writeln!(f, "    [Image URL: {}]", image.url)?,
67                    }
68                }
69                Ok(())
70            }
71            Message::Assistant { name, content, tool_calls } => {
72                writeln!(f, "Assistant: {}", name.as_deref().unwrap_or("Assistant"))?;
73                for ctx in content {
74                    match ctx {
75                        MessageContext::Text(text) => writeln!(f, "    {}", text)?,
76                        MessageContext::Image(image) => writeln!(f, "    [Image URL: {}]", image.url)?,
77                    }
78                }
79                if let Some(calls) = tool_calls {
80                    for call in calls {
81                        writeln!(f, "    Tool Call: {:?}", call)?;
82                    }
83                }
84                Ok(())
85            }
86            Message::System { name, content } => {
87                writeln!(f, "System: {}", name.as_deref().unwrap_or("System"))?;
88                writeln!(f, "    {}", content)
89            }
90            Message::Developer { name, content } => {
91                writeln!(f, "Developer: {}", name.as_deref().unwrap_or("Developer"))?;
92                writeln!(f, "    {}", content)
93            }
94        }
95    }
96}
97
98// Custom serialization implementation for Message.
99impl Serialize for Message {
100    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
101    where
102        S: Serializer,
103    {
104        let state = match self {
105            Message::User { name, content } => {
106                let mut s = serializer.serialize_struct("Message", 3)?;
107                s.serialize_field("role", "user")?;
108                if let Some(name) = name {
109                    s.serialize_field("name", name)?;
110                }
111                serialize_content_field(&mut s, content)?;
112                s
113            }
114            Message::Tool { tool_call_id, content } => {
115                let mut s = serializer.serialize_struct("Message", 2)?;
116                s.serialize_field("role", "tool")?;
117                s.serialize_field("tool_call_id", tool_call_id)?;
118
119                serialize_content_field(&mut s, content)?;
120                s
121            }
122            Message::Assistant { name, content, tool_calls } => {
123                let mut s = serializer.serialize_struct("Message", 3)?;
124                s.serialize_field("role", "assistant")?;
125                if let Some(name) = name {
126                    s.serialize_field("name", name)?;
127                }
128                serialize_content_field(&mut s, content)?;
129                if let Some(tool_calls) = tool_calls {
130                    s.serialize_field("tool_calls", tool_calls)?;
131                }
132                s
133            }
134            Message::System { name, content } => {
135                let mut s = serializer.serialize_struct("Message", 3)?;
136                s.serialize_field("role", "system")?;
137                if let Some(name) = name {
138                    s.serialize_field("name", name)?;
139                }
140                s.serialize_field("content", content)?;
141                s
142            }
143            Message::Developer { name, content } => {
144                let mut s = serializer.serialize_struct("Message", 3)?;
145                s.serialize_field("role", "developer")?;
146                if let Some(name) = name {
147                    s.serialize_field("name", name)?;
148                }
149                s.serialize_field("content", content)?;
150                s
151            }
152        };
153        state.end()
154    }
155}
156
157/// Helper function for serializing the "content" field of a message.
158///
159/// If the `content` vector has exactly one element and it is a text message, it serializes the
160/// element directly. Otherwise, it serializes the entire vector.
161fn serialize_content_field<S>(
162    state: &mut S,
163    content: &Vec<MessageContext>,
164) -> Result<(), S::Error>
165where
166    S: SerializeStruct,
167{
168    if content.len() == 1 {
169        if let MessageContext::Text(text) = &content[0] {
170            state.serialize_field("content", text)?;
171        } else {
172            state.serialize_field("content", content)?;
173        }
174    } else {
175        state.serialize_field("content", content)?;
176    }
177    Ok(())
178}
179
180// Custom deserialization implementation for Message.
181impl<'de> Deserialize<'de> for Message {
182    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
183    where
184        D: Deserializer<'de>,
185    {
186        let value: Value = Deserialize::deserialize(deserializer)?;
187
188        let role = value.get("role").and_then(Value::as_str).unwrap_or("");
189
190        match role {
191            "user" => {
192            let name = value.get("name").and_then(Value::as_str).map(String::from);
193            let content = serde_json::from_value(
194                value.get("content").cloned().unwrap_or_default(),
195            )
196            .map_err(serde::de::Error::custom)?;
197            Ok(Message::User { name, content })
198            }
199            "tool" => {
200            let tool_call_id = value
201                .get("tool_call_id")
202                .and_then(Value::as_str)
203                .ok_or_else(|| serde::de::Error::missing_field("tool_call_id"))?
204                .to_string();
205            let content = serde_json::from_value(
206                value.get("content").cloned().unwrap_or_default(),
207            )
208            .map_err(serde::de::Error::custom)?;
209            Ok(Message::Tool { tool_call_id, content })
210            }
211            "assistant" => {
212                let name = value.get("name").and_then(Value::as_str).map(String::from);
213                let content = serde_json::from_value(
214                    value.get("content").cloned().unwrap_or_default(),
215                )
216                .map_err(serde::de::Error::custom)?;
217                let tool_calls = value.get("tool_calls").map_or(Ok(None), |v| {
218                    serde_json::from_value(v.clone()).map(Some)
219                }).map_err(serde::de::Error::custom)?;
220                Ok(Message::Assistant { name, content, tool_calls })
221            }
222            "system" => {
223                let name = value.get("name").and_then(Value::as_str).map(String::from);
224                let content = value
225                    .get("content")
226                    .and_then(Value::as_str)
227                    .ok_or_else(|| serde::de::Error::missing_field("content"))?
228                    .to_string();
229                Ok(Message::System { name, content })
230            }
231            "developer" => {
232                let name = value.get("name").and_then(Value::as_str).map(String::from);
233                let content = value
234                    .get("content")
235                    .and_then(Value::as_str)
236                    .ok_or_else(|| serde::de::Error::missing_field("content"))?
237                    .to_string();
238                Ok(Message::Developer { name, content })
239            }
240            _ => Err(serde::de::Error::custom("Invalid message type")),
241        }
242    }
243}
244
245/// Represents a context within a message.
246///
247/// This enum supports either textual content or image content.
248#[derive(Debug, Deserialize, Clone)]
249pub enum MessageContext {
250    /// A text message context.
251    Text(String),
252    /// An image message context.
253    Image(MessageImage),
254}
255
256// Custom serialization implementation for MessageContext.
257impl Serialize for MessageContext {
258    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
259    where
260        S: Serializer,
261    {
262        match self {
263            MessageContext::Text(text) => {
264                let mut state = serializer.serialize_struct("MessageContext", 2)?;
265                state.serialize_field("type", "text")?;
266                state.serialize_field("text", text)?;
267                state.end()
268            }
269            MessageContext::Image(image) => {
270                let mut state = serializer.serialize_struct("MessageContext", 2)?;
271                state.serialize_field("type", "image_url")?;
272                state.serialize_field("image_url", image)?;
273                state.end()
274            }
275        }
276    }
277}
278
279/// Represents an image used within a message.
280///
281/// Contains a URL for the image and an optional detail representing the image resolution.
282#[derive(Debug, Serialize, Deserialize, Clone)]
283pub struct MessageImage {
284    /// The image URL, which may be an HTTP URL or a base64-encoded data URI.
285    ///
286    /// For example:
287    /// - "data:image/jpeg;base64,{IMAGE_DATA}"
288    /// - "https://example.com/image.jpg"
289    pub url: String,
290
291    /// The resolution detail of the image.
292    ///
293    /// For example, for OpenAI API, valid values are:
294    /// - "low"
295    /// - "medium"
296    /// - "auto" (default)
297    #[serde(skip_serializing_if = "Option::is_none")]
298    pub detail: Option<String>,
299}
300
301/// Represents a choice from the API response.
302///
303/// A choice contains a response message and a finish reason.
304#[derive(Debug, Deserialize, Clone)]
305pub struct Choice {
306    /// The index of the choice in the response.
307    pub index: usize,
308
309    /// The message associated with this choice.
310    pub message: ResponseMessage,
311
312    /// The reason for finishing, as a string.
313    pub finish_reason: String,
314}
315
316/// Represents a response message from the API.
317///
318/// Contains the role of the responder, optional text content, and an optional function call.
319#[derive(Debug, Deserialize, Clone)]
320pub struct ResponseMessage {
321    /// The role of the message sender.
322    pub role: String,
323    
324    /// The text content of the message (if any).
325    pub content: Option<String>,
326
327    /// An optional function call associated with the message.
328    pub tool_calls: Option<Vec<FunctionCall>>,
329
330    /// An optional refusal message.
331    pub refusal: Option<String>,
332
333    /// annotation for web search options
334    #[serde(default)]
335    pub annotations: Option<serde_json::Value>
336}