katu-core 0.1.1

Core traits and types for the Katu AI Agent framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
//! # katu_core::event
//!
//! ## 职责
//! 定义 LLM 层的流式事件类型(OpenCode 风格的细粒度 start/delta/end 三段式)。
//!
//! ## 设计
//! `StreamEvent` 是 provider 无关的 LLM 流式输出契约:
//! - Provider adapter 将 provider 特定的 SSE/WebSocket 帧**翻译**为 `StreamEvent`
//! - Agent loop **消费** `StreamEvent` 流,驱动工具执行和状态更新
//! - UI/持久化层可直接订阅 `StreamEvent` 做实时渲染
//!
//! ## 事件生命周期
//! ```text
//! StepStart → [TextStart → TextDelta* → TextEnd]
//!           → [ReasoningStart → ReasoningDelta* → ReasoningEnd]
//!           → [ToolCallStart → ToolCallDelta* → ToolCallEnd]
//!           → StepFinish
//! (重复多个 Step,或直到 Finish / ProviderError)
//! ```
//!
//! ## 对外接口
//! - `StreamEvent` — LLM 流式事件枚举
//! - `ToolResultValue` — 工具返回值(json / text / error)
//!
//! ## 调用者
//! - `katu-llm` (future) — provider adapter 产出 StreamEvent
//! - `katu-agent` (future) — agent loop 消费 StreamEvent
//! - UI 层 — 实时渲染流式输出

use serde::{Deserialize, Serialize};

use crate::types::{FinishReason, ToolCallId};
use crate::usage::Usage;

// ===========================================================================
// ToolResultValue
// ===========================================================================

/// 工具执行返回值 — 区分 JSON 结构、纯文本、错误三种类型。
///
/// Provider adapter 在收到 `tool-result` 帧时构造此值,
/// agent loop 据此决定是否标记为工具错误
/// !tool.rs
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolResultValue {
    /// JSON 结构化结果
    Json { value: serde_json::Value },
    /// 纯文本结果
    Text { value: String },
    /// 错误结果(工具执行失败)
    Error { value: String },
}

// ===========================================================================
// StreamEvent
// ===========================================================================

/// LLM 流式事件 — provider 无关的细粒度输出事件。
///
/// 遵循 OpenCode 的 start/delta/end 三段式设计:
/// - **start** — 标记一个内容块开始,消费者可据此创建 UI 占位
/// - **delta** — 增量数据,消费者追加到当前块
/// - **end** — 标记内容块结束,消费者可据此 finalize
///
/// 每个事件通过 `content_index` 标识其所属的内容块位置
/// (一次 LLM 回复可能包含多个并行内容块)。
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
    // ----- Step 生命周期 -----
    /// 一个推理步骤开始(provider 可能在一次请求中产出多步)
    StepStart {
        index: u32,
    },

    /// 一个推理步骤结束,携带停止原因和 token 用量
    StepFinish {
        index: u32,
        finish_reason: FinishReason,
        #[serde(skip_serializing_if = "Option::is_none")]
        usage: Option<Usage>,
    },

    // ----- Text 流 -----
    /// 文本内容块开始
    TextStart {
        content_index: usize,
    },

    /// 文本增量
    TextDelta {
        content_index: usize,
        delta: String,
    },

    /// 文本内容块结束
    TextEnd {
        content_index: usize,
    },

    // ----- Reasoning 流 -----
    /// 推理/思考内容块开始
    ReasoningStart {
        content_index: usize,
    },

    /// 推理增量
    ReasoningDelta {
        content_index: usize,
        delta: String,
    },

    /// 推理内容块结束
    ReasoningEnd {
        content_index: usize,
    },

    // ----- ToolCall 参数流 -----
    /// 工具调用开始(已知 id 和 name)
    ToolCallStart {
        content_index: usize,
        id: ToolCallId,
        name: String,
    },

    /// 工具调用参数增量(JSON 字符串片段)
    ToolCallDelta {
        content_index: usize,
        delta: String,
    },

    /// 工具调用参数流结束
    ToolCallEnd {
        content_index: usize,
    },

    // ----- Tool 执行结果(由 agent loop 注入事件流)-----
    /// 工具执行成功
    ToolResult {
        id: ToolCallId,
        name: String,
        result: ToolResultValue,
    },

    /// 工具执行失败
    ToolError {
        id: ToolCallId,
        name: String,
        message: String,
    },

    // ----- 终态 -----
    /// 整个 LLM 请求完成
    Finish {
        finish_reason: FinishReason,
        #[serde(skip_serializing_if = "Option::is_none")]
        usage: Option<Usage>,
    },

    /// Provider 级别错误(网络、认证、限流等)
    ProviderError {
        message: String,
        retryable: bool,
    },
}

