Skip to main content

katu_core/
error.rs

1//! # katu_core::error
2//!
3//! ## 职责
4//! 定义全局错误类型。所有 crate 的错误最终汇聚到 `Error` 枚举。
5//!
6//! ## 依赖
7//! - `katu_core::types` — ToolCallId
8//!
9//! ## 对外接口
10//! - `Error` — 顶层错误枚举
11//! - `Result<T>` — 类型别名
12//! - `ProviderErrorKind` — Provider 错误分类(retryable 判定依据)
13//! - `AuthErrorKind` — 认证错误细分
14//!
15//! ## 调用者
16//! - `katu_core::message` — Message 构造可能失败
17//! - `katu_core::event` — AgentEvent 携带错误信息
18//! - `katu-provider` — 将 HTTP 错误映射为 ProviderErrorKind
19//! - `katu-tools` — 工具执行失败时构造 Error::Tool
20//! - `katu-agent` — Agent loop match Error 决定下一步行为
21
22use std::fmt;
23use std::time::Duration;
24
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28use crate::types::ToolCallId;
29
30// ---------------------------------------------------------------------------
31// Result 别名
32// ---------------------------------------------------------------------------
33
34/// Katu 全局 Result 别名。
35///
36/// # Examples
37/// ```
38/// use katu_core::error::Result;
39///
40/// fn do_something() -> Result<()> {
41///     Ok(())
42/// }
43/// ```
44pub type Result<T, E = Error> = std::result::Result<T, E>;
45
46// ---------------------------------------------------------------------------
47// Error — 顶层错误枚举
48// ---------------------------------------------------------------------------
49
50/// 顶层错误枚举。
51///
52/// Agent loop 对此做 match 决定行为:
53/// - `Provider { kind: RateLimit { .. }, .. }` → 退避重试
54/// - `Provider { kind: Authentication { .. }, .. }` → 终止并提示用户
55/// - `Tool { .. }` → 将错误作为 ToolResult 回传给 LLM
56/// - `ContextOverflow { .. }` → 触发上下文压缩
57///
58/// # Examples
59/// ```
60/// use katu_core::error::{Error, ProviderErrorKind};
61/// use katu_core::types::ToolCallId;
62///
63/// // 构造一个工具错误
64/// let err = Error::tool("read_file", ToolCallId::new("call_1"), "file not found");
65/// assert!(err.to_string().contains("read_file"));
66/// assert!(!err.retryable());
67///
68/// // 构造一个可重试的 Provider 错误
69/// let err = Error::provider(
70///     ProviderErrorKind::RateLimit {
71///         message: "too many requests".into(),
72///         retry_after: None,
73///     },
74///     "please wait 30 seconds",
75/// );
76/// assert!(err.retryable());
77/// assert!(err.retry_after().is_some());
78/// ```
79#[derive(Debug, Error)]
80pub enum Error {
81    /// LLM Provider 返回的错误。
82    #[error("provider error: {kind}")]
83    Provider {
84        kind: ProviderErrorKind,
85        /// 可选的恢复建议,面向终端用户。
86        suggestion: Option<String>,
87    },
88
89    /// 工具执行失败。
90    /// 非致命:Agent loop 将此作为 tool_result 回传 LLM,让模型自行修正。
91    #[error("tool `{name}` failed: {message}")]
92    Tool {
93        name: String,
94        call_id: ToolCallId,
95        message: String,
96        #[source]
97        source: Option<Box<dyn std::error::Error + Send + Sync>>,
98    },
99
100    /// 上下文窗口溢出。
101    /// Agent loop 收到此错误后应触发上下文压缩策略。
102    #[error("context overflow: {used}/{limit} tokens")]
103    ContextOverflow { used: usize, limit: usize },
104
105    /// 配置错误(缺失字段、无效值)。
106    #[error("config error: {message}")]
107    Config { message: String },
108
109    /// 用户或系统取消操作。
110    #[error("cancelled")]
111    Cancelled,
112
113    /// 不可归类的内部错误。
114    #[error("{0}")]
115    Internal(#[from] Box<dyn std::error::Error + Send + Sync>),
116}
117
118impl Error {
119    /// 构造一个 Tool 错误。
120    ///
121    /// # Examples
122    /// ```
123    /// use katu_core::error::Error;
124    /// use katu_core::types::ToolCallId;
125    ///
126    /// let err = Error::tool("bash", ToolCallId::new("call_42"), "command not found");
127    /// assert!(matches!(err, Error::Tool { .. }));
128    /// ```
129    pub fn tool(
130        name: impl Into<String>,
131        call_id: ToolCallId,
132        message: impl Into<String>,
133    ) -> Self {
134        Self::Tool {
135            name: name.into(),
136            call_id,
137            message: message.into(),
138            source: None,
139        }
140    }
141
142    /// 构造一个带 source 的 Tool 错误。
143    pub fn tool_with_source(
144        name: impl Into<String>,
145        call_id: ToolCallId,
146        message: impl Into<String>,
147        source: impl std::error::Error + Send + Sync + 'static,
148    ) -> Self {
149        Self::Tool {
150            name: name.into(),
151            call_id,
152            message: message.into(),
153            source: Some(Box::new(source)),
154        }
155    }
156
157    /// 构造一个带恢复建议的 Provider 错误。
158    ///
159    /// # Examples
160    /// ```
161    /// use katu_core::error::{Error, ProviderErrorKind, AuthErrorKind};
162    ///
163    /// let err = Error::provider(
164    ///     ProviderErrorKind::Authentication {
165    ///         message: "invalid key".into(),
166    ///         kind: AuthErrorKind::Invalid,
167    ///     },
168    ///     "check your API key in ~/.config/katu/config.toml",
169    /// );
170    /// assert!(!err.retryable());
171    /// ```
172    pub fn provider(kind: ProviderErrorKind, suggestion: impl Into<String>) -> Self {
173        Self::Provider {
174            kind,
175            suggestion: Some(suggestion.into()),
176        }
177    }
178
179    /// 该错误是否可重试。
180    ///
181    /// 仅 `RateLimit` 和 `ServerError` 返回 true。
182    pub fn retryable(&self) -> bool {
183        match self {
184            Self::Provider { kind, .. } => kind.retryable(),
185            _ => false,
186        }
187    }
188
189    /// 建议的退避时间。
190    ///
191    /// 如果 Provider 在 header 中指定了 retry-after,返回该值;
192    /// 否则按分类返回默认退避时间。
193    pub fn retry_after(&self) -> Option<Duration> {
194        match self {
195            Self::Provider { kind, .. } => kind.retry_after(),
196            _ => None,
197        }
198    }
199}
200
201// ---------------------------------------------------------------------------
202// ProviderErrorKind — Provider 错误分类
203// ---------------------------------------------------------------------------
204
205/// Provider 错误分类。
206///
207/// 每种分类自带 retryable 语义,Agent loop 据此决定重试策略。
208/// 可序列化以便持久化到 Session 记录中。
209#[derive(Debug, Error, Clone, Serialize, Deserialize)]
210pub enum ProviderErrorKind {
211    /// API Key 缺失、无效、过期、权限不足。
212    #[error("authentication failed: {message} ({kind})")]
213    Authentication {
214        message: String,
215        kind: AuthErrorKind,
216    },
217
218    /// 请求频率超限(HTTP 429)。
219    #[error("rate limited: {message}")]
220    RateLimit {
221        message: String,
222        /// Provider 建议的等待时间。
223        #[serde(with = "option_duration_millis")]
224        retry_after: Option<Duration>,
225    },
226
227    /// 配额/余额耗尽(需充值或切换账户)。
228    #[error("quota exceeded: {message}")]
229    QuotaExceeded { message: String },
230
231    /// 请求参数非法(prompt 过长、无效模型名等)。
232    #[error("invalid request: {message}")]
233    InvalidRequest { message: String },
234
235    /// 内容被安全策略过滤。
236    #[error("content filtered: {message}")]
237    ContentFiltered { message: String },
238
239    /// Provider 服务端错误(5xx)。
240    #[error("provider internal error ({status}): {message}")]
241    ServerError {
242        message: String,
243        status: u16,
244        #[serde(with = "option_duration_millis")]
245        retry_after: Option<Duration>,
246    },
247
248    /// 网络/传输层错误(DNS 失败、连接超时、SSL 等)。
249    #[error("transport error: {message}")]
250    Transport { message: String },
251
252    /// Provider 返回了无法解析的响应。
253    #[error("invalid response: {message}")]
254    InvalidResponse { message: String },
255
256    /// 未知 Provider 错误。
257    #[error("unknown provider error: {message}")]
258    Unknown { message: String },
259}
260
261impl ProviderErrorKind {
262    /// 该错误是否可重试。
263    ///
264    /// - `true` → RateLimit, ServerError
265    /// - `false` → Authentication, QuotaExceeded, InvalidRequest, 其他
266    pub fn retryable(&self) -> bool {
267        matches!(self, Self::RateLimit { .. } | Self::ServerError { .. })
268    }
269
270    /// 建议的退避等待时间。
271    ///
272    /// 优先使用 Provider 返回的 retry-after 值;
273    /// 无显式值时按分类使用默认值:
274    /// - RateLimit → 30s
275    /// - ServerError → 20s
276    pub fn retry_after(&self) -> Option<Duration> {
277        match self {
278            Self::RateLimit { retry_after, .. } => {
279                retry_after.or(Some(Duration::from_secs(30)))
280            }
281            Self::ServerError { retry_after, .. } => {
282                retry_after.or(Some(Duration::from_secs(20)))
283            }
284            _ => None,
285        }
286    }
287}
288
289// ---------------------------------------------------------------------------
290// AuthErrorKind — 认证错误细分
291// ---------------------------------------------------------------------------
292
293/// 认证错误细分。
294///
295/// 便于上层 UI 给出针对性的恢复建议:
296/// - `Missing` → "请设置 API Key"
297/// - `Expired` → "请刷新 Token"
298/// - `InsufficientPermissions` → "当前 Key 无权使用该模型"
299#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
300#[serde(rename_all = "snake_case")]
301pub enum AuthErrorKind {
302    /// API Key 未设置。
303    Missing,
304    /// API Key 格式正确但被拒绝。
305    Invalid,
306    /// Token 已过期。
307    Expired,
308    /// 权限不足(Key 有效但无权访问该模型)。
309    InsufficientPermissions,
310    /// 无法分类的认证错误。
311    Unknown,
312}
313
314impl fmt::Display for AuthErrorKind {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        match self {
317            Self::Missing => write!(f, "missing"),
318            Self::Invalid => write!(f, "invalid"),
319            Self::Expired => write!(f, "expired"),
320            Self::InsufficientPermissions => write!(f, "insufficient_permissions"),
321            Self::Unknown => write!(f, "unknown"),
322        }
323    }
324}
325
326// ---------------------------------------------------------------------------
327// Serde helper: Option<Duration> as milliseconds
328// ---------------------------------------------------------------------------
329
330/// 将 `Option<Duration>` 序列化为毫秒数(u64 或 null)。
331mod option_duration_millis {
332    use serde::{Deserialize, Deserializer, Serialize, Serializer};
333    use std::time::Duration;
334
335    pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
336    where
337        S: Serializer,
338    {
339        match value {
340            Some(d) => d.as_millis().serialize(serializer),
341            None => serializer.serialize_none(),
342        }
343    }
344
345    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
346    where
347        D: Deserializer<'de>,
348    {
349        let opt: Option<u64> = Option::deserialize(deserializer)?;
350        Ok(opt.map(Duration::from_millis))
351    }
352}
353
354// ---------------------------------------------------------------------------
355// Tests
356// ---------------------------------------------------------------------------
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    // -- Error 顶层 --
363
364    /// 验证 Error::tool 构造正确
365    #[test]
366    fn test_error_tool_construction() {
367        let err = Error::tool("read_file", ToolCallId::new("call_1"), "not found");
368        match &err {
369            Error::Tool {
370                name,
371                call_id,
372                message,
373                source,
374            } => {
375                assert_eq!(name, "read_file");
376                assert_eq!(call_id.as_str(), "call_1");
377                assert_eq!(message, "not found");
378                assert!(source.is_none());
379            }
380            _ => panic!("expected Error::Tool"),
381        }
382    }
383
384    /// 验证 Error::tool_with_source 保留原始错误
385    #[test]
386    fn test_error_tool_with_source() {
387        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "no such file");
388        let err = Error::tool_with_source(
389            "read_file",
390            ToolCallId::new("call_2"),
391            "failed",
392            io_err,
393        );
394        match &err {
395            Error::Tool { source, .. } => {
396                assert!(source.is_some());
397            }
398            _ => panic!("expected Error::Tool"),
399        }
400    }
401
402    /// 验证 Error::provider 构造正确
403    #[test]
404    fn test_error_provider_construction() {
405        let err = Error::provider(
406            ProviderErrorKind::Authentication {
407                message: "bad key".into(),
408                kind: AuthErrorKind::Invalid,
409            },
410            "run /login",
411        );
412        match &err {
413            Error::Provider { kind, suggestion } => {
414                assert!(matches!(kind, ProviderErrorKind::Authentication { .. }));
415                assert_eq!(suggestion.as_deref(), Some("run /login"));
416            }
417            _ => panic!("expected Error::Provider"),
418        }
419    }
420
421    /// 验证 Error::Cancelled 不可重试
422    #[test]
423    fn test_cancelled_not_retryable() {
424        let err = Error::Cancelled;
425        assert!(!err.retryable());
426        assert!(err.retry_after().is_none());
427    }
428
429    /// 验证 Error::ContextOverflow 的 Display
430    #[test]
431    fn test_context_overflow_display() {
432        let err = Error::ContextOverflow {
433            used: 130000,
434            limit: 128000,
435        };
436        assert_eq!(err.to_string(), "context overflow: 130000/128000 tokens");
437    }
438
439    /// 验证 Error::Config 的 Display
440    #[test]
441    fn test_config_error_display() {
442        let err = Error::Config {
443            message: "missing api_key".into(),
444        };
445        assert!(err.to_string().contains("missing api_key"));
446    }
447
448    // -- ProviderErrorKind --
449
450    /// 验证 RateLimit 可重试
451    #[test]
452    fn test_rate_limit_retryable() {
453        let kind = ProviderErrorKind::RateLimit {
454            message: "429".into(),
455            retry_after: Some(Duration::from_secs(60)),
456        };
457        assert!(kind.retryable());
458        assert_eq!(kind.retry_after(), Some(Duration::from_secs(60)));
459    }
460
461    /// 验证 RateLimit 无显式 retry_after 时使用默认 30s
462    #[test]
463    fn test_rate_limit_default_retry_after() {
464        let kind = ProviderErrorKind::RateLimit {
465            message: "slow down".into(),
466            retry_after: None,
467        };
468        assert_eq!(kind.retry_after(), Some(Duration::from_secs(30)));
469    }
470
471    /// 验证 ServerError 可重试且默认 20s
472    #[test]
473    fn test_server_error_retryable() {
474        let kind = ProviderErrorKind::ServerError {
475            message: "internal".into(),
476            status: 500,
477            retry_after: None,
478        };
479        assert!(kind.retryable());
480        assert_eq!(kind.retry_after(), Some(Duration::from_secs(20)));
481    }
482
483    /// 验证 Authentication 不可重试
484    #[test]
485    fn test_authentication_not_retryable() {
486        let kind = ProviderErrorKind::Authentication {
487            message: "invalid".into(),
488            kind: AuthErrorKind::Invalid,
489        };
490        assert!(!kind.retryable());
491        assert!(kind.retry_after().is_none());
492    }
493
494    /// 验证 QuotaExceeded 不可重试
495    #[test]
496    fn test_quota_exceeded_not_retryable() {
497        let kind = ProviderErrorKind::QuotaExceeded {
498            message: "out of credits".into(),
499        };
500        assert!(!kind.retryable());
501    }
502
503    /// 验证 InvalidRequest 不可重试
504    #[test]
505    fn test_invalid_request_not_retryable() {
506        let kind = ProviderErrorKind::InvalidRequest {
507            message: "model not found".into(),
508        };
509        assert!(!kind.retryable());
510    }
511
512    /// 验证 ContentFiltered 不可重试
513    #[test]
514    fn test_content_filtered_not_retryable() {
515        let kind = ProviderErrorKind::ContentFiltered {
516            message: "blocked".into(),
517        };
518        assert!(!kind.retryable());
519    }
520
521    /// 验证 Transport 不可重试
522    #[test]
523    fn test_transport_not_retryable() {
524        let kind = ProviderErrorKind::Transport {
525            message: "dns failed".into(),
526        };
527        assert!(!kind.retryable());
528    }
529
530    /// 验证 InvalidResponse 不可重试
531    #[test]
532    fn test_invalid_response_not_retryable() {
533        let kind = ProviderErrorKind::InvalidResponse {
534            message: "bad json".into(),
535        };
536        assert!(!kind.retryable());
537    }
538
539    /// 验证 Unknown 不可重试
540    #[test]
541    fn test_unknown_not_retryable() {
542        let kind = ProviderErrorKind::Unknown {
543            message: "???".into(),
544        };
545        assert!(!kind.retryable());
546    }
547
548    // -- ProviderErrorKind serde --
549
550    /// 验证 ProviderErrorKind 序列化/反序列化 roundtrip
551    #[test]
552    fn test_provider_error_kind_serde_roundtrip() {
553        let kind = ProviderErrorKind::RateLimit {
554            message: "429 too many".into(),
555            retry_after: Some(Duration::from_millis(5000)),
556        };
557        let json = serde_json::to_string(&kind).unwrap();
558        let restored: ProviderErrorKind = serde_json::from_str(&json).unwrap();
559        assert!(matches!(
560            restored,
561            ProviderErrorKind::RateLimit {
562                retry_after: Some(d),
563                ..
564            } if d == Duration::from_millis(5000)
565        ));
566    }
567
568    /// 验证 ProviderErrorKind::ServerError serde roundtrip
569    #[test]
570    fn test_server_error_serde_roundtrip() {
571        let kind = ProviderErrorKind::ServerError {
572            message: "overloaded".into(),
573            status: 529,
574            retry_after: None,
575        };
576        let json = serde_json::to_string(&kind).unwrap();
577        let restored: ProviderErrorKind = serde_json::from_str(&json).unwrap();
578        assert!(matches!(
579            restored,
580            ProviderErrorKind::ServerError { status: 529, retry_after: None, .. }
581        ));
582    }
583
584    // -- AuthErrorKind --
585
586    /// 验证 AuthErrorKind 的 Display 输出
587    #[test]
588    fn test_auth_error_kind_display() {
589        assert_eq!(AuthErrorKind::Missing.to_string(), "missing");
590        assert_eq!(AuthErrorKind::Invalid.to_string(), "invalid");
591        assert_eq!(AuthErrorKind::Expired.to_string(), "expired");
592        assert_eq!(
593            AuthErrorKind::InsufficientPermissions.to_string(),
594            "insufficient_permissions"
595        );
596        assert_eq!(AuthErrorKind::Unknown.to_string(), "unknown");
597    }
598
599    /// 验证 AuthErrorKind serde roundtrip
600    #[test]
601    fn test_auth_error_kind_serde_roundtrip() {
602        for kind in [
603            AuthErrorKind::Missing,
604            AuthErrorKind::Invalid,
605            AuthErrorKind::Expired,
606            AuthErrorKind::InsufficientPermissions,
607            AuthErrorKind::Unknown,
608        ] {
609            let json = serde_json::to_string(&kind).unwrap();
610            let restored: AuthErrorKind = serde_json::from_str(&json).unwrap();
611            assert_eq!(kind, restored);
612        }
613    }
614
615    // -- Error::retryable / retry_after 集成 --
616
617    /// 验证 Error 层面的 retryable 委托到 ProviderErrorKind
618    #[test]
619    fn test_error_retryable_delegates() {
620        let retryable_err = Error::Provider {
621            kind: ProviderErrorKind::RateLimit {
622                message: "wait".into(),
623                retry_after: Some(Duration::from_secs(10)),
624            },
625            suggestion: None,
626        };
627        assert!(retryable_err.retryable());
628        assert_eq!(retryable_err.retry_after(), Some(Duration::from_secs(10)));
629
630        let non_retryable_err = Error::Provider {
631            kind: ProviderErrorKind::Authentication {
632                message: "bad".into(),
633                kind: AuthErrorKind::Invalid,
634            },
635            suggestion: None,
636        };
637        assert!(!non_retryable_err.retryable());
638        assert!(non_retryable_err.retry_after().is_none());
639    }
640
641    /// 验证非 Provider 错误的 retryable 始终为 false
642    #[test]
643    fn test_non_provider_errors_not_retryable() {
644        assert!(!Error::Cancelled.retryable());
645        assert!(!Error::Config { message: "x".into() }.retryable());
646        assert!(!Error::ContextOverflow { used: 1, limit: 1 }.retryable());
647        assert!(!Error::tool("t", ToolCallId::new("c"), "m").retryable());
648    }
649
650    // -- Display --
651
652    /// 验证各错误变体的 Display 格式
653    #[test]
654    fn test_error_display_formats() {
655        let tool_err = Error::tool("bash", ToolCallId::new("c1"), "permission denied");
656        assert_eq!(
657            tool_err.to_string(),
658            "tool `bash` failed: permission denied"
659        );
660
661        let provider_err = Error::Provider {
662            kind: ProviderErrorKind::Transport {
663                message: "connection reset".into(),
664            },
665            suggestion: Some("check your network".into()),
666        };
667        assert_eq!(
668            provider_err.to_string(),
669            "provider error: transport error: connection reset"
670        );
671    }
672}