Skip to main content

katu_core/
message.rs

1//! # katu_core::message
2//!
3//! ## 职责
4//! 定义对话消息类型:内容块、角色特化的消息结构、以及统一的 `Message` 枚举。
5//!
6//! ## 设计原则
7//! - **Provider 无关** — 不绑定任何特定 LLM SDK
8//! - **角色特化** — User / Assistant / ToolResult 各有独立 struct,
9//!   编译期禁止非法组合(如 User 消息不可能包含 ToolCall)
10//! - **Serde 友好** — 所有类型可序列化/反序列化,便于持久化和网络传输
11//! - **双层 Content** — `ContentBlock`(文本+图片)用于 User/ToolResult,
12//!   `AssistantBlock`(文本+推理+工具调用)用于 Assistant
13//!
14//! ## 对外接口
15//! - `ContentBlock` — 基础内容块(文本、图片)
16//! - `AssistantBlock` — Assistant 专用内容块(文本、推理、工具调用)
17//! - `UserContent` — 用户内容(纯文本或多模态块)
18//! - `UserMessage`, `AssistantMessage`, `ToolResultMessage` — 角色特化消息
19//! - `Message` — 统一消息枚举
20//!
21//! ## 调用者
22//! - `katu_core::event` — StreamEvent 引用 AssistantMessage
23//! - `katu_core::request` — LlmRequest 持有 Vec<Message>
24//! - 所有上层 crate 通过 `katu_core::message::*` 使用
25
26use chrono::{DateTime, Utc};
27use serde::{Deserialize, Serialize};
28
29use crate::types::{FinishReason, MessageId, Role, ToolCallId};
30use crate::usage::Usage;
31
32// ===========================================================================
33// Content Blocks
34// ===========================================================================
35
36/// 基础内容块 — 可出现在 User 消息和 ToolResult 中。
37///
38/// 使用 `type` 字段作为 serde 标签区分变体。
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40#[serde(tag = "type", rename_all = "snake_case")]
41pub enum ContentBlock {
42    /// 纯文本
43    Text {
44        text: String,
45    },
46
47    /// 图片(base64 编码的二进制数据)
48    Image {
49        /// MIME 类型,如 `"image/png"`, `"image/jpeg"`
50        media_type: String,
51        /// base64 编码的图片数据
52        data: String,
53    },
54}
55
56/// Assistant 专用内容块 — 比基础块多了推理和工具调用。
57///
58/// 编译期保证只有 Assistant 消息可以包含 `ToolCall` 和 `Reasoning`。
59#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60#[serde(tag = "type", rename_all = "snake_case")]
61pub enum AssistantBlock {
62    /// 文本输出
63    Text {
64        text: String,
65    },
66
67    /// 推理/思考内容(extended thinking)
68    Reasoning {
69        text: String,
70        /// Provider 不透明签名(如 Anthropic redacted thinking, Google thought signature)
71        #[serde(skip_serializing_if = "Option::is_none")]
72        signature: Option<String>,
73    },
74
75    /// 工具调用请求
76    ToolCall {
77        /// LLM 生成的工具调用 ID,用于关联 ToolResultMessage
78        id: ToolCallId,
79        /// 工具名称
80        name: String,
81        /// 工具参数(JSON 值)
82        arguments: serde_json::Value,
83    },
84}
85
86// ===========================================================================
87// User Content
88// ===========================================================================
89
90/// 用户消息内容 — 支持纯文本快捷方式或多模态内容块。
91///
92/// 90% 的场景是纯文本,`Text` 变体避免了
93/// `vec![ContentBlock::Text { text }]` 的样板代码。
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95#[serde(untagged)]
96pub enum UserContent {
97    /// 纯文本内容
98    Text(String),
99    /// 多模态内容块(文本 + 图片)
100    Blocks(Vec<ContentBlock>),
101}
102
103impl UserContent {
104    /// 将内容规范化为 `ContentBlock` 数组。
105    pub fn into_blocks(self) -> Vec<ContentBlock> {
106        match self {
107            Self::Text(text) => vec![ContentBlock::Text { text }],
108            Self::Blocks(blocks) => blocks,
109        }
110    }
111
112    /// 提取纯文本内容;多模态时拼接所有文本块。
113    pub fn text(&self) -> String {
114        match self {
115            Self::Text(text) => text.clone(),
116            Self::Blocks(blocks) => blocks
117                .iter()
118                .filter_map(|b| match b {
119                    ContentBlock::Text { text } => Some(text.as_str()),
120                    _ => None,
121                })
122                .collect::<Vec<_>>()
123                .join("\n"),
124        }
125    }
126}
127
128impl From<String> for UserContent {
129    fn from(text: String) -> Self {
130        Self::Text(text)
131    }
132}
133
134impl From<&str> for UserContent {
135    fn from(text: &str) -> Self {
136        Self::Text(text.to_owned())
137    }
138}
139
140impl From<Vec<ContentBlock>> for UserContent {
141    fn from(blocks: Vec<ContentBlock>) -> Self {
142        Self::Blocks(blocks)
143    }
144}
145
146// ===========================================================================
147// Role-Specific Messages
148// ===========================================================================
149
150/// 用户消息。
151#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
152pub struct UserMessage {
153    /// 消息唯一标识
154    pub id: MessageId,
155    /// 消息内容
156    pub content: UserContent,
157    /// 创建时间
158    pub timestamp: DateTime<Utc>,
159}
160
161/// LLM 回复消息。
162///
163/// 直接携带 `model`、`provider`、`usage` 等上下文信息,
164/// 便于多模型场景下追踪每轮回复的来源和开销。
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct AssistantMessage {
167    /// 消息唯一标识
168    pub id: MessageId,
169    /// 内容块数组(文本、推理、工具调用)
170    pub content: Vec<AssistantBlock>,
171    /// 使用的模型标识(如 `"gpt-4o"`, `"claude-sonnet-4-20250514"`)
172    pub model: String,
173    /// Provider 标识(如 `"openai"`, `"anthropic"`)
174    pub provider: String,
175    /// 停止原因
176    pub finish_reason: FinishReason,
177    /// Token 用量统计
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub usage: Option<Usage>,
180    /// 创建时间
181    pub timestamp: DateTime<Utc>,
182}
183
184/// 工具执行结果消息。
185///
186/// 作为独立角色存在(非嵌入 content 数组),
187/// 通过 `tool_call_id` 关联对应的 `AssistantBlock::ToolCall`。
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189pub struct ToolResultMessage {
190    /// 消息唯一标识
191    pub id: MessageId,
192    /// 关联的工具调用 ID
193    pub tool_call_id: ToolCallId,
194    /// 工具名称
195    pub tool_name: String,
196    /// 执行结果内容(支持文本 + 图片)
197    pub content: Vec<ContentBlock>,
198    /// 是否为错误结果
199    pub is_error: bool,
200    /// 创建时间
201    pub timestamp: DateTime<Utc>,
202}
203
204// ===========================================================================
205// Unified Message Enum
206// ===========================================================================
207
208/// 统一消息枚举 — 对话历史中的一条消息。
209///
210/// 使用角色特化的变体,编译期保证类型安全:
211/// - `User` — 用户输入(文本或多模态)
212/// - `Assistant` — LLM 回复(文本、推理、工具调用)
213/// - `ToolResult` — 工具执行结果
214///
215/// System prompt 不在此枚举中,它属于请求级别。
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217#[serde(tag = "role", rename_all = "snake_case")]
218pub enum Message {
219    User(UserMessage),
220    Assistant(AssistantMessage),
221    ToolResult(ToolResultMessage),
222}
223
224// ---------------------------------------------------------------------------
225// Message — accessor methods
226// ---------------------------------------------------------------------------
227
228impl Message {
229    /// 返回消息的角色。
230    pub fn role(&self) -> Role {
231        match self {
232            Self::User(_) => Role::User,
233            Self::Assistant(_) => Role::Assistant,
234            Self::ToolResult(_) => Role::Tool,
235        }
236    }
237
238    /// 返回消息的唯一标识。
239    pub fn id(&self) -> &MessageId {
240        match self {
241            Self::User(m) => &m.id,
242            Self::Assistant(m) => &m.id,
243            Self::ToolResult(m) => &m.id,
244        }
245    }
246
247    /// 返回消息的时间戳。
248    pub fn timestamp(&self) -> DateTime<Utc> {
249        match self {
250            Self::User(m) => m.timestamp,
251            Self::Assistant(m) => m.timestamp,
252            Self::ToolResult(m) => m.timestamp,
253        }
254    }
255}
256
257// ---------------------------------------------------------------------------
258// Message — convenience constructors
259// ---------------------------------------------------------------------------
260
261impl Message {
262    /// 创建一个纯文本 User 消息。
263    pub fn user(content: impl Into<UserContent>) -> Self {
264        Self::User(UserMessage {
265            id: MessageId::new(),
266            content: content.into(),
267            timestamp: Utc::now(),
268        })
269    }
270
271    /// 创建一个纯文本 Assistant 消息(用于测试或合成消息)。
272    pub fn assistant(text: impl Into<String>) -> Self {
273        Self::Assistant(AssistantMessage {
274            id: MessageId::new(),
275            content: vec![AssistantBlock::Text {
276                text: text.into(),
277            }],
278            model: String::new(),
279            provider: String::new(),
280            finish_reason: FinishReason::Stop,
281            usage: None,
282            timestamp: Utc::now(),
283        })
284    }
285
286    /// 创建一个工具结果消息。
287    pub fn tool_result(
288        tool_call_id: ToolCallId,
289        tool_name: impl Into<String>,
290        content: impl Into<String>,
291        is_error: bool,
292    ) -> Self {
293        Self::ToolResult(ToolResultMessage {
294            id: MessageId::new(),
295            tool_call_id,
296            tool_name: tool_name.into(),
297            content: vec![ContentBlock::Text {
298                text: content.into(),
299            }],
300            is_error,
301            timestamp: Utc::now(),
302        })
303    }
304}
305
306// ---------------------------------------------------------------------------
307// AssistantMessage — helper methods
308// ---------------------------------------------------------------------------
309
310impl AssistantMessage {
311    /// 提取所有文本内容,拼接为单个字符串。
312    pub fn text(&self) -> String {
313        self.content
314            .iter()
315            .filter_map(|b| match b {
316                AssistantBlock::Text { text } => Some(text.as_str()),
317                _ => None,
318            })
319            .collect::<Vec<_>>()
320            .join("")
321    }
322
323    /// 提取所有推理内容,拼接为单个字符串。
324    pub fn reasoning(&self) -> String {
325        self.content
326            .iter()
327            .filter_map(|b| match b {
328                AssistantBlock::Reasoning { text, .. } => Some(text.as_str()),
329                _ => None,
330            })
331            .collect::<Vec<_>>()
332            .join("")
333    }
334
335    /// 提取所有工具调用。
336    pub fn tool_calls(&self) -> Vec<&AssistantBlock> {
337        self.content
338            .iter()
339            .filter(|b| matches!(b, AssistantBlock::ToolCall { .. }))
340            .collect()
341    }
342
343    /// 是否包含工具调用。
344    pub fn has_tool_calls(&self) -> bool {
345        self.content
346            .iter()
347            .any(|b| matches!(b, AssistantBlock::ToolCall { .. }))
348    }
349}
350
351// ===========================================================================
352// Tests
353// ===========================================================================
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    // -- ContentBlock --
360
361    #[test]
362    fn test_content_block_text_serde() {
363        let block = ContentBlock::Text {
364            text: "hello".into(),
365        };
366        let json = serde_json::to_string(&block).unwrap();
367        assert!(json.contains(r#""type":"text""#));
368        let restored: ContentBlock = serde_json::from_str(&json).unwrap();
369        assert_eq!(block, restored);
370    }
371
372    #[test]
373    fn test_content_block_image_serde() {
374        let block = ContentBlock::Image {
375            media_type: "image/png".into(),
376            data: "iVBOR...".into(),
377        };
378        let json = serde_json::to_string(&block).unwrap();
379        assert!(json.contains(r#""type":"image""#));
380        let restored: ContentBlock = serde_json::from_str(&json).unwrap();
381        assert_eq!(block, restored);
382    }
383
384    // -- AssistantBlock --
385
386    #[test]
387    fn test_assistant_block_text_serde() {
388        let block = AssistantBlock::Text {
389            text: "hi".into(),
390        };
391        let json = serde_json::to_string(&block).unwrap();
392        assert!(json.contains(r#""type":"text""#));
393        let restored: AssistantBlock = serde_json::from_str(&json).unwrap();
394        assert_eq!(block, restored);
395    }
396
397    #[test]
398    fn test_assistant_block_reasoning_serde() {
399        let block = AssistantBlock::Reasoning {
400            text: "let me think...".into(),
401            signature: Some("sig123".into()),
402        };
403        let json = serde_json::to_string(&block).unwrap();
404        assert!(json.contains(r#""type":"reasoning""#));
405        let restored: AssistantBlock = serde_json::from_str(&json).unwrap();
406        assert_eq!(block, restored);
407    }
408
409    #[test]
410    fn test_assistant_block_reasoning_no_signature() {
411        let block = AssistantBlock::Reasoning {
412            text: "thinking".into(),
413            signature: None,
414        };
415        let json = serde_json::to_string(&block).unwrap();
416        assert!(!json.contains("signature"));
417    }
418
419    #[test]
420    fn test_assistant_block_tool_call_serde() {
421        let block = AssistantBlock::ToolCall {
422            id: ToolCallId::new("call_123"),
423            name: "read_file".into(),
424            arguments: serde_json::json!({"path": "/tmp/test.rs"}),
425        };
426        let json = serde_json::to_string(&block).unwrap();
427        assert!(json.contains(r#""type":"tool_call""#));
428        let restored: AssistantBlock = serde_json::from_str(&json).unwrap();
429        assert_eq!(block, restored);
430    }
431
432    // -- UserContent --
433
434    #[test]
435    fn test_user_content_text() {
436        let content = UserContent::from("hello");
437        assert_eq!(content.text(), "hello");
438        let blocks = content.into_blocks();
439        assert_eq!(blocks.len(), 1);
440        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
441    }
442
443    #[test]
444    fn test_user_content_blocks() {
445        let content = UserContent::Blocks(vec![
446            ContentBlock::Text {
447                text: "look at this".into(),
448            },
449            ContentBlock::Image {
450                media_type: "image/png".into(),
451                data: "base64data".into(),
452            },
453        ]);
454        assert_eq!(content.text(), "look at this");
455    }
456
457    #[test]
458    fn test_user_content_serde_text() {
459        let content = UserContent::Text("hello".into());
460        let json = serde_json::to_string(&content).unwrap();
461        assert_eq!(json, r#""hello""#);
462        let restored: UserContent = serde_json::from_str(&json).unwrap();
463        assert_eq!(content, restored);
464    }
465
466    #[test]
467    fn test_user_content_serde_blocks() {
468        let content = UserContent::Blocks(vec![ContentBlock::Text {
469            text: "hi".into(),
470        }]);
471        let json = serde_json::to_string(&content).unwrap();
472        let restored: UserContent = serde_json::from_str(&json).unwrap();
473        assert_eq!(content, restored);
474    }
475
476    // -- Message constructors --
477
478    #[test]
479    fn test_message_user_constructor() {
480        let msg = Message::user("hello");
481        assert_eq!(msg.role(), Role::User);
482        if let Message::User(u) = &msg {
483            assert_eq!(u.content.text(), "hello");
484        } else {
485            panic!("expected User");
486        }
487    }
488
489    #[test]
490    fn test_message_assistant_constructor() {
491        let msg = Message::assistant("hi there");
492        assert_eq!(msg.role(), Role::Assistant);
493        if let Message::Assistant(a) = &msg {
494            assert_eq!(a.text(), "hi there");
495            assert_eq!(a.finish_reason, FinishReason::Stop);
496        } else {
497            panic!("expected Assistant");
498        }
499    }
500
501    #[test]
502    fn test_message_tool_result_constructor() {
503        let msg = Message::tool_result(
504            ToolCallId::new("call_1"),
505            "read_file",
506            "file contents here",
507            false,
508        );
509        assert_eq!(msg.role(), Role::Tool);
510        if let Message::ToolResult(t) = &msg {
511            assert_eq!(t.tool_call_id, ToolCallId::new("call_1"));
512            assert_eq!(t.tool_name, "read_file");
513            assert!(!t.is_error);
514        } else {
515            panic!("expected ToolResult");
516        }
517    }
518
519    #[test]
520    fn test_message_tool_result_error() {
521        let msg = Message::tool_result(
522            ToolCallId::new("call_2"),
523            "write_file",
524            "permission denied",
525            true,
526        );
527        if let Message::ToolResult(t) = &msg {
528            assert!(t.is_error);
529        } else {
530            panic!("expected ToolResult");
531        }
532    }
533
534    // -- Message accessors --
535
536    #[test]
537    fn test_message_id_accessor() {
538        let msg = Message::user("test");
539        let id = msg.id().clone();
540        assert_eq!(msg.id(), &id);
541    }
542
543    #[test]
544    fn test_message_timestamp_accessor() {
545        let before = Utc::now();
546        let msg = Message::user("test");
547        let after = Utc::now();
548        assert!(msg.timestamp() >= before);
549        assert!(msg.timestamp() <= after);
550    }
551
552    // -- AssistantMessage helpers --
553
554    #[test]
555    fn test_assistant_message_text() {
556        let msg = AssistantMessage {
557            id: MessageId::new(),
558            content: vec![
559                AssistantBlock::Reasoning {
560                    text: "hmm".into(),
561                    signature: None,
562                },
563                AssistantBlock::Text {
564                    text: "Hello ".into(),
565                },
566                AssistantBlock::Text {
567                    text: "World".into(),
568                },
569            ],
570            model: "gpt-4o".into(),
571            provider: "openai".into(),
572            finish_reason: FinishReason::Stop,
573            usage: None,
574            timestamp: Utc::now(),
575        };
576        assert_eq!(msg.text(), "Hello World");
577        assert_eq!(msg.reasoning(), "hmm");
578    }
579
580    #[test]
581    fn test_assistant_message_tool_calls() {
582        let msg = AssistantMessage {
583            id: MessageId::new(),
584            content: vec![
585                AssistantBlock::Text {
586                    text: "I'll read that file.".into(),
587                },
588                AssistantBlock::ToolCall {
589                    id: ToolCallId::new("call_1"),
590                    name: "read_file".into(),
591                    arguments: serde_json::json!({"path": "foo.rs"}),
592                },
593                AssistantBlock::ToolCall {
594                    id: ToolCallId::new("call_2"),
595                    name: "grep".into(),
596                    arguments: serde_json::json!({"pattern": "fn main"}),
597                },
598            ],
599            model: "claude-sonnet-4-20250514".into(),
600            provider: "anthropic".into(),
601            finish_reason: FinishReason::ToolCalls,
602            usage: None,
603            timestamp: Utc::now(),
604        };
605        assert!(msg.has_tool_calls());
606        assert_eq!(msg.tool_calls().len(), 2);
607    }
608
609    #[test]
610    fn test_assistant_message_no_tool_calls() {
611        let msg = AssistantMessage {
612            id: MessageId::new(),
613            content: vec![AssistantBlock::Text {
614                text: "done".into(),
615            }],
616            model: String::new(),
617            provider: String::new(),
618            finish_reason: FinishReason::Stop,
619            usage: None,
620            timestamp: Utc::now(),
621        };
622        assert!(!msg.has_tool_calls());
623        assert!(msg.tool_calls().is_empty());
624    }
625
626    // -- Message serde roundtrip --
627
628    #[test]
629    fn test_message_serde_roundtrip_user() {
630        let msg = Message::user("hello world");
631        let json = serde_json::to_string(&msg).unwrap();
632        assert!(json.contains(r#""role":"user""#));
633        let restored: Message = serde_json::from_str(&json).unwrap();
634        assert_eq!(msg.role(), restored.role());
635    }
636
637    #[test]
638    fn test_message_serde_roundtrip_assistant() {
639        let msg = Message::assistant("reply");
640        let json = serde_json::to_string(&msg).unwrap();
641        assert!(json.contains(r#""role":"assistant""#));
642        let restored: Message = serde_json::from_str(&json).unwrap();
643        assert_eq!(msg.role(), restored.role());
644    }
645
646    #[test]
647    fn test_message_serde_roundtrip_tool_result() {
648        let msg = Message::tool_result(
649            ToolCallId::new("call_x"),
650            "bash",
651            "exit code 0",
652            false,
653        );
654        let json = serde_json::to_string(&msg).unwrap();
655        assert!(json.contains(r#""role":"tool_result""#));
656        let restored: Message = serde_json::from_str(&json).unwrap();
657        assert_eq!(msg.role(), restored.role());
658    }
659}