Skip to main content

forge_core/
error.rs

1use std::time::Duration;
2
3use thiserror::Error;
4
5use crate::workflow::SuspendReason;
6
7/// Core error type mapping variants to HTTP status codes.
8#[derive(Error, Debug)]
9#[non_exhaustive]
10pub enum ForgeError {
11    #[error("Configuration error: {context}")]
12    Config {
13        context: String,
14        #[source]
15        source: Option<Box<dyn std::error::Error + Send + Sync>>,
16    },
17
18    #[error("Database error: {0}")]
19    Database(#[from] sqlx::Error),
20
21    #[error("Job cancelled: {0}")]
22    JobCancelled(String),
23
24    #[error("Serialization error: {0}")]
25    Serialization(String),
26
27    #[error("Deserialization error: {0}")]
28    Deserialization(String),
29
30    #[error("IO error: {0}")]
31    Io(#[from] std::io::Error),
32
33    #[error("Invalid argument: {0}")]
34    InvalidArgument(String),
35
36    #[error("Not found: {0}")]
37    NotFound(String),
38
39    #[error("Unauthorized: {0}")]
40    Unauthorized(String),
41
42    #[error("Forbidden: {0}")]
43    Forbidden(String),
44
45    #[error("Validation error: {0}")]
46    Validation(String),
47
48    #[error("Timeout: {0}")]
49    Timeout(String),
50
51    #[error("Internal error: {context}")]
52    Internal {
53        context: String,
54        #[source]
55        source: Option<Box<dyn std::error::Error + Send + Sync>>,
56    },
57
58    #[error("Invalid state: {0}")]
59    InvalidState(String),
60
61    #[error("Rate limit exceeded: retry after {retry_after:?}")]
62    RateLimitExceeded {
63        retry_after: Duration,
64        limit: u32,
65        remaining: u32,
66    },
67
68    /// Service unavailable (503).
69    #[error("Service unavailable: {0}")]
70    ServiceUnavailable(String),
71
72    /// Internal control signal raised when a workflow handler suspends
73    /// (`ctx.sleep(...)` / `ctx.wait_for_event(...)`). The executor handles
74    /// it before any HTTP mapping layer; it is never returned to a client.
75    #[error("Workflow suspended")]
76    WorkflowSuspended(SuspendReason),
77}
78
79impl ForgeError {
80    pub fn not_found(msg: impl Into<String>) -> Self {
81        ForgeError::NotFound(msg.into())
82    }
83
84    pub fn config(msg: impl Into<String>) -> Self {
85        ForgeError::Config {
86            context: msg.into(),
87            source: None,
88        }
89    }
90
91    pub fn unauthorized(msg: impl Into<String>) -> Self {
92        ForgeError::Unauthorized(msg.into())
93    }
94
95    pub fn forbidden(msg: impl Into<String>) -> Self {
96        ForgeError::Forbidden(msg.into())
97    }
98
99    pub fn validation(msg: impl Into<String>) -> Self {
100        ForgeError::Validation(msg.into())
101    }
102
103    pub fn timeout(msg: impl Into<String>) -> Self {
104        ForgeError::Timeout(msg.into())
105    }
106
107    pub fn internal(msg: impl Into<String>) -> Self {
108        ForgeError::Internal {
109            context: msg.into(),
110            source: None,
111        }
112    }
113
114    pub fn internal_with(
115        msg: impl Into<String>,
116        source: impl std::error::Error + Send + Sync + 'static,
117    ) -> Self {
118        ForgeError::Internal {
119            context: msg.into(),
120            source: Some(Box::new(source)),
121        }
122    }
123
124    pub fn config_with(
125        msg: impl Into<String>,
126        source: impl std::error::Error + Send + Sync + 'static,
127    ) -> Self {
128        ForgeError::Config {
129            context: msg.into(),
130            source: Some(Box::new(source)),
131        }
132    }
133
134    /// Canonical variant-to-HTTP-status mapping.
135    pub fn http_status(&self) -> u16 {
136        match self {
137            Self::NotFound(_) => 404,
138            Self::Unauthorized(_) => 401,
139            Self::Forbidden(_) => 403,
140            Self::Validation(_) => 400,
141            Self::InvalidArgument(_) => 400,
142            Self::Deserialization(_) => 400,
143            Self::Timeout(_) => 504,
144            Self::RateLimitExceeded { .. } => 429,
145            Self::JobCancelled(_) => 409,
146            Self::ServiceUnavailable(_) => 503,
147            _ => 500,
148        }
149    }
150
151    pub fn is_client_error(&self) -> bool {
152        let status = self.http_status();
153        (400..500).contains(&status)
154    }
155
156    pub fn is_server_error(&self) -> bool {
157        self.http_status() >= 500
158    }
159
160    pub fn is_retryable(&self) -> bool {
161        matches!(
162            self,
163            Self::ServiceUnavailable(_) | Self::Timeout(_) | Self::RateLimitExceeded { .. }
164        )
165    }
166}
167
168impl From<serde_json::Error> for ForgeError {
169    fn from(e: serde_json::Error) -> Self {
170        ForgeError::Serialization(e.to_string())
171    }
172}
173
174impl From<crate::http::CircuitBreakerError> for ForgeError {
175    fn from(e: crate::http::CircuitBreakerError) -> Self {
176        match e {
177            crate::http::CircuitBreakerError::CircuitOpen(open) => {
178                ForgeError::Timeout(open.to_string())
179            }
180            crate::http::CircuitBreakerError::Request(err) if err.is_timeout() => {
181                ForgeError::Timeout(err.to_string())
182            }
183            crate::http::CircuitBreakerError::Request(err) => ForgeError::Internal {
184                context: "HTTP request failed".to_string(),
185                source: Some(Box::new(err)),
186            },
187            crate::http::CircuitBreakerError::PrivateHostBlocked(host) => {
188                ForgeError::Forbidden(format!("Outbound request to private host '{host}' blocked"))
189            }
190        }
191    }
192}
193
194/// Result type alias using ForgeError.
195pub type Result<T> = std::result::Result<T, ForgeError>;
196
197#[cfg(test)]
198#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
199mod tests {
200    use std::error::Error as _;
201
202    use super::*;
203
204    #[test]
205    fn display_preserves_inner_message() {
206        let cases: Vec<(ForgeError, &str)> = vec![
207            (
208                ForgeError::config("bad toml"),
209                "Configuration error: bad toml",
210            ),
211            (
212                ForgeError::Database(sqlx::Error::RowNotFound),
213                "Database error: no rows returned by a query that expected to return at least one row",
214            ),
215            (
216                ForgeError::JobCancelled("user request".into()),
217                "Job cancelled: user request",
218            ),
219            (
220                ForgeError::Serialization("bad json".into()),
221                "Serialization error: bad json",
222            ),
223            (
224                ForgeError::Deserialization("missing field".into()),
225                "Deserialization error: missing field",
226            ),
227            (
228                ForgeError::InvalidArgument("negative id".into()),
229                "Invalid argument: negative id",
230            ),
231            (ForgeError::NotFound("user 42".into()), "Not found: user 42"),
232            (
233                ForgeError::Unauthorized("expired token".into()),
234                "Unauthorized: expired token",
235            ),
236            (
237                ForgeError::Forbidden("admin only".into()),
238                "Forbidden: admin only",
239            ),
240            (
241                ForgeError::Validation("email required".into()),
242                "Validation error: email required",
243            ),
244            (
245                ForgeError::Timeout("5s exceeded".into()),
246                "Timeout: 5s exceeded",
247            ),
248            (
249                ForgeError::internal("null pointer"),
250                "Internal error: null pointer",
251            ),
252            (
253                ForgeError::InvalidState("already completed".into()),
254                "Invalid state: already completed",
255            ),
256        ];
257
258        for (error, expected) in cases {
259            assert_eq!(error.to_string(), expected, "Display mismatch for variant");
260        }
261    }
262
263    #[test]
264    fn display_rate_limit_includes_retry_after() {
265        let err = ForgeError::RateLimitExceeded {
266            retry_after: Duration::from_secs(30),
267            limit: 100,
268            remaining: 0,
269        };
270        let msg = err.to_string();
271        assert!(msg.contains("30"), "Expected retry_after in message: {msg}");
272    }
273
274    #[test]
275    fn from_serde_json_error_maps_to_serialization() {
276        let bad_json = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
277        let forge_err: ForgeError = bad_json.into();
278        match forge_err {
279            ForgeError::Serialization(msg) => assert!(!msg.is_empty()),
280            other => panic!("Expected Serialization, got: {other:?}"),
281        }
282    }
283
284    #[test]
285    fn from_io_error_maps_to_io() {
286        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
287        let forge_err: ForgeError = io_err.into();
288        match forge_err {
289            ForgeError::Io(e) => assert_eq!(e.kind(), std::io::ErrorKind::NotFound),
290            other => panic!("Expected Io, got: {other:?}"),
291        }
292    }
293
294    #[test]
295    fn from_circuit_breaker_open_maps_to_timeout() {
296        let open = crate::http::CircuitBreakerError::CircuitOpen(crate::http::CircuitBreakerOpen {
297            host: "api.example.com".into(),
298            retry_after: Duration::from_secs(60),
299        });
300        let forge_err: ForgeError = open.into();
301        match forge_err {
302            ForgeError::Timeout(msg) => {
303                assert!(
304                    msg.contains("api.example.com"),
305                    "Expected host in message: {msg}"
306                );
307            }
308            other => panic!("Expected Timeout, got: {other:?}"),
309        }
310    }
311
312    #[test]
313    fn variants_are_distinguishable_via_pattern_match() {
314        let errors: Vec<ForgeError> = vec![
315            ForgeError::NotFound("x".into()),
316            ForgeError::Unauthorized("x".into()),
317            ForgeError::Forbidden("x".into()),
318            ForgeError::Validation("x".into()),
319            ForgeError::InvalidArgument("x".into()),
320            ForgeError::Timeout("x".into()),
321            ForgeError::internal("x"),
322        ];
323
324        for (i, err) in errors.iter().enumerate() {
325            let matched = match err {
326                ForgeError::NotFound(_) => 0,
327                ForgeError::Unauthorized(_) => 1,
328                ForgeError::Forbidden(_) => 2,
329                ForgeError::Validation(_) => 3,
330                ForgeError::InvalidArgument(_) => 4,
331                ForgeError::Timeout(_) => 5,
332                ForgeError::Internal { .. } => 6,
333                _ => usize::MAX,
334            };
335            assert_eq!(matched, i, "Variant at index {i} matched wrong pattern");
336        }
337    }
338
339    #[test]
340    fn rate_limit_fields_accessible() {
341        let err = ForgeError::RateLimitExceeded {
342            retry_after: Duration::from_secs(60),
343            limit: 100,
344            remaining: 0,
345        };
346
347        match err {
348            ForgeError::RateLimitExceeded {
349                retry_after,
350                limit,
351                remaining,
352            } => {
353                assert_eq!(retry_after, Duration::from_secs(60));
354                assert_eq!(limit, 100);
355                assert_eq!(remaining, 0);
356            }
357            _ => panic!("Expected RateLimitExceeded"),
358        }
359    }
360
361    #[test]
362    fn error_is_send_and_sync() {
363        fn assert_send<T: Send>() {}
364        fn assert_sync<T: Sync>() {}
365        assert_send::<ForgeError>();
366        assert_sync::<ForgeError>();
367    }
368
369    #[test]
370    fn http_status_returns_correct_codes() {
371        assert_eq!(ForgeError::NotFound("x".into()).http_status(), 404);
372        assert_eq!(ForgeError::Unauthorized("x".into()).http_status(), 401);
373        assert_eq!(ForgeError::Forbidden("x".into()).http_status(), 403);
374        assert_eq!(ForgeError::Validation("x".into()).http_status(), 400);
375        assert_eq!(ForgeError::InvalidArgument("x".into()).http_status(), 400);
376        assert_eq!(ForgeError::Deserialization("x".into()).http_status(), 400);
377        assert_eq!(ForgeError::Timeout("x".into()).http_status(), 504);
378        assert_eq!(ForgeError::JobCancelled("x".into()).http_status(), 409);
379        assert_eq!(
380            ForgeError::RateLimitExceeded {
381                retry_after: Duration::from_secs(1),
382                limit: 10,
383                remaining: 0,
384            }
385            .http_status(),
386            429
387        );
388        for err in [
389            ForgeError::internal("x"),
390            ForgeError::Database(sqlx::Error::RowNotFound),
391            ForgeError::config("x"),
392            ForgeError::InvalidState("x".into()),
393        ] {
394            assert_eq!(err.http_status(), 500, "expected 500 for {err:?}");
395        }
396    }
397
398    #[test]
399    fn is_client_error_for_4xx() {
400        assert!(ForgeError::not_found("x").is_client_error());
401        assert!(ForgeError::unauthorized("x").is_client_error());
402        assert!(ForgeError::forbidden("x").is_client_error());
403        assert!(ForgeError::validation("x").is_client_error());
404        assert!(!ForgeError::internal("x").is_client_error());
405        assert!(!ForgeError::timeout("x").is_client_error());
406    }
407
408    #[test]
409    fn is_server_error_for_5xx() {
410        assert!(ForgeError::internal("x").is_server_error());
411        assert!(ForgeError::timeout("x").is_server_error());
412        assert!(ForgeError::config("x").is_server_error());
413        assert!(!ForgeError::not_found("x").is_server_error());
414        assert!(!ForgeError::unauthorized("x").is_server_error());
415    }
416
417    #[test]
418    fn is_retryable_for_transient_errors() {
419        assert!(ForgeError::ServiceUnavailable("x".into()).is_retryable());
420        assert!(ForgeError::timeout("x").is_retryable());
421        assert!(
422            ForgeError::RateLimitExceeded {
423                retry_after: Duration::from_secs(1),
424                limit: 10,
425                remaining: 0,
426            }
427            .is_retryable()
428        );
429        assert!(!ForgeError::not_found("x").is_retryable());
430        assert!(!ForgeError::internal("x").is_retryable());
431        assert!(!ForgeError::validation("x").is_retryable());
432    }
433
434    #[test]
435    fn internal_with_preserves_source_chain() {
436        let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe broken");
437        let err = ForgeError::internal_with("connection failed", io_err);
438        assert_eq!(err.to_string(), "Internal error: connection failed");
439        assert!(err.source().is_some(), "source should be preserved");
440    }
441}