Skip to main content

recoco_utils/
error.rs

1// ReCoco is a Rust-only fork of CocoIndex, by [CocoIndex](https://CocoIndex)
2// Original code from CocoIndex is copyrighted by CocoIndex
3// SPDX-FileCopyrightText: 2025-2026 CocoIndex (upstream)
4// SPDX-FileContributor: CocoIndex Contributors
5//
6// All modifications from the upstream for ReCoco are copyrighted by Knitli Inc.
7// SPDX-FileCopyrightText: 2026 Knitli Inc. (ReCoco)
8// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
9//
10// Both the upstream CocoIndex code and the ReCoco modifications are licensed under the Apache-2.0 License.
11// SPDX-License-Identifier: Apache-2.0
12
13#[cfg(feature = "server")]
14use axum::{
15    Json,
16    response::{IntoResponse, Response},
17};
18
19#[cfg(feature = "http")]
20pub use http::StatusCode;
21#[cfg(feature = "server")]
22use serde::Serialize;
23use std::{
24    any::Any,
25    backtrace::Backtrace,
26    error::Error as StdError,
27    fmt::{Debug, Display},
28    sync::{Arc, Mutex},
29};
30
31pub trait HostError: Any + StdError + Send + Sync + 'static {}
32impl<T: Any + StdError + Send + Sync + 'static> HostError for T {}
33
34pub enum Error {
35    Context { msg: String, source: Box<SError> },
36    HostLang(Box<dyn HostError>),
37    Client { msg: String, bt: Backtrace },
38    Internal(anyhow::Error),
39}
40
41impl Display for Error {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self.format_context(f)? {
44            Error::Context { .. } => Ok(()),
45            Error::HostLang(e) => write!(f, "{}", e),
46            Error::Client { msg, .. } => write!(f, "Invalid Request: {}", msg),
47            Error::Internal(e) => write!(f, "{}", e),
48        }
49    }
50}
51impl Debug for Error {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self.format_context(f)? {
54            Error::Context { .. } => Ok(()),
55            Error::HostLang(e) => write!(f, "{:?}", e),
56            Error::Client { msg, bt } => {
57                write!(f, "Invalid Request: {msg}\n\n{bt}\n")
58            }
59            Error::Internal(e) => write!(f, "{e:?}"),
60        }
61    }
62}
63
64pub type Result<T, E = Error> = std::result::Result<T, E>;
65
66// Backwards compatibility aliases
67pub type CError = Error;
68pub type CResult<T> = Result<T>;
69
70impl Error {
71    pub fn host(e: impl HostError) -> Self {
72        Self::HostLang(Box::new(e))
73    }
74
75    pub fn client(msg: impl Into<String>) -> Self {
76        Self::Client {
77            msg: msg.into(),
78            bt: Backtrace::capture(),
79        }
80    }
81
82    pub fn internal(e: impl Into<anyhow::Error>) -> Self {
83        Self::Internal(e.into())
84    }
85
86    pub fn internal_msg(msg: impl Into<String>) -> Self {
87        Self::Internal(anyhow::anyhow!("{}", msg.into()))
88    }
89
90    pub fn backtrace(&self) -> Option<&Backtrace> {
91        match self {
92            Error::Client { bt, .. } => Some(bt),
93            Error::Internal(e) => Some(e.backtrace()),
94            Error::Context { source, .. } => source.0.backtrace(),
95            Error::HostLang(_) => None,
96        }
97    }
98
99    pub fn without_contexts(&self) -> &Error {
100        match self {
101            Error::Context { source, .. } => source.0.without_contexts(),
102            other => other,
103        }
104    }
105
106    pub fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
107        match self {
108            Error::Context { source, .. } => Some(source.as_ref()),
109            Error::HostLang(e) => Some(e.as_ref()),
110            Error::Internal(e) => e.source(),
111            Error::Client { .. } => None,
112        }
113    }
114
115    pub fn context<C: Into<String>>(self, context: C) -> Self {
116        Self::Context {
117            msg: context.into(),
118            source: Box::new(SError(self)),
119        }
120    }
121
122    pub fn with_context<C: Into<String>, F: FnOnce() -> C>(self, f: F) -> Self {
123        Self::Context {
124            msg: f().into(),
125            source: Box::new(SError(self)),
126        }
127    }
128
129    pub fn std_error(self) -> SError {
130        SError(self)
131    }
132
133    fn format_context(&self, f: &mut std::fmt::Formatter<'_>) -> Result<&Error, std::fmt::Error> {
134        let mut current = self;
135        if matches!(current, Error::Context { .. }) {
136            write!(f, "\nContext:\n")?;
137            let mut next_id = 1;
138            while let Error::Context { msg, source } = current {
139                writeln!(f, "  {next_id}: {msg}")?;
140                current = source.inner();
141                next_id += 1;
142            }
143        }
144        Ok(current)
145    }
146}
147
148impl StdError for Error {
149    fn source(&self) -> Option<&(dyn StdError + 'static)> {
150        self.source()
151    }
152}
153
154// impl<E: Into<anyhow::Error>> From<E> for Error {
155//     fn from(e: E) -> Self {
156//         Error::Internal(e.into())
157//     }
158// }
159
160// impl<E: Into<anyhow::Error>> From<E> for Error {
161//     fn from(e: E) -> Self {
162//         Error::Internal(e.into())
163//     }
164// }
165
166// Explicitly implement From for common error types used in recoco_utils to avoid conflict with From<T> for T
167impl From<anyhow::Error> for Error {
168    fn from(e: anyhow::Error) -> Self {
169        Error::Internal(e)
170    }
171}
172
173impl From<std::io::Error> for Error {
174    fn from(e: std::io::Error) -> Self {
175        Error::Internal(e.into())
176    }
177}
178#[cfg(any(
179    feature = "server",
180    feature = "concur_control",
181    feature = "retryable",
182    feature = "batching",
183    feature = "http"
184))]
185impl From<tokio::task::JoinError> for Error {
186    fn from(e: tokio::task::JoinError) -> Self {
187        Error::Internal(e.into())
188    }
189}
190#[cfg(any(
191    feature = "server",
192    feature = "concur_control",
193    feature = "retryable",
194    feature = "batching",
195    feature = "http"
196))]
197impl From<tokio::sync::oneshot::error::RecvError> for Error {
198    fn from(e: tokio::sync::oneshot::error::RecvError) -> Self {
199        Error::Internal(e.into())
200    }
201}
202#[cfg(feature = "fingerprint")]
203impl From<base64::DecodeError> for Error {
204    fn from(e: base64::DecodeError) -> Self {
205        Error::Internal(e.into())
206    }
207}
208#[cfg(feature = "fingerprint")]
209impl From<hex::FromHexError> for Error {
210    fn from(e: hex::FromHexError) -> Self {
211        Error::Internal(e.into())
212    }
213}
214
215impl From<ResidualError> for Error {
216    fn from(e: ResidualError) -> Self {
217        Error::Internal(anyhow::Error::from(e))
218    }
219}
220#[cfg(feature = "fingerprint")]
221impl From<crate::fingerprint::FingerprinterError> for Error {
222    fn from(e: crate::fingerprint::FingerprinterError) -> Self {
223        Error::Internal(anyhow::Error::new(e))
224    }
225}
226
227impl From<ApiError> for Error {
228    fn from(e: ApiError) -> Self {
229        Error::Internal(e.err)
230    }
231}
232#[cfg(feature = "deserialize")]
233impl From<serde_json::Error> for Error {
234    fn from(e: serde_json::Error) -> Self {
235        Error::Internal(e.into())
236    }
237}
238#[cfg(any(
239    feature = "local-file",
240    feature = "google-drive",
241    feature = "azure",
242    feature = "s3"
243))]
244impl From<globset::Error> for Error {
245    fn from(e: globset::Error) -> Self {
246        Error::Internal(e.into())
247    }
248}
249#[cfg(feature = "regex")]
250impl From<regex::Error> for Error {
251    fn from(e: regex::Error) -> Self {
252        Error::Internal(e.into())
253    }
254}
255
256impl<T> From<std::sync::PoisonError<T>> for Error {
257    fn from(e: std::sync::PoisonError<T>) -> Self {
258        Error::Internal(anyhow::anyhow!("Mutex poison error: {}", e))
259    }
260}
261#[cfg(feature = "chrono")]
262impl From<chrono::ParseError> for Error {
263    fn from(e: chrono::ParseError) -> Self {
264        Error::Internal(e.into())
265    }
266}
267#[cfg(feature = "uuid")]
268impl From<uuid::Error> for Error {
269    fn from(e: uuid::Error) -> Self {
270        Error::Internal(e.into())
271    }
272}
273#[cfg(any(feature = "server", feature = "http", feature = "reqwest"))]
274impl From<http::header::InvalidHeaderValue> for Error {
275    fn from(e: http::header::InvalidHeaderValue) -> Self {
276        Error::Internal(e.into())
277    }
278}
279
280impl From<std::num::ParseIntError> for Error {
281    fn from(e: std::num::ParseIntError) -> Self {
282        Error::Internal(e.into())
283    }
284}
285
286impl From<std::str::ParseBoolError> for Error {
287    fn from(e: std::str::ParseBoolError) -> Self {
288        Error::Internal(e.into())
289    }
290}
291
292impl From<std::fmt::Error> for Error {
293    fn from(e: std::fmt::Error) -> Self {
294        Error::Internal(e.into())
295    }
296}
297
298impl From<std::string::FromUtf8Error> for Error {
299    fn from(e: std::string::FromUtf8Error) -> Self {
300        Error::Internal(e.into())
301    }
302}
303
304impl From<std::borrow::Cow<'_, str>> for Error {
305    fn from(e: std::borrow::Cow<'_, str>) -> Self {
306        Error::Internal(anyhow::anyhow!("{}", e))
307    }
308}
309#[cfg(any(
310    feature = "server",
311    feature = "concur_control",
312    feature = "retryable",
313    feature = "batching"
314))]
315impl From<tokio::sync::AcquireError> for Error {
316    fn from(e: tokio::sync::AcquireError) -> Self {
317        Error::Internal(e.into())
318    }
319}
320#[cfg(any(
321    feature = "server",
322    feature = "concur_control",
323    feature = "retryable",
324    feature = "batching"
325))]
326impl From<tokio::sync::watch::error::RecvError> for Error {
327    fn from(e: tokio::sync::watch::error::RecvError) -> Self {
328        Error::Internal(e.into())
329    }
330}
331
332#[cfg(feature = "yaml")]
333impl From<yaml_rust2::EmitError> for Error {
334    fn from(e: yaml_rust2::EmitError) -> Self {
335        Error::Internal(e.into())
336    }
337}
338
339#[cfg(feature = "yaml")]
340impl From<crate::yaml_ser::YamlSerializerError> for Error {
341    fn from(e: crate::yaml_ser::YamlSerializerError) -> Self {
342        Error::Internal(anyhow::Error::new(e))
343    }
344}
345
346#[cfg(feature = "reqwest")]
347impl From<reqwest::Error> for Error {
348    fn from(e: reqwest::Error) -> Self {
349        Error::Internal(e.into())
350    }
351}
352
353#[cfg(feature = "sqlx")]
354impl From<sqlx::Error> for Error {
355    fn from(e: sqlx::Error) -> Self {
356        Error::Internal(e.into())
357    }
358}
359
360#[cfg(feature = "neo4rs")]
361impl From<neo4rs::Error> for Error {
362    fn from(e: neo4rs::Error) -> Self {
363        Error::Internal(e.into())
364    }
365}
366
367#[cfg(feature = "openai")]
368impl From<async_openai::error::OpenAIError> for Error {
369    fn from(e: async_openai::error::OpenAIError) -> Self {
370        Error::Internal(e.into())
371    }
372}
373
374#[cfg(feature = "qdrant")]
375impl From<qdrant_client::QdrantError> for Error {
376    fn from(e: qdrant_client::QdrantError) -> Self {
377        Error::Internal(anyhow::Error::msg(e.to_string()))
378    }
379}
380
381#[cfg(feature = "redis")]
382impl From<redis::RedisError> for Error {
383    fn from(e: redis::RedisError) -> Self {
384        Error::Internal(e.into())
385    }
386}
387
388#[cfg(feature = "azure")]
389impl From<azure_storage::Error> for Error {
390    fn from(e: azure_storage::Error) -> Self {
391        Error::Internal(anyhow::Error::msg(e.to_string()))
392    }
393}
394
395#[cfg(feature = "google-drive")]
396impl From<google_drive3::Error> for Error {
397    fn from(e: google_drive3::Error) -> Self {
398        Error::Internal(anyhow::Error::msg(e.to_string()))
399    }
400}
401
402#[cfg(feature = "google-drive")]
403impl From<google_drive3::hyper::Error> for Error {
404    fn from(e: google_drive3::hyper::Error) -> Self {
405        Error::Internal(e.into())
406    }
407}
408
409pub trait ContextExt<T> {
410    fn context<C: Into<String>>(self, context: C) -> Result<T>;
411    fn with_context<C: Into<String>, F: FnOnce() -> C>(self, f: F) -> Result<T>;
412}
413
414impl<T> ContextExt<T> for Result<T> {
415    fn context<C: Into<String>>(self, context: C) -> Result<T> {
416        self.map_err(|e| e.context(context))
417    }
418
419    fn with_context<C: Into<String>, F: FnOnce() -> C>(self, f: F) -> Result<T> {
420        self.map_err(|e| e.with_context(f))
421    }
422}
423
424pub trait StdContextExt<T, E> {
425    fn context<C: Into<String>>(self, context: C) -> Result<T>;
426    fn with_context<C: Into<String>, F: FnOnce() -> C>(self, f: F) -> Result<T>;
427}
428
429impl<T, E: StdError + Send + Sync + 'static> StdContextExt<T, E> for Result<T, E> {
430    fn context<C: Into<String>>(self, context: C) -> Result<T> {
431        self.map_err(|e| Error::internal(e).context(context))
432    }
433
434    fn with_context<C: Into<String>, F: FnOnce() -> C>(self, f: F) -> Result<T> {
435        self.map_err(|e| Error::internal(e).with_context(f))
436    }
437}
438
439impl<T> ContextExt<T> for Option<T> {
440    fn context<C: Into<String>>(self, context: C) -> Result<T> {
441        self.ok_or_else(|| Error::client(context))
442    }
443
444    fn with_context<C: Into<String>, F: FnOnce() -> C>(self, f: F) -> Result<T> {
445        self.ok_or_else(|| Error::client(f()))
446    }
447}
448
449#[cfg(feature = "server")]
450impl IntoResponse for Error {
451    fn into_response(self) -> Response {
452        tracing::debug!("Error response:\n{:?}", self);
453
454        let (status_code, error_msg) = match &self {
455            Error::Client { msg, .. } => (StatusCode::BAD_REQUEST, msg.clone()),
456            Error::HostLang(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
457            Error::Context { .. } | Error::Internal(_) => {
458                (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", self))
459            }
460        };
461        let error_response = ErrorResponse { error: error_msg };
462        (status_code, Json(error_response)).into_response()
463    }
464}
465
466#[macro_export]
467macro_rules! client_bail {
468    ( $fmt:literal $(, $($arg:tt)*)?) => {
469        return Err($crate::error::Error::client(format!($fmt $(, $($arg)*)?)))
470    };
471}
472
473#[macro_export]
474macro_rules! client_error {
475    ( $fmt:literal $(, $($arg:tt)*)?) => {
476        $crate::error::Error::client(format!($fmt $(, $($arg)*)?))
477    };
478}
479
480#[macro_export]
481macro_rules! internal_bail {
482    ( $fmt:literal $(, $($arg:tt)*)?) => {
483        return Err($crate::error::Error::internal_msg(format!($fmt $(, $($arg)*)?)))
484    };
485}
486
487#[macro_export]
488macro_rules! internal_error {
489    ( $fmt:literal $(, $($arg:tt)*)?) => {
490        $crate::error::Error::internal_msg(format!($fmt $(, $($arg)*)?))
491    };
492}
493
494// A wrapper around Error that fits into std::error::Error trait.
495pub struct SError(Error);
496
497impl SError {
498    pub fn inner(&self) -> &Error {
499        &self.0
500    }
501}
502
503impl Display for SError {
504    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505        Display::fmt(&self.0, f)
506    }
507}
508
509impl Debug for SError {
510    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511        Debug::fmt(&self.0, f)
512    }
513}
514
515impl std::error::Error for SError {
516    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
517        self.0.source()
518    }
519}
520
521// Legacy types below - kept for backwards compatibility during migration
522
523struct ResidualErrorData {
524    message: String,
525    debug: String,
526}
527
528#[derive(Clone)]
529pub struct ResidualError(Arc<ResidualErrorData>);
530
531impl ResidualError {
532    pub fn new<Err: Display + Debug>(err: &Err) -> Self {
533        Self(Arc::new(ResidualErrorData {
534            message: err.to_string(),
535            debug: err.to_string(),
536        }))
537    }
538}
539
540impl Display for ResidualError {
541    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
542        write!(f, "{}", self.0.message)
543    }
544}
545
546impl Debug for ResidualError {
547    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
548        write!(f, "{}", self.0.debug)
549    }
550}
551
552impl StdError for ResidualError {}
553
554enum SharedErrorState {
555    Error(Error),
556    ResidualErrorMessage(ResidualError),
557}
558
559#[derive(Clone)]
560pub struct SharedError(Arc<Mutex<SharedErrorState>>);
561
562impl SharedError {
563    pub fn new(err: Error) -> Self {
564        Self(Arc::new(Mutex::new(SharedErrorState::Error(err))))
565    }
566
567    fn extract_error(&self) -> Error {
568        let mut state = self.0.lock().unwrap();
569        let mut_state = &mut *state;
570
571        let residual_err = match mut_state {
572            SharedErrorState::ResidualErrorMessage(err) => {
573                // Already extracted; return a generic internal error with the residual message.
574                return Error::internal(err.clone());
575            }
576            SharedErrorState::Error(err) => ResidualError::new(err),
577        };
578
579        let orig_state = std::mem::replace(
580            mut_state,
581            SharedErrorState::ResidualErrorMessage(residual_err),
582        );
583        let SharedErrorState::Error(err) = orig_state else {
584            panic!("Expected shared error state to hold Error");
585        };
586        err
587    }
588}
589
590impl Debug for SharedError {
591    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
592        let state = self.0.lock().unwrap();
593        match &*state {
594            SharedErrorState::Error(err) => Debug::fmt(err, f),
595            SharedErrorState::ResidualErrorMessage(err) => Debug::fmt(err, f),
596        }
597    }
598}
599
600impl Display for SharedError {
601    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
602        let state = self.0.lock().unwrap();
603        match &*state {
604            SharedErrorState::Error(err) => Display::fmt(err, f),
605            SharedErrorState::ResidualErrorMessage(err) => Display::fmt(err, f),
606        }
607    }
608}
609
610impl From<Error> for SharedError {
611    fn from(err: Error) -> Self {
612        Self(Arc::new(Mutex::new(SharedErrorState::Error(err))))
613    }
614}
615
616pub fn shared_ok<T>(value: T) -> std::result::Result<T, SharedError> {
617    Ok(value)
618}
619
620pub type SharedResult<T> = std::result::Result<T, SharedError>;
621
622pub trait SharedResultExt<T> {
623    fn into_result(self) -> Result<T>;
624}
625
626impl<T> SharedResultExt<T> for std::result::Result<T, SharedError> {
627    fn into_result(self) -> Result<T> {
628        match self {
629            Ok(value) => Ok(value),
630            Err(err) => Err(err.extract_error()),
631        }
632    }
633}
634
635pub trait SharedResultExtRef<'a, T> {
636    fn into_result(self) -> Result<&'a T>;
637}
638
639impl<'a, T> SharedResultExtRef<'a, T> for &'a std::result::Result<T, SharedError> {
640    fn into_result(self) -> Result<&'a T> {
641        match self {
642            Ok(value) => Ok(value),
643            Err(err) => Err(err.extract_error()),
644        }
645    }
646}
647
648pub fn invariance_violation() -> anyhow::Error {
649    anyhow::anyhow!("Invariance violation")
650}
651
652#[derive(Debug)]
653pub struct ApiError {
654    pub err: anyhow::Error,
655    #[cfg(feature = "http")]
656    pub status_code: StatusCode,
657}
658
659impl ApiError {
660    cfg_if::cfg_if! {
661        if #[cfg(feature = "http")] {
662        pub fn new(message: &str, status_code: StatusCode) -> Self {
663            Self {
664                err: anyhow::anyhow!("{}", message),
665                status_code,
666            }
667        }} else {
668            pub fn new(message: &str) -> Self {
669                Self {
670                    err: anyhow::anyhow!("{}", message),
671                }
672            }
673        }
674    }
675}
676
677impl Display for ApiError {
678    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
679        Display::fmt(&self.err, f)
680    }
681}
682
683impl StdError for ApiError {
684    fn source(&self) -> Option<&(dyn StdError + 'static)> {
685        self.err.source()
686    }
687}
688
689cfg_if::cfg_if! {
690    if #[cfg(feature = "server")] {
691        #[derive(Serialize)]
692        struct ErrorResponse {
693            error: String,
694        }
695    }
696}
697#[cfg(feature = "server")]
698impl IntoResponse for ApiError {
699    fn into_response(self) -> Response {
700        tracing::debug!("Internal server error:\n{:?}", self.err);
701        let error_response = ErrorResponse {
702            error: format!("{:?}", self.err),
703        };
704        (self.status_code, Json(error_response)).into_response()
705    }
706}
707cfg_if::cfg_if! {
708    if #[cfg(feature = "http")] {
709        impl From<anyhow::Error> for ApiError {
710            fn from(err: anyhow::Error) -> ApiError {
711                if err.is::<ApiError>() {
712                    return err.downcast::<ApiError>().unwrap();
713                }
714                Self {
715                    err,
716                    status_code: StatusCode::INTERNAL_SERVER_ERROR,
717                }
718            }
719        }
720    } else {
721        impl From<anyhow::Error> for ApiError {
722            fn from(err: anyhow::Error) -> ApiError {
723                if err.is::<ApiError>() {
724                    return err.downcast::<ApiError>().unwrap();
725                }
726                Self {
727                    err,
728                }
729            }
730        }
731    }
732}
733impl From<Error> for ApiError {
734    fn from(err: Error) -> ApiError {
735        cfg_if::cfg_if! {
736            if #[cfg(feature = "http")] {
737                let status_code = match err.without_contexts() {
738                    Error::Client { .. } => StatusCode::BAD_REQUEST,
739                    _ => StatusCode::INTERNAL_SERVER_ERROR,
740                };
741                ApiError {
742                    err: anyhow::Error::from(err.std_error()),
743                    status_code,
744                }
745            } else {
746                ApiError {
747                    err: anyhow::Error::from(err.std_error()),
748                }
749            }
750        }
751    }
752}
753cfg_if::cfg_if! {
754    if #[cfg(feature = "http")] {
755        #[macro_export]
756        macro_rules! api_bail {
757            ( $fmt:literal $(, $($arg:tt)*)?) => {
758                return Err($crate::error::ApiError::new(&format!($fmt $(, $($arg)*)?), $crate::error::StatusCode::BAD_REQUEST).into())
759            };
760        }
761    } else {
762        #[macro_export]
763        macro_rules! api_bail {
764            ( $fmt:literal $(, $($arg:tt)*)?) => {
765                return Err($crate::error::ApiError::new(&format!($fmt $(, $($arg)*)?)).into())
766            };
767        }
768    }
769}
770
771cfg_if::cfg_if! {
772    if #[cfg(feature = "http")] {
773        #[macro_export]
774        macro_rules! api_error {
775            ( $fmt:literal $(, $($arg:tt)*)?) => {
776                $crate::error::ApiError::new(&format!($fmt $(, $($arg)*)?), $crate::error::StatusCode::BAD_REQUEST)
777            };
778        }
779    } else {
780        #[macro_export]
781        macro_rules! api_error {
782            ( $fmt:literal $(, $($arg:tt)*)?) => {
783                $crate::error::ApiError::new(&format!($fmt $(, $($arg)*)?))
784            };
785        }
786    }
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792    use std::backtrace::BacktraceStatus;
793    use std::io;
794
795    #[derive(Debug)]
796    struct MockHostError(String);
797
798    impl Display for MockHostError {
799        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
800            write!(f, "MockHostError: {}", self.0)
801        }
802    }
803
804    impl StdError for MockHostError {}
805
806    #[test]
807    fn test_client_error_creation() {
808        let err = Error::client("invalid input");
809        assert!(matches!(&err, Error::Client { msg, .. } if msg == "invalid input"));
810        assert!(matches!(err.without_contexts(), Error::Client { .. }));
811    }
812
813    #[test]
814    fn test_internal_error_creation() {
815        let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found");
816        let err: Error = io_err.into();
817        assert!(matches!(err, Error::Internal { .. }));
818    }
819
820    #[test]
821    fn test_internal_msg_error_creation() {
822        let err = Error::internal_msg("something went wrong");
823        assert!(matches!(err, Error::Internal { .. }));
824        assert_eq!(err.to_string(), "something went wrong");
825    }
826
827    #[test]
828    fn test_host_error_creation_and_detection() {
829        let mock = MockHostError("test error".to_string());
830        let err = Error::host(mock);
831        assert!(matches!(err.without_contexts(), Error::HostLang(_)));
832
833        if let Error::HostLang(host_err) = err.without_contexts() {
834            let any: &dyn Any = host_err.as_ref();
835            let downcasted = any.downcast_ref::<MockHostError>();
836            assert!(downcasted.is_some());
837            assert_eq!(downcasted.unwrap().0, "test error");
838        } else {
839            panic!("Expected HostLang variant");
840        }
841    }
842
843    #[test]
844    fn test_context_chaining() {
845        let inner = Error::client("base error");
846        let with_context: Result<()> = Err(inner);
847        let wrapped = ContextExt::context(
848            ContextExt::context(ContextExt::context(with_context, "layer 1"), "layer 2"),
849            "layer 3",
850        );
851
852        let err = wrapped.unwrap_err();
853        assert!(matches!(&err, Error::Context { msg, .. } if msg == "layer 3"));
854
855        if let Error::Context { source, .. } = &err {
856            assert!(
857                matches!(source.as_ref(), SError(Error::Context { msg, .. }) if msg == "layer 2")
858            );
859        }
860        assert_eq!(
861            err.to_string(),
862            "\nContext:\
863             \n  1: layer 3\
864             \n  2: layer 2\
865             \n  3: layer 1\
866             \nInvalid Request: base error"
867        );
868    }
869
870    #[test]
871    fn test_context_preserves_host_error() {
872        let mock = MockHostError("original python error".to_string());
873        let err = Error::host(mock);
874        let wrapped: Result<()> = Err(err);
875        let with_context = ContextExt::context(wrapped, "while processing request");
876
877        let final_err = with_context.unwrap_err();
878        assert!(matches!(final_err.without_contexts(), Error::HostLang(_)));
879
880        if let Error::HostLang(host_err) = final_err.without_contexts() {
881            let any: &dyn Any = host_err.as_ref();
882            let downcasted = any.downcast_ref::<MockHostError>();
883            assert!(downcasted.is_some());
884            assert_eq!(downcasted.unwrap().0, "original python error");
885        } else {
886            panic!("Expected HostLang variant");
887        }
888    }
889
890    #[test]
891    fn test_backtrace_captured_for_client_error() {
892        let err = Error::client("test");
893        let bt = err.backtrace();
894        assert!(bt.is_some());
895        let status = bt.unwrap().status();
896        assert!(
897            status == BacktraceStatus::Captured
898                || status == BacktraceStatus::Disabled
899                || status == BacktraceStatus::Unsupported
900        );
901    }
902
903    #[test]
904    fn test_backtrace_captured_for_internal_error() {
905        let err = Error::internal_msg("test internal");
906        let bt = err.backtrace();
907        assert!(bt.is_some());
908    }
909
910    #[test]
911    fn test_backtrace_traverses_context() {
912        let inner = Error::internal_msg("base");
913        let wrapped: Result<()> = Err(inner);
914        let with_context = ContextExt::context(wrapped, "context");
915
916        let err = with_context.unwrap_err();
917        let bt = err.backtrace();
918        assert!(bt.is_some());
919    }
920
921    #[test]
922    fn test_option_context_ext() {
923        let opt: Option<i32> = None;
924        let result = opt.context("value was missing");
925
926        assert!(result.is_err());
927        let err = result.unwrap_err();
928        assert!(matches!(err.without_contexts(), Error::Client { .. }));
929        assert!(matches!(&err, Error::Client { msg, .. } if msg == "value was missing"));
930    }
931
932    #[test]
933    fn test_error_display_formats() {
934        let client_err = Error::client("bad input");
935        assert_eq!(client_err.to_string(), "Invalid Request: bad input");
936
937        let internal_err = Error::internal_msg("db connection failed");
938        assert_eq!(internal_err.to_string(), "db connection failed");
939
940        let host_err = Error::host(MockHostError("py error".to_string()));
941        assert_eq!(host_err.to_string(), "MockHostError: py error");
942    }
943
944    #[test]
945    fn test_error_source_chain() {
946        let inner = Error::internal_msg("root cause");
947        let wrapped: Result<()> = Err(inner);
948        let outer = ContextExt::context(wrapped, "outer context").unwrap_err();
949
950        let source = outer.source();
951        assert!(source.is_some());
952    }
953}