Skip to main content

rucora_core/
error.rs

1//! rucora-core 的统一错误类型定义(增强版)
2//!
3//! # 概述
4//!
5//! 本模块提供细粒度的错误分类,支持:
6//! - 错误类型识别
7//! - 可重试性判断
8//! - 结构化诊断信息
9//! - 错误来源追踪
10//!
11//! # 使用示例
12//!
13//! ```rust
14//! use rucora_core::error::{ProviderError, ErrorCategory};
15//!
16//! let error = ProviderError::RateLimit {
17//!     retry_after: Some(std::time::Duration::from_secs(60)),
18//!     message: "API 限流".to_string(),
19//! };
20//!
21//! // 判断是否可重试
22//! if error.is_retriable() {
23//!     println!("可以重试");
24//! }
25//!
26//! // 获取错误类别
27//! match error.category() {
28//!     ErrorCategory::RateLimit => println!("限流错误"),
29//!     _ => println!("其他错误"),
30//! }
31//! ```
32
33/// 错误类别枚举
34///
35/// 用于识别错误的根本原因。
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ErrorCategory {
38    /// 网络错误(连接超时、DNS 解析失败等)
39    Network,
40    /// API 错误(HTTP 状态码错误)
41    Api,
42    /// 认证错误(API Key 无效、令牌过期等)
43    Authentication,
44    /// 授权错误(权限不足)
45    Authorization,
46    /// 限流错误(请求频率过高)
47    RateLimit,
48    /// 超时错误(请求超时)
49    Timeout,
50    /// 模型错误(模型不存在、模型过载等)
51    Model,
52    /// 工具错误(工具执行失败)
53    Tool,
54    /// 策略错误(违反安全策略)
55    Policy,
56    /// 配置错误(配置无效)
57    Configuration,
58    /// 其他错误
59    Other,
60}
61
62impl ErrorCategory {
63    /// 判断是否可重试
64    pub fn is_retriable(self) -> bool {
65        matches!(
66            self,
67            ErrorCategory::Network | ErrorCategory::Timeout | ErrorCategory::RateLimit
68        )
69    }
70
71    /// 判断是否认证相关
72    pub fn is_authentication_error(self) -> bool {
73        matches!(
74            self,
75            ErrorCategory::Authentication | ErrorCategory::Authorization
76        )
77    }
78
79    /// 判断是否客户端错误
80    pub fn is_client_error(self) -> bool {
81        matches!(
82            self,
83            ErrorCategory::Authentication
84                | ErrorCategory::Authorization
85                | ErrorCategory::Configuration
86                | ErrorCategory::Policy
87        )
88    }
89}
90
91/// 统一的错误诊断信息
92///
93/// 设计目标:
94/// - 不破坏现有错误枚举与调用点
95/// - 让上层能拿到结构化的诊断字段
96///
97/// # 字段说明
98///
99/// - `kind`: 错误大类(provider/tool/runtime/skill/memory/channel)
100/// - `message`: 人类可读信息
101/// - `retriable`: 是否建议重试
102/// - `source`: 可选的错误来源字符串
103/// - `category`: 错误详细类别
104/// - `status_code`: 可选的 HTTP 状态码
105/// - `retry_after`: 可选的重试等待时间
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct ErrorDiagnostic {
108    /// 错误类型
109    pub kind: String,
110    /// 错误消息
111    pub message: String,
112    /// 是否可重试
113    pub retriable: bool,
114    /// 错误来源
115    pub source: Option<String>,
116    /// 错误详细类别
117    pub category: ErrorCategory,
118    /// HTTP 状态码(如果有)
119    pub status_code: Option<u16>,
120    /// 建议重试等待时间
121    pub retry_after: Option<std::time::Duration>,
122}
123
124impl Default for ErrorDiagnostic {
125    fn default() -> Self {
126        Self {
127            kind: "unknown".to_string(),
128            message: String::new(),
129            retriable: false,
130            source: None,
131            category: ErrorCategory::Other,
132            status_code: None,
133            retry_after: None,
134        }
135    }
136}
137
138/// 为 core 层错误提供统一诊断能力
139pub trait DiagnosticError {
140    /// 获取错误的诊断信息
141    fn diagnostic(&self) -> ErrorDiagnostic;
142
143    /// 判断是否可重试
144    fn is_retriable(&self) -> bool {
145        self.diagnostic().retriable
146    }
147
148    /// 获取错误类别
149    fn category(&self) -> ErrorCategory {
150        self.diagnostic().category
151    }
152}
153
154/// Provider 错误(增强版)
155///
156/// # 变体说明
157///
158/// - `Network`: 网络错误(可重试)
159/// - `Api`: API 错误(根据状态码判断)
160/// - `Authentication`: 认证错误(不可重试)
161/// - `RateLimit`: 限流错误(可重试,带等待时间)
162/// - `Timeout`: 超时错误(可重试)
163/// - `Model`: 模型错误
164/// - `Message`: 通用错误(向后兼容)
165#[derive(thiserror::Error, Debug)]
166pub enum ProviderError {
167    /// 网络错误
168    #[error("网络错误:{message}")]
169    Network {
170        message: String,
171        #[source]
172        source: Option<Box<dyn std::error::Error + Send + Sync>>,
173        retriable: bool,
174    },
175
176    /// API 错误
177    #[error("API 错误 ({status}): {message}")]
178    Api {
179        status: u16,
180        message: String,
181        code: Option<String>,
182    },
183
184    /// 认证错误
185    #[error("认证失败:{message}")]
186    Authentication { message: String },
187
188    /// 限流错误
189    #[error("请求频率过高:{message}")]
190    RateLimit {
191        message: String,
192        retry_after: Option<std::time::Duration>,
193    },
194
195    /// 超时错误
196    #[error("请求超时:{message}")]
197    Timeout {
198        message: String,
199        elapsed: std::time::Duration,
200    },
201
202    /// 模型错误
203    #[error("模型错误:{message}")]
204    Model { message: String },
205
206    /// 通用错误(向后兼容)
207    #[error("provider error: {0}")]
208    Message(String),
209}
210
211impl ProviderError {
212    /// 创建网络错误
213    pub fn network(message: impl Into<String>) -> Self {
214        Self::Network {
215            message: message.into(),
216            source: None,
217            retriable: true,
218        }
219    }
220
221    /// 创建认证错误
222    pub fn authentication(message: impl Into<String>) -> Self {
223        Self::Authentication {
224            message: message.into(),
225        }
226    }
227
228    /// 创建限流错误
229    pub fn rate_limit(
230        message: impl Into<String>,
231        retry_after: Option<std::time::Duration>,
232    ) -> Self {
233        Self::RateLimit {
234            message: message.into(),
235            retry_after,
236        }
237    }
238
239    /// 判断是否可重试
240    pub fn is_retriable(&self) -> bool {
241        match self {
242            ProviderError::Network { retriable, .. } => *retriable,
243            ProviderError::RateLimit { .. } => true,
244            ProviderError::Timeout { .. } => true,
245            ProviderError::Model { .. } => true,
246            ProviderError::Api { status, .. } => {
247                // 5xx 错误可重试
248                *status >= 500 && *status < 600
249            }
250            _ => false,
251        }
252    }
253
254    /// 获取错误类别
255    pub fn category(&self) -> ErrorCategory {
256        match self {
257            ProviderError::Network { .. } => ErrorCategory::Network,
258            ProviderError::Api { .. } => ErrorCategory::Api,
259            ProviderError::Authentication { .. } => ErrorCategory::Authentication,
260            ProviderError::RateLimit { .. } => ErrorCategory::RateLimit,
261            ProviderError::Timeout { .. } => ErrorCategory::Timeout,
262            ProviderError::Model { .. } => ErrorCategory::Model,
263            ProviderError::Message(_) => ErrorCategory::Other,
264        }
265    }
266}
267
268impl DiagnosticError for ProviderError {
269    fn diagnostic(&self) -> ErrorDiagnostic {
270        match self {
271            ProviderError::Network {
272                message, retriable, ..
273            } => ErrorDiagnostic {
274                kind: "provider".to_string(),
275                message: message.clone(),
276                retriable: *retriable,
277                source: None,
278                category: ErrorCategory::Network,
279                status_code: None,
280                retry_after: None,
281            },
282            ProviderError::Api {
283                status,
284                message,
285                code,
286            } => ErrorDiagnostic {
287                kind: "provider".to_string(),
288                message: message.clone(),
289                retriable: *status >= 500,
290                source: code.clone(),
291                category: ErrorCategory::Api,
292                status_code: Some(*status),
293                retry_after: None,
294            },
295            ProviderError::Authentication { message } => ErrorDiagnostic {
296                kind: "provider".to_string(),
297                message: message.clone(),
298                retriable: false,
299                source: None,
300                category: ErrorCategory::Authentication,
301                status_code: None,
302                retry_after: None,
303            },
304            ProviderError::RateLimit {
305                message,
306                retry_after,
307            } => ErrorDiagnostic {
308                kind: "provider".to_string(),
309                message: message.clone(),
310                retriable: true,
311                source: None,
312                category: ErrorCategory::RateLimit,
313                status_code: Some(429),
314                retry_after: *retry_after,
315            },
316            ProviderError::Timeout {
317                message,
318                elapsed: _,
319            } => ErrorDiagnostic {
320                kind: "provider".to_string(),
321                message: message.clone(),
322                retriable: true,
323                source: None,
324                category: ErrorCategory::Timeout,
325                status_code: None,
326                retry_after: None, // elapsed 是已消耗时间,不是建议等待时间
327            },
328            ProviderError::Model { message } => ErrorDiagnostic {
329                kind: "provider".to_string(),
330                message: message.clone(),
331                retriable: false, // 模型错误通常是永久性错误(如模型不存在)
332                source: None,
333                category: ErrorCategory::Model,
334                status_code: None,
335                retry_after: None,
336            },
337            ProviderError::Message(msg) => ErrorDiagnostic {
338                kind: "provider".to_string(),
339                message: msg.clone(),
340                retriable: true,
341                source: None,
342                category: ErrorCategory::Other,
343                status_code: None,
344                retry_after: None,
345            },
346        }
347    }
348}
349
350/// Tool 错误(增强版)
351#[derive(thiserror::Error, Debug)]
352pub enum ToolError {
353    /// 通用错误
354    #[error("tool error: {0}")]
355    Message(String),
356
357    /// 策略拒绝
358    #[error("tool policy denied (rule_id={rule_id}): {reason}")]
359    PolicyDenied { rule_id: String, reason: String },
360
361    /// 工具不存在
362    #[error("工具不存在:{name}")]
363    NotFound { name: String },
364
365    /// 输入验证失败
366    #[error("输入验证失败:{message}")]
367    ValidationError { message: String },
368
369    /// 执行超时
370    #[error("工具执行超时:{message}")]
371    Timeout { message: String },
372}
373
374impl DiagnosticError for ToolError {
375    fn diagnostic(&self) -> ErrorDiagnostic {
376        match self {
377            ToolError::Message(msg) => ErrorDiagnostic {
378                kind: "tool".to_string(),
379                message: msg.clone(),
380                retriable: false,
381                source: None,
382                category: ErrorCategory::Tool,
383                status_code: None,
384                retry_after: None,
385            },
386            ToolError::PolicyDenied { rule_id, reason } => ErrorDiagnostic {
387                kind: "tool".to_string(),
388                message: format!("policy denied (rule_id={rule_id}): {reason}"),
389                retriable: false,
390                source: Some(rule_id.clone()),
391                category: ErrorCategory::Policy,
392                status_code: None,
393                retry_after: None,
394            },
395            ToolError::NotFound { name } => ErrorDiagnostic {
396                kind: "tool".to_string(),
397                message: format!("工具不存在:{name}"),
398                retriable: false,
399                source: None,
400                category: ErrorCategory::Configuration,
401                status_code: None,
402                retry_after: None,
403            },
404            ToolError::ValidationError { message } => ErrorDiagnostic {
405                kind: "tool".to_string(),
406                message: format!("输入验证失败:{message}"),
407                retriable: false,
408                source: None,
409                category: ErrorCategory::Configuration,
410                status_code: None,
411                retry_after: None,
412            },
413            ToolError::Timeout { message } => ErrorDiagnostic {
414                kind: "tool".to_string(),
415                message: format!("工具执行超时:{message}"),
416                retriable: false, // 工具超时通常不应该重试
417                source: None,
418                category: ErrorCategory::Timeout,
419                status_code: None,
420                retry_after: None,
421            },
422        }
423    }
424}
425
426/// Skill 错误
427#[derive(thiserror::Error, Debug)]
428pub enum SkillError {
429    #[error("skill error: {0}")]
430    Message(String),
431
432    #[error("技能不存在:{name}")]
433    NotFound { name: String },
434
435    #[error("技能执行超时:{message}")]
436    Timeout { message: String },
437}
438
439impl DiagnosticError for SkillError {
440    fn diagnostic(&self) -> ErrorDiagnostic {
441        match self {
442            SkillError::Message(msg) => ErrorDiagnostic {
443                kind: "skill".to_string(),
444                message: msg.clone(),
445                retriable: false,
446                source: None,
447                category: ErrorCategory::Other,
448                status_code: None,
449                retry_after: None,
450            },
451            SkillError::NotFound { name } => ErrorDiagnostic {
452                kind: "skill".to_string(),
453                message: format!("技能不存在:{name}"),
454                retriable: false,
455                source: None,
456                category: ErrorCategory::Configuration,
457                status_code: None,
458                retry_after: None,
459            },
460            SkillError::Timeout { message } => ErrorDiagnostic {
461                kind: "skill".to_string(),
462                message: format!("技能执行超时:{message}"),
463                retriable: true,
464                source: None,
465                category: ErrorCategory::Timeout,
466                status_code: None,
467                retry_after: None,
468            },
469        }
470    }
471}
472
473/// Agent 错误
474#[derive(thiserror::Error, Debug)]
475pub enum AgentError {
476    /// 通用错误消息
477    #[error("agent error: {0}")]
478    Message(String),
479
480    /// 超过最大步数限制
481    #[error("超过最大步数限制:{max_steps}")]
482    MaxStepsExceeded { max_steps: usize },
483
484    /// Provider 错误
485    #[error("Provider 错误:{source}")]
486    ProviderError {
487        #[source]
488        source: ProviderError,
489    },
490
491    /// 需要 Runtime 支持(Agent 返回了需要 Runtime 执行的决策)
492    #[error("此决策需要 Runtime 支持,请使用 Runtime 模式运行")]
493    RequiresRuntime,
494}
495
496impl DiagnosticError for AgentError {
497    fn diagnostic(&self) -> ErrorDiagnostic {
498        match self {
499            AgentError::Message(msg) => ErrorDiagnostic {
500                kind: "runtime".to_string(),
501                message: msg.clone(),
502                retriable: false,
503                source: None,
504                category: ErrorCategory::Other,
505                status_code: None,
506                retry_after: None,
507            },
508            AgentError::MaxStepsExceeded { max_steps } => ErrorDiagnostic {
509                kind: "runtime".to_string(),
510                message: format!("超过最大步数限制:{max_steps}"),
511                retriable: false,
512                source: None,
513                category: ErrorCategory::Configuration,
514                status_code: None,
515                retry_after: None,
516            },
517            AgentError::ProviderError { source } => {
518                let mut diag = source.diagnostic();
519                diag.kind = "runtime".to_string();
520                diag
521            }
522            AgentError::RequiresRuntime => ErrorDiagnostic {
523                kind: "runtime".to_string(),
524                message: "此决策需要 Runtime 支持,请使用 Runtime 模式运行".to_string(),
525                retriable: false,
526                source: None,
527                category: ErrorCategory::Configuration,
528                status_code: None,
529                retry_after: None,
530            },
531        }
532    }
533}
534
535/// Memory 错误
536#[derive(thiserror::Error, Debug)]
537pub enum MemoryError {
538    #[error("memory error: {0}")]
539    Message(String),
540
541    #[error("记忆不存在:{id}")]
542    NotFound { id: String },
543}
544
545impl DiagnosticError for MemoryError {
546    fn diagnostic(&self) -> ErrorDiagnostic {
547        match self {
548            MemoryError::Message(msg) => ErrorDiagnostic {
549                kind: "memory".to_string(),
550                message: msg.clone(),
551                retriable: false,
552                source: None,
553                category: ErrorCategory::Other,
554                status_code: None,
555                retry_after: None,
556            },
557            MemoryError::NotFound { id } => ErrorDiagnostic {
558                kind: "memory".to_string(),
559                message: format!("记忆不存在:{id}"),
560                retriable: false,
561                source: None,
562                category: ErrorCategory::Configuration,
563                status_code: None,
564                retry_after: None,
565            },
566        }
567    }
568}
569
570/// Channel 错误
571#[derive(thiserror::Error, Debug)]
572pub enum ChannelError {
573    #[error("channel error: {0}")]
574    Message(String),
575}
576
577impl DiagnosticError for ChannelError {
578    fn diagnostic(&self) -> ErrorDiagnostic {
579        match self {
580            ChannelError::Message(msg) => ErrorDiagnostic {
581                kind: "channel".to_string(),
582                message: msg.clone(),
583                retriable: false,
584                source: None,
585                category: ErrorCategory::Other,
586                status_code: None,
587                retry_after: None,
588            },
589        }
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn test_provider_error_retriable() {
599        let network = ProviderError::network("连接失败");
600        assert!(network.is_retriable());
601        assert_eq!(network.category(), ErrorCategory::Network);
602
603        let auth = ProviderError::authentication("API Key 无效");
604        assert!(!auth.is_retriable());
605        assert_eq!(auth.category(), ErrorCategory::Authentication);
606
607        let rate_limit =
608            ProviderError::rate_limit("限流", Some(std::time::Duration::from_secs(60)));
609        assert!(rate_limit.is_retriable());
610        assert_eq!(rate_limit.category(), ErrorCategory::RateLimit);
611    }
612
613    #[test]
614    fn test_error_category() {
615        assert!(ErrorCategory::Network.is_retriable());
616        assert!(ErrorCategory::Timeout.is_retriable());
617        assert!(ErrorCategory::RateLimit.is_retriable());
618        assert!(!ErrorCategory::Authentication.is_retriable());
619        assert!(!ErrorCategory::Policy.is_retriable());
620    }
621
622    #[test]
623    fn test_diagnostic() {
624        let error = ProviderError::Api {
625            status: 503,
626            message: "服务不可用".to_string(),
627            code: None,
628        };
629
630        let diag = error.diagnostic();
631        assert_eq!(diag.kind, "provider");
632        assert_eq!(diag.status_code, Some(503));
633        assert!(diag.retriable);
634        assert_eq!(diag.category, ErrorCategory::Api);
635    }
636}