Skip to main content

lellm_core/
message.rs

1//! 消息与内容块类型。
2
3use crate::error::{ParseError, ToolResult};
4use serde::{Deserialize, Serialize};
5
6/// 缓存控制标记 — Provider 无关的语义抽象。
7///
8/// 由 Provider Codec 映射为各 Provider 的具体格式:
9/// - Anthropic: `{"type": "ephemeral"}`
10/// - OpenAI: ignore(隐式缓存)
11/// - Google: ignore
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
13pub enum CacheControl {
14    /// 缓存断点 — 标记此处为缓存边界。
15    /// 业务层在稳定性递减的层边界处插入。
16    Breakpoint,
17}
18
19/// 纯文本块
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct TextBlock {
22    pub text: String,
23
24    /// 缓存控制标记。业务层在 System prompt 的稳定性层边界处设置。
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub cache_control: Option<CacheControl>,
27}
28
29/// 思考块(Claude thinking / OpenAI reasoning)
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct ThinkingBlock {
32    pub thinking: String,
33    /// 部分 provider 支持 redacted thinking
34    pub redacted: Option<String>,
35}
36
37/// 图片资源
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub struct ImageSource {
40    /// base64 编码的图片数据
41    pub data: String,
42    /// MIME 类型,如 "image/png"
43    pub media_type: String,
44}
45
46/// LLM 请求的工具调用。
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48pub struct ToolCall {
49    pub id: String,
50    pub name: String,
51    pub arguments: serde_json::Value,
52}
53
54/// 内容块 — Message 和 ChatResponse 的基本组成单元。
55/// 核心层极简,无 provider 特有标记。
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57#[serde(tag = "type", rename_all = "snake_case")]
58pub enum ContentBlock {
59    Text(TextBlock),
60    Thinking(ThinkingBlock),
61    Image { source: ImageSource },
62    ToolCall(ToolCall),
63}
64
65impl ContentBlock {
66    pub fn text(s: String) -> Self {
67        ContentBlock::Text(TextBlock {
68            text: s,
69            cache_control: None,
70        })
71    }
72
73    /// 创建带缓存标记的文本块。
74    pub fn text_with_cache(s: String, cache: CacheControl) -> Self {
75        ContentBlock::Text(TextBlock {
76            text: s,
77            cache_control: Some(cache),
78        })
79    }
80
81    pub fn as_text(&self) -> Option<&str> {
82        match self {
83            ContentBlock::Text(block) => Some(&block.text),
84            _ => None,
85        }
86    }
87}
88
89/// 对话中的单条消息。
90#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(tag = "type", rename_all = "snake_case")]
92pub enum Message {
93    System {
94        content: Vec<ContentBlock>,
95    },
96    User {
97        content: Vec<ContentBlock>,
98    },
99    Assistant {
100        content: Vec<ContentBlock>,
101    },
102    ToolResult {
103        tool_call_id: String,
104        /// 工具执行是否失败(供 Provider API 映射,如 Anthropic `is_error: true`)
105        is_error: bool,
106        content: Vec<ContentBlock>,
107    },
108}
109
110impl Message {
111    /// 返回内容块的引用(用于 provider 适配器序列化)
112    pub fn content(&self) -> &Vec<ContentBlock> {
113        match self {
114            Message::System { content }
115            | Message::User { content }
116            | Message::Assistant { content }
117            | Message::ToolResult { content, .. } => content,
118        }
119    }
120
121    /// 返回 ToolResult 的 tool_call_id(仅 ToolResult 变体有效,其他返回 None)
122    pub fn tool_call_id(&self) -> String {
123        match self {
124            Message::ToolResult { tool_call_id, .. } => tool_call_id.clone(),
125            _ => String::new(),
126        }
127    }
128
129    /// 返回 ToolResult 的 is_error 标记(仅 ToolResult 变体有效)
130    pub fn is_tool_error(&self) -> bool {
131        matches!(self, Message::ToolResult { is_error: true, .. })
132    }
133
134    /// 从工具调用结果构建 Message::ToolResult
135    ///
136    /// 成功 → 序列化 `serde_json::Value` 为文本,`is_error: false`
137    /// 失败 → `"tool error: {e}"` 文本 content,`is_error: true`
138    pub fn tool_result(call: &ToolCall, result: &ToolResult) -> Self {
139        let (content_str, is_error) = match result {
140            Ok(v) => (
141                serde_json::to_string(v).unwrap_or_else(|_| v.to_string()),
142                false,
143            ),
144            Err(e) => (format!("tool error: {e}"), true),
145        };
146        Message::ToolResult {
147            tool_call_id: call.id.clone(),
148            is_error,
149            content: text_block(content_str),
150        }
151    }
152
153    /// 语义校验 — 检查 Message 变体与 ContentBlock 的合法性。
154    ///
155    /// v0.1 核心规则:
156    /// 1. `ToolResult` 禁止包含 `ToolCall` 或 `Thinking`
157    /// 2. `ToolResult.tool_call_id` 非空
158    /// 3. `Assistant` 中的 `ToolCall.id` 非空
159    /// 4. `User` 禁止包含 `Thinking`
160    pub fn validate(&self) -> Result<(), ParseError> {
161        match self {
162            Message::ToolResult {
163                tool_call_id,
164                is_error: _,
165                content,
166            } => {
167                if tool_call_id.is_empty() {
168                    return Err(ParseError {
169                        detail: "ToolResult.tool_call_id must not be empty".into(),
170                    });
171                }
172                for block in content {
173                    match block {
174                        ContentBlock::ToolCall(_) => {
175                            return Err(ParseError {
176                                detail: "ToolResult must not contain ToolCall blocks".into(),
177                            });
178                        }
179                        ContentBlock::Thinking(_) => {
180                            return Err(ParseError {
181                                detail: "ToolResult must not contain Thinking blocks".into(),
182                            });
183                        }
184                        _ => {}
185                    }
186                }
187            }
188            Message::Assistant { content } => {
189                for block in content {
190                    if let ContentBlock::ToolCall(tc) = block
191                        && tc.id.is_empty()
192                    {
193                        return Err(ParseError {
194                            detail: "Assistant ToolCall.id must not be empty".into(),
195                        });
196                    }
197                }
198            }
199            Message::User { content } => {
200                for block in content {
201                    if let ContentBlock::Thinking(_) = block {
202                        return Err(ParseError {
203                            detail: "User must not contain Thinking blocks".into(),
204                        });
205                    }
206                }
207            }
208            Message::System { .. } => {}
209        }
210        Ok(())
211    }
212
213    /// 提取所有 ToolCall(仅 Assistant 消息包含)
214    pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
215        match self {
216            Message::Assistant { content } => content
217                .iter()
218                .filter_map(|b| {
219                    if let ContentBlock::ToolCall(tc) = b {
220                        Some(tc.clone())
221                    } else {
222                        None
223                    }
224                })
225                .collect(),
226            _ => Vec::new(),
227        }
228    }
229}
230
231/// 便捷函数:创建纯文本块
232pub fn text_block(s: String) -> Vec<ContentBlock> {
233    vec![ContentBlock::text(s)]
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_content_block_text() {
242        let block = ContentBlock::text("hello".to_string());
243        assert_eq!(block.as_text(), Some("hello"));
244    }
245
246    #[test]
247    fn test_content_block_tool_call_no_as_text() {
248        let block = ContentBlock::ToolCall(ToolCall {
249            id: "1".into(),
250            name: "test".into(),
251            arguments: serde_json::json!({}),
252        });
253        assert_eq!(block.as_text(), None);
254    }
255
256    #[test]
257    fn test_message_content() {
258        let msg = Message::User {
259            content: text_block("hello world".to_string()),
260        };
261        assert_eq!(msg.content().len(), 1);
262        assert_eq!(msg.content()[0].as_text(), Some("hello world"));
263    }
264
265    #[test]
266    fn test_message_extract_tool_calls() {
267        let tc = ToolCall {
268            id: "1".into(),
269            name: "test".into(),
270            arguments: serde_json::json!({}),
271        };
272        let msg = Message::Assistant {
273            content: vec![ContentBlock::ToolCall(tc.clone())],
274        };
275        let calls = msg.extract_tool_calls();
276        assert_eq!(calls.len(), 1);
277        assert_eq!(calls[0].name, "test");
278    }
279
280    // ─── validate() 测试 ───
281
282    #[test]
283    fn test_validate_user_ok() {
284        let msg = Message::User {
285            content: text_block("hello".to_string()),
286        };
287        assert!(msg.validate().is_ok());
288    }
289
290    #[test]
291    fn test_validate_user_reject_thinking() {
292        let msg = Message::User {
293            content: vec![ContentBlock::Thinking(ThinkingBlock {
294                thinking: "hmm".into(),
295                redacted: None,
296            })],
297        };
298        assert!(matches!(msg.validate(), Err(ParseError { .. })));
299    }
300
301    #[test]
302    fn test_validate_assistant_ok() {
303        let msg = Message::Assistant {
304            content: text_block("hi".to_string()),
305        };
306        assert!(msg.validate().is_ok());
307    }
308
309    #[test]
310    fn test_validate_assistant_tool_call_empty_id() {
311        let msg = Message::Assistant {
312            content: vec![ContentBlock::ToolCall(ToolCall {
313                id: String::new(),
314                name: "test".into(),
315                arguments: serde_json::json!({}),
316            })],
317        };
318        assert!(matches!(msg.validate(), Err(ParseError { .. })));
319    }
320
321    #[test]
322    fn test_validate_tool_result_ok() {
323        let msg = Message::ToolResult {
324            tool_call_id: "call_1".to_string(),
325            is_error: false,
326            content: text_block("ok".to_string()),
327        };
328        assert!(msg.validate().is_ok());
329    }
330
331    #[test]
332    fn test_validate_tool_result_empty_id() {
333        let msg = Message::ToolResult {
334            tool_call_id: String::new(),
335            is_error: false,
336            content: text_block("ok".to_string()),
337        };
338        assert!(matches!(msg.validate(), Err(ParseError { .. })));
339    }
340
341    #[test]
342    fn test_validate_tool_result_reject_tool_call() {
343        let msg = Message::ToolResult {
344            tool_call_id: "call_1".to_string(),
345            is_error: false,
346            content: vec![ContentBlock::ToolCall(ToolCall {
347                id: "x".into(),
348                name: "y".into(),
349                arguments: serde_json::json!({}),
350            })],
351        };
352        assert!(matches!(msg.validate(), Err(ParseError { .. })));
353    }
354
355    #[test]
356    fn test_validate_tool_result_reject_thinking() {
357        let msg = Message::ToolResult {
358            tool_call_id: "call_1".to_string(),
359            is_error: false,
360            content: vec![ContentBlock::Thinking(ThinkingBlock {
361                thinking: "hmm".into(),
362                redacted: None,
363            })],
364        };
365        assert!(matches!(msg.validate(), Err(ParseError { .. })));
366    }
367
368    #[test]
369    fn test_validate_system_ok() {
370        let msg = Message::System {
371            content: text_block("you are helpful".to_string()),
372        };
373        assert!(msg.validate().is_ok());
374    }
375}