Skip to main content

systemprompt_ai/
error.rs

1//! Typed error hierarchy for the [`systemprompt-ai`](crate) crate.
2//!
3//! Two error families live here:
4//!
5//! - [`AiError`] — the top-level public error returned by [`crate::services`].
6//!   It composes provider-level failures ([`LlmProviderError`]) and
7//!   repository-level failures ([`RepositoryError`]) via `#[from]`, plus common
8//!   transport / parsing errors ([`reqwest::Error`], [`serde_json::Error`],
9//!   [`sqlx::Error`]).
10//! - [`RepositoryError`] — the persistence-layer error returned by every
11//!   `*Repository` type in [`crate::repository`].
12//!
13//! All public service signatures use [`Result<T>`] (i.e. `Result<T, AiError>`).
14//! Provider-trait signatures continue to use the boxed
15//! [`systemprompt_models::errors::ProviderResult`] and bridge through
16//! `AiProvider for AiService` in
17//! `crate::services::core::ai_service` (the `provider_impl` submodule).
18
19use std::time::Duration;
20
21use thiserror::Error;
22use uuid::Uuid;
23
24use systemprompt_database::resilience::Outcome;
25use systemprompt_identifiers::McpServerId;
26use systemprompt_provider_contracts::LlmProviderError;
27
28#[derive(Debug, Error)]
29pub enum AiError {
30    #[error("Model not specified and no default available for provider {provider}")]
31    ModelNotSpecified { provider: String },
32
33    #[error("Request metadata missing required field: {field}")]
34    MissingMetadata { field: String },
35
36    #[error("User context required for billing and audit trails")]
37    MissingUserContext,
38
39    #[error("Provider {provider} returned empty response")]
40    EmptyProviderResponse { provider: String },
41
42    #[error("Tool call schema validation failed: {reason}")]
43    InvalidToolSchema { reason: String },
44
45    #[error("Authentication required for service {service_id}")]
46    AuthenticationRequired { service_id: McpServerId },
47
48    #[error("Structured output validation failed after {retries} attempts: {details}")]
49    StructuredOutputFailed { retries: usize, details: String },
50
51    #[error("Provider {provider} error: {message}")]
52    ProviderError { provider: String, message: String },
53
54    #[error(transparent)]
55    Provider(#[from] LlmProviderError),
56
57    #[error("Serialization failed: {0}")]
58    SerializationError(#[from] serde_json::Error),
59
60    #[error("HTTP request failed: {0}")]
61    Http(#[from] reqwest::Error),
62
63    #[error("I/O error: {0}")]
64    Io(#[from] std::io::Error),
65
66    #[error("Message history cannot be serialized to JSON")]
67    MessageSerializationFailed,
68
69    #[error("Tool {tool_name} missing required field: {field}")]
70    MissingToolField { tool_name: String, field: String },
71
72    #[error("Tool description cannot be empty for tool: {tool_name}")]
73    EmptyToolDescription { tool_name: String },
74
75    #[error("No tool calls found in provider response")]
76    NoToolCalls,
77
78    #[error("Rate limit exceeded for provider {provider}: {details}")]
79    RateLimit { provider: String, details: String },
80
81    #[error("Provider {provider} returned HTTP {status}: {body}")]
82    HttpStatus {
83        provider: String,
84        status: u16,
85        retry_after: Option<Duration>,
86        body: String,
87    },
88
89    #[error("Provider {provider} request timed out after {after_ms}ms")]
90    Timeout { provider: String, after_ms: u64 },
91
92    #[error("Circuit breaker open for provider {provider}; failing fast")]
93    CircuitOpen { provider: String },
94
95    #[error("Provider {provider} unavailable: concurrency limit reached")]
96    DependencyUnavailable { provider: String },
97
98    #[error("Invalid API credentials for provider {provider}")]
99    AuthenticationFailed { provider: String },
100
101    #[error("Configuration error: {0}")]
102    ConfigurationError(String),
103
104    #[error("Database operation failed: {0}")]
105    DatabaseError(String),
106
107    #[error("MCP service {service_id} not found or not configured")]
108    McpServiceNotFound { service_id: McpServerId },
109
110    #[error("MCP service {service_id} requires OAuth authentication but no token available")]
111    McpAuthenticationMissing { service_id: McpServerId },
112
113    #[error("Failed to determine service authentication requirements: {details}")]
114    ServiceAuthCheckFailed { details: String },
115
116    #[error("Storage operation failed: {0}")]
117    StorageError(String),
118
119    #[error("Invalid input: {0}")]
120    InvalidInput(String),
121
122    #[error("Regex error: {0}")]
123    Regex(#[from] regex::Error),
124
125    #[error(transparent)]
126    ToolProvider(#[from] systemprompt_traits::ToolProviderError),
127
128    #[error(transparent)]
129    Secrets(#[from] systemprompt_config::SecretsBootstrapError),
130
131    #[error("internal: {0}")]
132    Internal(String),
133}
134
135#[derive(Debug, Error)]
136pub enum RepositoryError {
137    #[error("AI request not found: {0}")]
138    NotFound(Uuid),
139
140    #[error("Database error: {0}")]
141    Database(#[from] sqlx::Error),
142
143    #[error("Invalid data: {field} - {reason}")]
144    InvalidData { field: String, reason: String },
145
146    #[error("Database pool initialization failed: {0}")]
147    PoolInitialization(String),
148}
149
150impl AiError {
151    pub async fn from_error_response(provider: &str, response: reqwest::Response) -> Self {
152        let status = response.status().as_u16();
153        let retry_after = parse_retry_after(response.headers());
154        let body = response.text().await.unwrap_or_default();
155        Self::HttpStatus {
156            provider: provider.to_owned(),
157            status,
158            retry_after,
159            body,
160        }
161    }
162
163    #[must_use]
164    pub fn classify(&self) -> Outcome {
165        match self {
166            Self::HttpStatus {
167                status,
168                retry_after,
169                ..
170            } => {
171                if matches!(*status, 408 | 425 | 429 | 500 | 502 | 503 | 504) {
172                    Outcome::Transient {
173                        retry_after: *retry_after,
174                    }
175                } else {
176                    Outcome::Permanent
177                }
178            },
179            Self::RateLimit { .. } | Self::Timeout { .. } => {
180                Outcome::Transient { retry_after: None }
181            },
182            Self::Http(err) if err.is_timeout() || err.is_connect() => {
183                Outcome::Transient { retry_after: None }
184            },
185            _ => Outcome::Permanent,
186        }
187    }
188}
189
190/// Parse a `Retry-After` header expressed as an integer number of seconds.
191fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
192    headers
193        .get(reqwest::header::RETRY_AFTER)?
194        .to_str()
195        .ok()?
196        .trim()
197        .parse::<u64>()
198        .ok()
199        .map(Duration::from_secs)
200}
201
202pub type Result<T> = std::result::Result<T, AiError>;
203
204impl From<RepositoryError> for AiError {
205    fn from(error: RepositoryError) -> Self {
206        Self::DatabaseError(error.to_string())
207    }
208}