llm_shield_core/
error.rs

1//! Error types for LLM Shield
2//!
3//! ## SPARC Specification
4//!
5//! Comprehensive error handling following enterprise patterns:
6//! - Specific error variants for different failure modes
7//! - Rich context information
8//! - Integration with `anyhow` for flexibility
9//! - Proper error chaining
10
11use std::fmt;
12use thiserror::Error;
13
14/// Result type alias for LLM Shield operations
15pub type Result<T> = std::result::Result<T, Error>;
16
17/// Core error type for LLM Shield operations
18///
19/// ## Design Principles
20///
21/// 1. **Specific Variants**: Each error type has a specific variant
22/// 2. **Context**: All errors include contextual information
23/// 3. **Source Chaining**: Errors properly chain their sources
24/// 4. **Display**: Human-readable error messages
25#[derive(Debug, Error)]
26pub enum Error {
27    /// Scanner-specific errors
28    #[error("Scanner error in {scanner}: {message}")]
29    Scanner {
30        scanner: String,
31        message: String,
32        #[source]
33        source: Option<Box<dyn std::error::Error + Send + Sync>>,
34    },
35
36    /// Model loading or inference errors
37    #[error("Model error: {0}")]
38    Model(String),
39
40    /// Configuration errors
41    #[error("Configuration error: {0}")]
42    Config(String),
43
44    /// Invalid input data
45    #[error("Invalid input: {0}")]
46    InvalidInput(String),
47
48    /// I/O errors
49    #[error("I/O error: {0}")]
50    Io(#[from] std::io::Error),
51
52    /// Serialization errors
53    #[error("Serialization error: {0}")]
54    Serialization(#[from] serde_json::Error),
55
56    /// Vault errors (state management)
57    #[error("Vault error: {0}")]
58    Vault(String),
59
60    /// Timeout errors
61    #[error("Operation timed out after {0}ms")]
62    Timeout(u64),
63
64    /// Resource exhaustion
65    #[error("Resource exhausted: {0}")]
66    ResourceExhausted(String),
67
68    /// Internal errors (should not happen in production)
69    #[error("Internal error: {0}")]
70    Internal(String),
71
72    /// Authentication errors
73    #[error("Authentication error: {0}")]
74    Auth(String),
75
76    /// Unauthorized access
77    #[error("Unauthorized: {0}")]
78    Unauthorized(String),
79
80    /// Resource not found
81    #[error("Not found: {0}")]
82    NotFound(String),
83}
84
85impl Error {
86    /// Create a scanner error with context
87    pub fn scanner<S: Into<String>, M: Into<String>>(scanner: S, message: M) -> Self {
88        Self::Scanner {
89            scanner: scanner.into(),
90            message: message.into(),
91            source: None,
92        }
93    }
94
95    /// Create a scanner error with source
96    pub fn scanner_with_source<S: Into<String>, M: Into<String>>(
97        scanner: S,
98        message: M,
99        source: Box<dyn std::error::Error + Send + Sync>,
100    ) -> Self {
101        Self::Scanner {
102            scanner: scanner.into(),
103            message: message.into(),
104            source: Some(source),
105        }
106    }
107
108    /// Create a model error
109    pub fn model<S: Into<String>>(message: S) -> Self {
110        Self::Model(message.into())
111    }
112
113    /// Create a configuration error
114    pub fn config<S: Into<String>>(message: S) -> Self {
115        Self::Config(message.into())
116    }
117
118    /// Create an invalid input error
119    pub fn invalid_input<S: Into<String>>(message: S) -> Self {
120        Self::InvalidInput(message.into())
121    }
122
123    /// Create a vault error
124    pub fn vault<S: Into<String>>(message: S) -> Self {
125        Self::Vault(message.into())
126    }
127
128    /// Create a timeout error
129    pub fn timeout(duration_ms: u64) -> Self {
130        Self::Timeout(duration_ms)
131    }
132
133    /// Create a resource exhausted error
134    pub fn resource_exhausted<S: Into<String>>(resource: S) -> Self {
135        Self::ResourceExhausted(resource.into())
136    }
137
138    /// Create an internal error
139    pub fn internal<S: Into<String>>(message: S) -> Self {
140        Self::Internal(message.into())
141    }
142
143    /// Create an authentication error
144    pub fn auth<S: Into<String>>(message: S) -> Self {
145        Self::Auth(message.into())
146    }
147
148    /// Create an unauthorized error
149    pub fn unauthorized<S: Into<String>>(message: S) -> Self {
150        Self::Unauthorized(message.into())
151    }
152
153    /// Create a not found error
154    pub fn not_found<S: Into<String>>(message: S) -> Self {
155        Self::NotFound(message.into())
156    }
157
158    /// Check if error is retryable
159    pub fn is_retryable(&self) -> bool {
160        matches!(
161            self,
162            Error::Timeout(_) | Error::ResourceExhausted(_) | Error::Io(_)
163        )
164    }
165
166    /// Get error category for metrics
167    pub fn category(&self) -> &'static str {
168        match self {
169            Error::Scanner { .. } => "scanner",
170            Error::Model(_) => "model",
171            Error::Config(_) => "config",
172            Error::InvalidInput(_) => "invalid_input",
173            Error::Io(_) => "io",
174            Error::Serialization(_) => "serialization",
175            Error::Vault(_) => "vault",
176            Error::Timeout(_) => "timeout",
177            Error::ResourceExhausted(_) => "resource_exhausted",
178            Error::Internal(_) => "internal",
179            Error::Auth(_) => "auth",
180            Error::Unauthorized(_) => "unauthorized",
181            Error::NotFound(_) => "not_found",
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_error_creation() {
192        let err = Error::scanner("test_scanner", "test message");
193        assert!(matches!(err, Error::Scanner { .. }));
194        assert_eq!(err.category(), "scanner");
195    }
196
197    #[test]
198    fn test_error_retryable() {
199        assert!(Error::timeout(5000).is_retryable());
200        assert!(!Error::config("bad config").is_retryable());
201    }
202
203    #[test]
204    fn test_error_display() {
205        let err = Error::scanner("ban_substrings", "pattern not found");
206        let msg = format!("{}", err);
207        assert!(msg.contains("ban_substrings"));
208        assert!(msg.contains("pattern not found"));
209    }
210}