// ---------------------------------------------------------------------------
// StreamEvent — helper methods
// ---------------------------------------------------------------------------

impl StreamEvent {
    /// 是否为终态事件(Finish 或 ProviderError)。
    pub fn is_terminal(&self) -> bool {
        matches!(self, Self::Finish { .. } | Self::ProviderError { .. })
    }

    /// 是否为文本增量事件。
    pub fn is_text_delta(&self) -> bool {
        matches!(self, Self::TextDelta { .. })
    }

    /// 是否为推理增量事件。
    pub fn is_reasoning_delta(&self) -> bool {
        matches!(self, Self::ReasoningDelta { .. })
    }

    /// 提取文本增量内容,非 TextDelta 事件返回 None。
    pub fn as_text_delta(&self) -> Option<&str> {
        match self {
            Self::TextDelta { delta, .. } => Some(delta.as_str()),
            _ => None,
        }
    }

    /// 提取推理增量内容,非 ReasoningDelta 事件返回 None。
    pub fn as_reasoning_delta(&self) -> Option<&str> {
        match self {
            Self::ReasoningDelta { delta, .. } => Some(delta.as_str()),
            _ => None,
        }
    }
}

// ===========================================================================
// Tests
// ===========================================================================

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_stream_event_step_lifecycle_serde() {
        let start = StreamEvent::StepStart { index: 0 };
        let json = serde_json::to_string(&start).unwrap();
        assert!(json.contains(r#""type":"step_start""#));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(start, restored);

        let finish = StreamEvent::StepFinish {
            index: 0,
            finish_reason: FinishReason::Stop,
            usage: None,
        };
        let json = serde_json::to_string(&finish).unwrap();
        assert!(json.contains(r#""type":"step_finish""#));
        assert!(!json.contains("usage"));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(finish, restored);
    }

    #[test]
    fn test_stream_event_text_serde() {
        let delta = StreamEvent::TextDelta {
            content_index: 0,
            delta: "Hello".into(),
        };
        let json = serde_json::to_string(&delta).unwrap();
        assert!(json.contains(r#""type":"text_delta""#));
        assert!(json.contains(r#""delta":"Hello""#));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(delta, restored);
    }

    #[test]
    fn test_stream_event_reasoning_serde() {
        let delta = StreamEvent::ReasoningDelta {
            content_index: 1,
            delta: "thinking...".into(),
        };
        let json = serde_json::to_string(&delta).unwrap();
        assert!(json.contains(r#""type":"reasoning_delta""#));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(delta, restored);
    }

    #[test]
    fn test_stream_event_tool_call_serde() {
        let start = StreamEvent::ToolCallStart {
            content_index: 2,
            id: ToolCallId::new("call_abc"),
            name: "read_file".into(),
        };
        let json = serde_json::to_string(&start).unwrap();
        assert!(json.contains(r#""type":"tool_call_start""#));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(start, restored);
    }

    #[test]
    fn test_stream_event_tool_result_serde() {
        let event = StreamEvent::ToolResult {
            id: ToolCallId::new("call_1"),
            name: "bash".into(),
            result: ToolResultValue::Text {
                value: "exit 0".into(),
            },
        };
        let json = serde_json::to_string(&event).unwrap();
        assert!(json.contains(r#""type":"tool_result""#));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(event, restored);
    }

    #[test]
    fn test_stream_event_tool_error_serde() {
        let event = StreamEvent::ToolError {
            id: ToolCallId::new("call_2"),
            name: "write_file".into(),
            message: "permission denied".into(),
        };
        let json = serde_json::to_string(&event).unwrap();
        assert!(json.contains(r#""type":"tool_error""#));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(event, restored);
    }

    #[test]
    fn test_stream_event_finish_with_usage() {
        let event = StreamEvent::Finish {
            finish_reason: FinishReason::ToolCalls,
            usage: Some(Usage {
                input_tokens: 100,
                output_tokens: 50,
                total_tokens: 150,
                ..Default::default()
            }),
        };
        let json = serde_json::to_string(&event).unwrap();
        assert!(json.contains(r#""type":"finish""#));
        assert!(json.contains("usage"));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(event, restored);
    }

    #[test]
    fn test_stream_event_provider_error_serde() {
        let event = StreamEvent::ProviderError {
            message: "rate limit exceeded".into(),
            retryable: true,
        };
        let json = serde_json::to_string(&event).unwrap();
        assert!(json.contains(r#""type":"provider_error""#));
        let restored: StreamEvent = serde_json::from_str(&json).unwrap();
        assert_eq!(event, restored);
    }

    // -- ToolResultValue --

    #[test]
    fn test_tool_result_value_json_serde() {
        let v = ToolResultValue::Json {
            value: serde_json::json!({"count": 42}),
        };
        let json = serde_json::to_string(&v).unwrap();
        assert!(json.contains(r#""type":"json""#));
        let restored: ToolResultValue = serde_json::from_str(&json).unwrap();
        assert_eq!(v, restored);
    }

    #[test]
    fn test_tool_result_value_text_serde() {
        let v = ToolResultValue::Text {
            value: "hello".into(),
        };
        let json = serde_json::to_string(&v).unwrap();
        assert!(json.contains(r#""type":"text""#));
        let restored: ToolResultValue = serde_json::from_str(&json).unwrap();
        assert_eq!(v, restored);
    }

    #[test]
    fn test_tool_result_value_error_serde() {
        let v = ToolResultValue::Error {
            value: "not found".into(),
        };
        let json = serde_json::to_string(&v).unwrap();
        assert!(json.contains(r#""type":"error""#));
        let restored: ToolResultValue = serde_json::from_str(&json).unwrap();
        assert_eq!(v, restored);
    }

    // -- Helper methods --

    #[test]
    fn test_is_terminal() {
        assert!(StreamEvent::Finish {
            finish_reason: FinishReason::Stop,
            usage: None,
        }
        .is_terminal());

        assert!(StreamEvent::ProviderError {
            message: "err".into(),
            retryable: false,
        }
        .is_terminal());

        assert!(!StreamEvent::TextDelta {
            content_index: 0,
            delta: "hi".into(),
        }
        .is_terminal());
    }

    #[test]
    fn test_as_text_delta() {
        let event = StreamEvent::TextDelta {
            content_index: 0,
            delta: "hello".into(),
        };
        assert_eq!(event.as_text_delta(), Some("hello"));

        let other = StreamEvent::ReasoningDelta {
            content_index: 0,
            delta: "think".into(),
        };
        assert_eq!(other.as_text_delta(), None);
    }

    #[test]
    fn test_as_reasoning_delta() {
        let event = StreamEvent::ReasoningDelta {
            content_index: 1,
            delta: "hmm".into(),
        };
        assert_eq!(event.as_reasoning_delta(), Some("hmm"));

        let other = StreamEvent::TextDelta {
            content_index: 0,
            delta: "hi".into(),
        };
        assert_eq!(other.as_reasoning_delta(), None);
    }
}