openglm/chat/
message.rs

1use serde::ser::SerializeMap;
2
3use crate::error::Error;
4
5#[derive(serde::Serialize, serde::Deserialize, Debug)]
6pub struct Function {
7    pub name: String,
8    pub arguments: serde_json::Value,
9}
10
11#[derive(serde::Serialize, serde::Deserialize, Debug)]
12pub struct ToolCall {
13    pub id: String,
14    #[serde(rename = "type")]
15    pub ty: String,
16    pub function: Function,
17}
18
19#[derive(Debug)]
20pub struct ToolMessage {
21    pub content: String,
22    pub tool_call_id: String,
23}
24
25#[derive(Debug)]
26pub enum ImageMessage {
27    Text(String),
28    ImageUrl(String),
29}
30
31// 使用一个辅助结构体来正确地序列化ImageUrl
32#[derive(serde::Serialize)]
33struct ImageUrlWrapper<'a> {
34    url: &'a str,
35}
36
37impl serde::Serialize for ImageMessage {
38    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
39    where
40        S: serde::Serializer,
41    {
42        let mut map = serializer.serialize_map(Some(2))?;
43        match self {
44            ImageMessage::Text(text) => {
45                map.serialize_entry("type", "text")?;
46                map.serialize_entry("text", text)?;
47            },
48            ImageMessage::ImageUrl(url) => {
49                map.serialize_entry("type", "image_url")?;
50                map.serialize_entry("image_url", &ImageUrlWrapper { url: url })?;
51            },
52        }
53        map.end()
54    }
55
56}
57
58#[derive(Debug)]
59pub enum ChatMessage {
60    System(String),
61    User(String),
62    Image(Vec<ImageMessage>),
63    Assistant(String),
64    ToolCall(Vec<ToolCall>),
65    Tool(ToolMessage),
66}
67
68impl serde::Serialize for ChatMessage {
69    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70    where
71        S: serde::Serializer,
72    {
73        let mut map = serializer.serialize_map(Some(2))?;
74        match self {
75            ChatMessage::System(content) => {
76                map.serialize_entry("role", "system")?;
77                map.serialize_entry("content", content)?;
78            },
79            ChatMessage::User(content) => {
80                map.serialize_entry("role", "user")?;
81                map.serialize_entry("content", content)?;
82            },
83            ChatMessage::Image(images) => {
84                map.serialize_entry("role", "user")?;
85                map.serialize_entry("content", images)?;
86            },
87            ChatMessage::Assistant(content) => {
88                map.serialize_entry("role", "assistant")?;
89                map.serialize_entry("content", content)?;
90            },
91            ChatMessage::ToolCall(tool_calls) => {
92                map.serialize_entry("role", "assistant")?;
93                map.serialize_entry("tool_calls", tool_calls)?;
94            },
95            ChatMessage::Tool(tool_message) => {
96                map.serialize_entry("role", "tool")?;
97                map.serialize_entry("tool_call_id", &tool_message.tool_call_id)?;
98                map.serialize_entry("content", &tool_message.content)?;
99            },
100        }
101
102        map.end()
103    }
104}
105
106impl <'de> serde::Deserialize<'de> for ChatMessage {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        struct MessageVisitor;
112
113        impl<'de> serde::de::Visitor<'de> for MessageVisitor {
114            type Value = ChatMessage;
115
116            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
117                formatter.write_str("chat message")
118            }
119
120            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
121            where
122                A: serde::de::MapAccess<'de>, 
123            {
124                let mut role = None;
125                let mut content: Option<serde_json::Value> = None;
126                let mut tool_calls: Option<Vec<ToolCall>> = None;
127                let mut tool_call_id: Option<String> = None;
128
129                while let Some(key) = map.next_key()? {
130                    match key {
131                        "role" => role = map.next_value()?,
132                        "content" => content = map.next_value()?,
133                        "tool_calls" => tool_calls = map.next_value()?,
134                        "tool_call_id" => tool_call_id = map.next_value()?,
135                        _ => return Err(serde::de::Error::unknown_field(key, &["role", "content", "tool_calls", "tool_call_id"])),
136                    }
137                }
138
139                let role: String = role.ok_or_else(|| serde::de::Error::missing_field("role"))?;
140
141                match (role.as_str(), content, tool_calls, tool_call_id) {
142                    ("system", Some(serde_json::Value::String(content)), None, None) => Ok(ChatMessage::System(content)),
143                    ("user", Some(serde_json::Value::String(content)), None, None) => Ok(ChatMessage::User(content)),
144                    ("user", Some(serde_json::Value::Array(content)), None, None) => {
145                        let mut images = Vec::new();
146                        for image in content {
147                            let image = image.as_object().ok_or(serde::de::Error::custom("invalid image format"))?;
148                            match image.get("type") {
149                                Some(serde_json::Value::String(ty)) => {
150                                    match ty.as_str() {
151                                        "text" => {
152                                            let content = image.get("text").ok_or(serde::de::Error::custom("missing text field"))?;
153                                            let content = content.as_str().ok_or(serde::de::Error::custom("invalid text field"))?;
154                                            images.push(ImageMessage::Text(content.to_string()));
155                                        },
156                                        "image_url" => {
157                                            let image_url = image.get("image_url").ok_or(serde::de::Error::custom("missing image_url field"))?;
158                                            let image_url = image_url.as_object().ok_or(serde::de::Error::custom("invalid image_url field"))?;
159                                            let url = image_url.get("url").ok_or(serde::de::Error::custom("missing url field"))?;
160                                            let url = url.as_str().ok_or(serde::de::Error::custom("invalid url field"))?;
161                                            images.push(ImageMessage::ImageUrl(url.to_string()));
162                                        }
163                                        _ => return Err(serde::de::Error::custom("invalid image type")),
164                                    }
165                                },
166                                _ => return Err(serde::de::Error::custom("invalid image type")),
167                            }
168                        }
169                        Ok(ChatMessage::Image(images))
170                    },
171                    ("assistant", Some(serde_json::Value::String(content)), None, None) => Ok(ChatMessage::Assistant(content)),
172                    ("assistant", None, Some(tool_calls), None) => Ok(ChatMessage::ToolCall(tool_calls)),
173                    ("tool", Some(serde_json::Value::String(content)), None, Some(tool_call_id)) => Ok(ChatMessage::Tool(ToolMessage { content, tool_call_id })),
174                    _ => Err(serde::de::Error::custom("invalid message")),
175                }
176            }
177        }
178
179        deserializer.deserialize_map(MessageVisitor)
180    }
181}
182
183#[derive(Debug)]
184pub enum AssistantMessageDelta {
185    Content(String),
186    ToolCall(Vec<ToolCall>),
187}
188
189impl <'de> serde::Deserialize<'de> for AssistantMessageDelta {
190    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191    where
192        D: serde::Deserializer<'de> 
193    {
194        struct DeltaVisitor;
195
196        impl<'de> serde::de::Visitor<'de> for DeltaVisitor {
197            type Value = AssistantMessageDelta;
198
199            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
200                formatter.write_str("assistant message delta")
201            }
202
203            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
204            where
205                A: serde::de::MapAccess<'de>, 
206            {
207                let mut role = None;
208                let mut content = None;
209                let mut tool_calls: Option<Vec<ToolCall>> = None;
210
211                while let Some(key) = map.next_key()? {
212                    match key {
213                        "role" => role = map.next_value()?,
214                        "content" => content = map.next_value()?,
215                        "tool_calls" => tool_calls = map.next_value()?,
216                        _ => return Err(serde::de::Error::unknown_field(key, &["role", "content", "tool_calls"])),
217                    }
218                }
219
220                let role: String = role.ok_or_else(|| serde::de::Error::missing_field("role"))?;
221
222                match (role.as_str(), content, tool_calls) {
223                    ("assistant", Some(content), None) => Ok(AssistantMessageDelta::Content(content)),
224                    ("assistant", None, Some(tool_calls)) => Ok(AssistantMessageDelta::ToolCall(tool_calls)),
225                    _ => Err(serde::de::Error::custom("invalid message")),
226                }
227            }
228        }
229
230        deserializer.deserialize_map(DeltaVisitor)
231    }
232}
233
234impl TryFrom<Vec<AssistantMessageDelta>> for ChatMessage {
235    type Error = Error;
236    
237    fn try_from(value: Vec<AssistantMessageDelta>) -> Result<Self, Self::Error> {
238        let mut message = None;
239        for delta in value {
240            match delta {
241                AssistantMessageDelta::Content(income) => {
242                    match message {
243                        Some(ChatMessage::Assistant(ref mut content)) => {
244                            content.push_str(&income);
245                        },
246                        None => {
247                            message = Some(ChatMessage::Assistant(income));
248                        },
249                        _ => return Err(Error::Conflict),
250                    }
251                },
252                AssistantMessageDelta::ToolCall(mut income) => {
253                    match message {
254                        Some(ChatMessage::ToolCall(ref mut tool_calls)) => {
255                            tool_calls.append(&mut income);
256                        },
257                        None => {
258                            message = Some(ChatMessage::ToolCall(income));
259                        },
260                        _ => return Err(Error::Conflict),
261                    }
262                }
263            }
264        }
265
266        message.ok_or(Error::EmptyDeltaList)
267    }
268}