Skip to main content

ocelot_base/
error.rs

1use std::error::Error as StdError;
2use std::fmt::{Debug, Display, Formatter};
3use std::panic::Location;
4use tracing_error::{SpanTrace, SpanTraceStatus};
5
6use crate::shared_string::SharedString;
7use crate::unansi;
8
9#[derive(Debug)]
10pub enum ErrorKind {
11    Message(SharedString),
12    Std(Box<dyn StdError + Send + Sync + 'static>),
13}
14
15impl Display for ErrorKind {
16    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17        match self {
18            Self::Message(message) => f.write_str(message),
19            Self::Std(error) => Display::fmt(error, f),
20        }
21    }
22}
23
24#[derive(Debug)]
25pub struct OcelotError {
26    kind: ErrorKind,
27    source: Option<Box<OcelotError>>,
28    location: &'static Location<'static>,
29    span_trace: SpanTrace,
30}
31
32impl OcelotError {
33    #[track_caller]
34    pub fn new(kind: ErrorKind) -> Self {
35        Self::at_location(kind, Location::caller())
36    }
37
38    pub fn at_location(kind: ErrorKind, location: &'static Location<'static>) -> Self {
39        Self {
40            kind,
41            source: None,
42            location,
43            span_trace: SpanTrace::capture(),
44        }
45    }
46
47    #[track_caller]
48    pub fn message(s: impl Into<SharedString>) -> Self {
49        Self::message_at_location(s, Location::caller())
50    }
51
52    pub fn message_at_location(
53        s: impl Into<SharedString>,
54        location: &'static Location<'static>,
55    ) -> Self {
56        Self::at_location(ErrorKind::Message(s.into()), location)
57    }
58
59    #[track_caller]
60    pub fn std(error: impl StdError + Send + Sync + 'static) -> Self {
61        Self::std_at_location(error, Location::caller())
62    }
63
64    pub fn std_at_location(
65        error: impl StdError + Send + Sync + 'static,
66        location: &'static Location<'static>,
67    ) -> Self {
68        Self::at_location(ErrorKind::Std(Box::new(error)), location)
69    }
70
71    pub fn kind(&self) -> &ErrorKind {
72        &self.kind
73    }
74
75    pub fn source(&self) -> Option<&OcelotError> {
76        self.source.as_deref()
77    }
78
79    pub fn location(&self) -> &'static Location<'static> {
80        self.location
81    }
82
83    pub fn span_trace(&self) -> &SpanTrace {
84        &self.span_trace
85    }
86
87    pub fn with_source(mut self, source: impl Into<OcelotError>) -> Self {
88        self.source = Some(Box::new(source.into()));
89        self
90    }
91
92    #[track_caller]
93    pub fn with_std_source(mut self, source: impl StdError + Send + Sync + 'static) -> Self {
94        self.source = Some(Box::new(OcelotError::std_at_location(
95            source,
96            Location::caller(),
97        )));
98        self
99    }
100
101    pub fn with_std_source_at_location(
102        mut self,
103        source: impl StdError + Send + Sync + 'static,
104        location: &'static Location<'static>,
105    ) -> Self {
106        self.source = Some(Box::new(OcelotError::std_at_location(source, location)));
107        self
108    }
109
110    pub fn write_to(&self, write: &mut dyn std::fmt::Write) -> std::fmt::Result {
111        writeln!(write, "{} {}", style("1;31", "× error"), self.kind)?;
112        self.write_details(write, "")?;
113        Ok(())
114    }
115
116    pub fn to_test_string(&self) -> String {
117        let mut test_string = String::new();
118        if self.write_to(&mut test_string).is_err() {
119            test_string.push_str("failed to render error");
120        }
121        unansi(&test_string)
122    }
123}
124
125impl OcelotError {
126    fn write_details(&self, write: &mut dyn std::fmt::Write, prefix: &str) -> std::fmt::Result {
127        let show_span_trace = self.source.is_none();
128
129        writeln!(
130            write,
131            "{}{} {}:{}:{}",
132            prefix,
133            style("2;37", "  at"),
134            self.location.file(),
135            self.location.line(),
136            self.location.column()
137        )?;
138
139        if show_span_trace && self.span_trace.status() == SpanTraceStatus::CAPTURED {
140            writeln!(write, "{}{}", prefix, style("36", "  span trace:"))?;
141            write_span_trace(write, prefix, &self.span_trace)?;
142        }
143
144        if let Some(source) = self.source.as_deref() {
145            write_rendered_cause(
146                write,
147                prefix,
148                &style("33", "caused by:"),
149                &source.kind.to_string(),
150            )?;
151            source.write_child_details(write, &format!("{prefix}   "))?;
152        }
153
154        Ok(())
155    }
156
157    fn write_child_details(
158        &self,
159        write: &mut dyn std::fmt::Write,
160        prefix: &str,
161    ) -> std::fmt::Result {
162        let show_span_trace = self.source.is_none();
163
164        writeln!(
165            write,
166            "{}{} {}:{}:{}",
167            prefix,
168            style("2;37", "  at"),
169            self.location.file(),
170            self.location.line(),
171            self.location.column()
172        )?;
173
174        if show_span_trace && self.span_trace.status() == SpanTraceStatus::CAPTURED {
175            writeln!(write, "{}{}", prefix, style("36", "  span trace:"))?;
176            write_span_trace(write, prefix, &self.span_trace)?;
177        }
178
179        if let Some(source) = self.source.as_deref() {
180            write_rendered_cause(
181                write,
182                prefix,
183                &style("33", "caused by:"),
184                &source.kind.to_string(),
185            )?;
186            source.write_child_details(write, &format!("{prefix}   "))?;
187        }
188
189        Ok(())
190    }
191}
192
193fn write_rendered_cause(
194    write: &mut dyn std::fmt::Write,
195    prefix: &str,
196    label: &str,
197    rendered: &str,
198) -> std::fmt::Result {
199    if rendered.contains('\n') {
200        writeln!(write, "{prefix}{label}")?;
201        for line in rendered.lines() {
202            writeln!(write, "{prefix}   {line}")?;
203        }
204    } else {
205        writeln!(write, "{prefix}{label} {rendered}")?;
206    }
207
208    Ok(())
209}
210
211fn write_span_trace(
212    write: &mut dyn std::fmt::Write,
213    prefix: &str,
214    span_trace: &SpanTrace,
215) -> std::fmt::Result {
216    let mut result = Ok(());
217    let mut span_index = 0;
218
219    span_trace.with_spans(|metadata, fields| {
220        if span_index > 0 && writeln!(write).is_err() {
221            result = Err(std::fmt::Error);
222            return false;
223        }
224
225        if writeln!(
226            write,
227            "{}    {}: {}::{}",
228            prefix,
229            span_index,
230            metadata.target(),
231            metadata.name()
232        )
233        .is_err()
234        {
235            result = Err(std::fmt::Error);
236            return false;
237        }
238
239        if !fields.is_empty()
240            && writeln!(
241                write,
242                "{}       {}",
243                prefix,
244                format_span_trace_fields(fields)
245            )
246            .is_err()
247        {
248            result = Err(std::fmt::Error);
249            return false;
250        }
251
252        if let Some((file, line)) = metadata
253            .file()
254            .and_then(|file| metadata.line().map(|line| (file, line)))
255            && writeln!(write, "{}       at {}:{}", prefix, file, line).is_err()
256        {
257            result = Err(std::fmt::Error);
258            return false;
259        }
260
261        span_index += 1;
262        true
263    });
264
265    result
266}
267
268fn format_span_trace_fields(fields: &str) -> String {
269    let mut formatted = String::new();
270
271    for (index, field) in fields.split_whitespace().enumerate() {
272        if index > 0 {
273            formatted.push(' ');
274        }
275
276        if let Some((key, value)) = field.split_once('=') {
277            formatted.push_str(key);
278            formatted.push(':');
279            formatted.push(' ');
280            formatted.push_str(&style("1;97", value));
281        } else {
282            formatted.push_str(field);
283        }
284    }
285
286    formatted
287}
288
289fn style(code: &str, text: &str) -> String {
290    format!("\u{1b}[{code}m{text}\u{1b}[0m")
291}
292
293impl<T> From<T> for OcelotError
294where
295    T: StdError + Send + Sync + 'static,
296{
297    #[track_caller]
298    fn from(value: T) -> Self {
299        Self::std(value)
300    }
301}
302
303#[macro_export]
304macro_rules! err {
305    ($($arg:tt)*) => {
306        $crate::error::OcelotError::message(format!($($arg)*))
307    };
308}
309pub use err;
310
311#[macro_export]
312macro_rules! bail {
313    ($($arg:tt)*) => {
314        return Err($crate::err!($($arg)*))
315    };
316}
317pub use bail;
318
319#[cfg(test)]
320mod tests {
321    use super::format_span_trace_fields;
322
323    #[test]
324    fn test_format_span_trace_fields() {
325        let rendered = format_span_trace_fields(
326            "sources_dir=verification/sources output_dir=verification/output/ocelot",
327        );
328        let rendered = crate::unansi(&rendered);
329
330        assert_eq!(
331            rendered,
332            "sources_dir: verification/sources output_dir: verification/output/ocelot"
333        );
334    }
335}