1use std::time::Duration;
2
3use thiserror::Error;
4
5use crate::workflow::SuspendReason;
6
7#[derive(Error, Debug)]
9#[non_exhaustive]
10pub enum ForgeError {
11 #[error("Configuration error: {context}")]
12 Config {
13 context: String,
14 #[source]
15 source: Option<Box<dyn std::error::Error + Send + Sync>>,
16 },
17
18 #[error("Database error: {0}")]
19 Database(#[from] sqlx::Error),
20
21 #[error("Job cancelled: {0}")]
22 JobCancelled(String),
23
24 #[error("Serialization error: {0}")]
25 Serialization(String),
26
27 #[error("Deserialization error: {0}")]
28 Deserialization(String),
29
30 #[error("IO error: {0}")]
31 Io(#[from] std::io::Error),
32
33 #[error("Invalid argument: {0}")]
34 InvalidArgument(String),
35
36 #[error("Not found: {0}")]
37 NotFound(String),
38
39 #[error("Unauthorized: {0}")]
40 Unauthorized(String),
41
42 #[error("Forbidden: {0}")]
43 Forbidden(String),
44
45 #[error("Validation error: {0}")]
46 Validation(String),
47
48 #[error("Timeout: {0}")]
49 Timeout(String),
50
51 #[error("Internal error: {context}")]
52 Internal {
53 context: String,
54 #[source]
55 source: Option<Box<dyn std::error::Error + Send + Sync>>,
56 },
57
58 #[error("Invalid state: {0}")]
59 InvalidState(String),
60
61 #[error("Rate limit exceeded: retry after {retry_after:?}")]
62 RateLimitExceeded {
63 retry_after: Duration,
64 limit: u32,
65 remaining: u32,
66 },
67
68 #[error("Service unavailable: {0}")]
70 ServiceUnavailable(String),
71
72 #[error("Workflow suspended")]
76 WorkflowSuspended(SuspendReason),
77}
78
79impl ForgeError {
80 pub fn not_found(msg: impl Into<String>) -> Self {
81 ForgeError::NotFound(msg.into())
82 }
83
84 pub fn config(msg: impl Into<String>) -> Self {
85 ForgeError::Config {
86 context: msg.into(),
87 source: None,
88 }
89 }
90
91 pub fn unauthorized(msg: impl Into<String>) -> Self {
92 ForgeError::Unauthorized(msg.into())
93 }
94
95 pub fn forbidden(msg: impl Into<String>) -> Self {
96 ForgeError::Forbidden(msg.into())
97 }
98
99 pub fn validation(msg: impl Into<String>) -> Self {
100 ForgeError::Validation(msg.into())
101 }
102
103 pub fn timeout(msg: impl Into<String>) -> Self {
104 ForgeError::Timeout(msg.into())
105 }
106
107 pub fn internal(msg: impl Into<String>) -> Self {
108 ForgeError::Internal {
109 context: msg.into(),
110 source: None,
111 }
112 }
113
114 pub fn internal_with(
115 msg: impl Into<String>,
116 source: impl std::error::Error + Send + Sync + 'static,
117 ) -> Self {
118 ForgeError::Internal {
119 context: msg.into(),
120 source: Some(Box::new(source)),
121 }
122 }
123
124 pub fn config_with(
125 msg: impl Into<String>,
126 source: impl std::error::Error + Send + Sync + 'static,
127 ) -> Self {
128 ForgeError::Config {
129 context: msg.into(),
130 source: Some(Box::new(source)),
131 }
132 }
133
134 pub fn http_status(&self) -> u16 {
136 match self {
137 Self::NotFound(_) => 404,
138 Self::Unauthorized(_) => 401,
139 Self::Forbidden(_) => 403,
140 Self::Validation(_) => 400,
141 Self::InvalidArgument(_) => 400,
142 Self::Deserialization(_) => 400,
143 Self::Timeout(_) => 504,
144 Self::RateLimitExceeded { .. } => 429,
145 Self::JobCancelled(_) => 409,
146 Self::ServiceUnavailable(_) => 503,
147 _ => 500,
148 }
149 }
150
151 pub fn is_client_error(&self) -> bool {
152 let status = self.http_status();
153 (400..500).contains(&status)
154 }
155
156 pub fn is_server_error(&self) -> bool {
157 self.http_status() >= 500
158 }
159
160 pub fn is_retryable(&self) -> bool {
161 matches!(
162 self,
163 Self::ServiceUnavailable(_) | Self::Timeout(_) | Self::RateLimitExceeded { .. }
164 )
165 }
166}
167
168impl From<serde_json::Error> for ForgeError {
169 fn from(e: serde_json::Error) -> Self {
170 ForgeError::Serialization(e.to_string())
171 }
172}
173
174impl From<crate::http::CircuitBreakerError> for ForgeError {
175 fn from(e: crate::http::CircuitBreakerError) -> Self {
176 match e {
177 crate::http::CircuitBreakerError::CircuitOpen(open) => {
178 ForgeError::Timeout(open.to_string())
179 }
180 crate::http::CircuitBreakerError::Request(err) if err.is_timeout() => {
181 ForgeError::Timeout(err.to_string())
182 }
183 crate::http::CircuitBreakerError::Request(err) => ForgeError::Internal {
184 context: "HTTP request failed".to_string(),
185 source: Some(Box::new(err)),
186 },
187 crate::http::CircuitBreakerError::PrivateHostBlocked(host) => {
188 ForgeError::Forbidden(format!("Outbound request to private host '{host}' blocked"))
189 }
190 }
191 }
192}
193
194pub type Result<T> = std::result::Result<T, ForgeError>;
196
197#[cfg(test)]
198#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
199mod tests {
200 use std::error::Error as _;
201
202 use super::*;
203
204 #[test]
205 fn display_preserves_inner_message() {
206 let cases: Vec<(ForgeError, &str)> = vec![
207 (
208 ForgeError::config("bad toml"),
209 "Configuration error: bad toml",
210 ),
211 (
212 ForgeError::Database(sqlx::Error::RowNotFound),
213 "Database error: no rows returned by a query that expected to return at least one row",
214 ),
215 (
216 ForgeError::JobCancelled("user request".into()),
217 "Job cancelled: user request",
218 ),
219 (
220 ForgeError::Serialization("bad json".into()),
221 "Serialization error: bad json",
222 ),
223 (
224 ForgeError::Deserialization("missing field".into()),
225 "Deserialization error: missing field",
226 ),
227 (
228 ForgeError::InvalidArgument("negative id".into()),
229 "Invalid argument: negative id",
230 ),
231 (ForgeError::NotFound("user 42".into()), "Not found: user 42"),
232 (
233 ForgeError::Unauthorized("expired token".into()),
234 "Unauthorized: expired token",
235 ),
236 (
237 ForgeError::Forbidden("admin only".into()),
238 "Forbidden: admin only",
239 ),
240 (
241 ForgeError::Validation("email required".into()),
242 "Validation error: email required",
243 ),
244 (
245 ForgeError::Timeout("5s exceeded".into()),
246 "Timeout: 5s exceeded",
247 ),
248 (
249 ForgeError::internal("null pointer"),
250 "Internal error: null pointer",
251 ),
252 (
253 ForgeError::InvalidState("already completed".into()),
254 "Invalid state: already completed",
255 ),
256 ];
257
258 for (error, expected) in cases {
259 assert_eq!(error.to_string(), expected, "Display mismatch for variant");
260 }
261 }
262
263 #[test]
264 fn display_rate_limit_includes_retry_after() {
265 let err = ForgeError::RateLimitExceeded {
266 retry_after: Duration::from_secs(30),
267 limit: 100,
268 remaining: 0,
269 };
270 let msg = err.to_string();
271 assert!(msg.contains("30"), "Expected retry_after in message: {msg}");
272 }
273
274 #[test]
275 fn from_serde_json_error_maps_to_serialization() {
276 let bad_json = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
277 let forge_err: ForgeError = bad_json.into();
278 match forge_err {
279 ForgeError::Serialization(msg) => assert!(!msg.is_empty()),
280 other => panic!("Expected Serialization, got: {other:?}"),
281 }
282 }
283
284 #[test]
285 fn from_io_error_maps_to_io() {
286 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
287 let forge_err: ForgeError = io_err.into();
288 match forge_err {
289 ForgeError::Io(e) => assert_eq!(e.kind(), std::io::ErrorKind::NotFound),
290 other => panic!("Expected Io, got: {other:?}"),
291 }
292 }
293
294 #[test]
295 fn from_circuit_breaker_open_maps_to_timeout() {
296 let open = crate::http::CircuitBreakerError::CircuitOpen(crate::http::CircuitBreakerOpen {
297 host: "api.example.com".into(),
298 retry_after: Duration::from_secs(60),
299 });
300 let forge_err: ForgeError = open.into();
301 match forge_err {
302 ForgeError::Timeout(msg) => {
303 assert!(
304 msg.contains("api.example.com"),
305 "Expected host in message: {msg}"
306 );
307 }
308 other => panic!("Expected Timeout, got: {other:?}"),
309 }
310 }
311
312 #[test]
313 fn variants_are_distinguishable_via_pattern_match() {
314 let errors: Vec<ForgeError> = vec![
315 ForgeError::NotFound("x".into()),
316 ForgeError::Unauthorized("x".into()),
317 ForgeError::Forbidden("x".into()),
318 ForgeError::Validation("x".into()),
319 ForgeError::InvalidArgument("x".into()),
320 ForgeError::Timeout("x".into()),
321 ForgeError::internal("x"),
322 ];
323
324 for (i, err) in errors.iter().enumerate() {
325 let matched = match err {
326 ForgeError::NotFound(_) => 0,
327 ForgeError::Unauthorized(_) => 1,
328 ForgeError::Forbidden(_) => 2,
329 ForgeError::Validation(_) => 3,
330 ForgeError::InvalidArgument(_) => 4,
331 ForgeError::Timeout(_) => 5,
332 ForgeError::Internal { .. } => 6,
333 _ => usize::MAX,
334 };
335 assert_eq!(matched, i, "Variant at index {i} matched wrong pattern");
336 }
337 }
338
339 #[test]
340 fn rate_limit_fields_accessible() {
341 let err = ForgeError::RateLimitExceeded {
342 retry_after: Duration::from_secs(60),
343 limit: 100,
344 remaining: 0,
345 };
346
347 match err {
348 ForgeError::RateLimitExceeded {
349 retry_after,
350 limit,
351 remaining,
352 } => {
353 assert_eq!(retry_after, Duration::from_secs(60));
354 assert_eq!(limit, 100);
355 assert_eq!(remaining, 0);
356 }
357 _ => panic!("Expected RateLimitExceeded"),
358 }
359 }
360
361 #[test]
362 fn error_is_send_and_sync() {
363 fn assert_send<T: Send>() {}
364 fn assert_sync<T: Sync>() {}
365 assert_send::<ForgeError>();
366 assert_sync::<ForgeError>();
367 }
368
369 #[test]
370 fn http_status_returns_correct_codes() {
371 assert_eq!(ForgeError::NotFound("x".into()).http_status(), 404);
372 assert_eq!(ForgeError::Unauthorized("x".into()).http_status(), 401);
373 assert_eq!(ForgeError::Forbidden("x".into()).http_status(), 403);
374 assert_eq!(ForgeError::Validation("x".into()).http_status(), 400);
375 assert_eq!(ForgeError::InvalidArgument("x".into()).http_status(), 400);
376 assert_eq!(ForgeError::Deserialization("x".into()).http_status(), 400);
377 assert_eq!(ForgeError::Timeout("x".into()).http_status(), 504);
378 assert_eq!(ForgeError::JobCancelled("x".into()).http_status(), 409);
379 assert_eq!(
380 ForgeError::RateLimitExceeded {
381 retry_after: Duration::from_secs(1),
382 limit: 10,
383 remaining: 0,
384 }
385 .http_status(),
386 429
387 );
388 for err in [
389 ForgeError::internal("x"),
390 ForgeError::Database(sqlx::Error::RowNotFound),
391 ForgeError::config("x"),
392 ForgeError::InvalidState("x".into()),
393 ] {
394 assert_eq!(err.http_status(), 500, "expected 500 for {err:?}");
395 }
396 }
397
398 #[test]
399 fn is_client_error_for_4xx() {
400 assert!(ForgeError::not_found("x").is_client_error());
401 assert!(ForgeError::unauthorized("x").is_client_error());
402 assert!(ForgeError::forbidden("x").is_client_error());
403 assert!(ForgeError::validation("x").is_client_error());
404 assert!(!ForgeError::internal("x").is_client_error());
405 assert!(!ForgeError::timeout("x").is_client_error());
406 }
407
408 #[test]
409 fn is_server_error_for_5xx() {
410 assert!(ForgeError::internal("x").is_server_error());
411 assert!(ForgeError::timeout("x").is_server_error());
412 assert!(ForgeError::config("x").is_server_error());
413 assert!(!ForgeError::not_found("x").is_server_error());
414 assert!(!ForgeError::unauthorized("x").is_server_error());
415 }
416
417 #[test]
418 fn is_retryable_for_transient_errors() {
419 assert!(ForgeError::ServiceUnavailable("x".into()).is_retryable());
420 assert!(ForgeError::timeout("x").is_retryable());
421 assert!(
422 ForgeError::RateLimitExceeded {
423 retry_after: Duration::from_secs(1),
424 limit: 10,
425 remaining: 0,
426 }
427 .is_retryable()
428 );
429 assert!(!ForgeError::not_found("x").is_retryable());
430 assert!(!ForgeError::internal("x").is_retryable());
431 assert!(!ForgeError::validation("x").is_retryable());
432 }
433
434 #[test]
435 fn internal_with_preserves_source_chain() {
436 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe broken");
437 let err = ForgeError::internal_with("connection failed", io_err);
438 assert_eq!(err.to_string(), "Internal error: connection failed");
439 assert!(err.source().is_some(), "source should be preserved");
440 }
441}