Skip to main content

ocelot_base/
error.rs

1use std::error::Error as StdError;
2use std::fmt::{Debug, Display, Formatter};
3use std::panic::Location;
4use tracing_error::{SpanTrace, SpanTraceStatus};
5
6use crate::compilation_stage::CompilationStage;
7use crate::shared_string::SharedString;
8use crate::unansi;
9
10#[derive(Debug)]
11pub enum ErrorKind {
12    Message(SharedString),
13    CompilationError(CompilationStage),
14    Std(Box<dyn StdError + Send + Sync + 'static>),
15}
16
17impl ErrorKind {
18    /// Returns the inner message for message-based errors.
19    pub fn as_message(&self) -> Option<&SharedString> {
20        match self {
21            Self::Message(message) => Some(message),
22            Self::CompilationError(_) | Self::Std(_) => None,
23        }
24    }
25}
26
27impl Display for ErrorKind {
28    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::Message(message) => f.write_str(message),
31            Self::CompilationError(stage) => write!(f, "{stage} compilation error"),
32            Self::Std(error) => Display::fmt(error, f),
33        }
34    }
35}
36
37#[derive(Debug)]
38pub struct OcelotError {
39    kind: ErrorKind,
40    source: Option<Box<OcelotError>>,
41    location: &'static Location<'static>,
42    span_trace: SpanTrace,
43}
44
45impl OcelotError {
46    #[track_caller]
47    pub fn new(kind: ErrorKind) -> Self {
48        Self::at_location(kind, Location::caller())
49    }
50
51    pub fn at_location(kind: ErrorKind, location: &'static Location<'static>) -> Self {
52        Self {
53            kind,
54            source: None,
55            location,
56            span_trace: SpanTrace::capture(),
57        }
58    }
59
60    #[track_caller]
61    pub fn message(s: impl Into<SharedString>) -> Self {
62        Self::message_at_location(s, Location::caller())
63    }
64
65    pub fn message_at_location(
66        s: impl Into<SharedString>,
67        location: &'static Location<'static>,
68    ) -> Self {
69        Self::at_location(ErrorKind::Message(s.into()), location)
70    }
71
72    #[track_caller]
73    pub fn compilation_error(stage: CompilationStage) -> Self {
74        Self::compilation_error_at_location(stage, Location::caller())
75    }
76
77    pub fn compilation_error_at_location(
78        stage: CompilationStage,
79        location: &'static Location<'static>,
80    ) -> Self {
81        Self::at_location(ErrorKind::CompilationError(stage), location)
82    }
83
84    #[track_caller]
85    pub fn std(error: impl StdError + Send + Sync + 'static) -> Self {
86        Self::std_at_location(error, Location::caller())
87    }
88
89    pub fn std_at_location(
90        error: impl StdError + Send + Sync + 'static,
91        location: &'static Location<'static>,
92    ) -> Self {
93        Self::at_location(ErrorKind::Std(Box::new(error)), location)
94    }
95
96    pub fn kind(&self) -> &ErrorKind {
97        &self.kind
98    }
99
100    pub fn source(&self) -> Option<&OcelotError> {
101        self.source.as_deref()
102    }
103
104    pub fn location(&self) -> &'static Location<'static> {
105        self.location
106    }
107
108    pub fn span_trace(&self) -> &SpanTrace {
109        &self.span_trace
110    }
111
112    pub fn with_source(mut self, source: impl Into<OcelotError>) -> Self {
113        self.source = Some(Box::new(source.into()));
114        self
115    }
116
117    #[track_caller]
118    pub fn with_std_source(mut self, source: impl StdError + Send + Sync + 'static) -> Self {
119        self.source = Some(Box::new(OcelotError::std_at_location(
120            source,
121            Location::caller(),
122        )));
123        self
124    }
125
126    pub fn with_std_source_at_location(
127        mut self,
128        source: impl StdError + Send + Sync + 'static,
129        location: &'static Location<'static>,
130    ) -> Self {
131        self.source = Some(Box::new(OcelotError::std_at_location(source, location)));
132        self
133    }
134
135    pub fn write_to(&self, write: &mut dyn std::fmt::Write) -> std::fmt::Result {
136        writeln!(write, "{} {}", style("1;31", "× error"), self.kind)?;
137        self.write_details(write, "")?;
138        Ok(())
139    }
140
141    pub fn to_test_string(&self) -> String {
142        let mut test_string = String::new();
143        if self.write_to(&mut test_string).is_err() {
144            test_string.push_str("failed to render error");
145        }
146        unansi(&test_string)
147    }
148}
149
150impl OcelotError {
151    fn write_details(&self, write: &mut dyn std::fmt::Write, prefix: &str) -> std::fmt::Result {
152        let show_span_trace = self.source.is_none();
153
154        writeln!(
155            write,
156            "{}{} {}:{}:{}",
157            prefix,
158            style("2;37", "  at"),
159            self.location.file(),
160            self.location.line(),
161            self.location.column()
162        )?;
163
164        if show_span_trace && self.span_trace.status() == SpanTraceStatus::CAPTURED {
165            writeln!(write, "{}{}", prefix, style("36", "  span trace:"))?;
166            write_span_trace(write, prefix, &self.span_trace)?;
167        }
168
169        if let Some(source) = self.source.as_deref() {
170            write_rendered_cause(
171                write,
172                prefix,
173                &style("33", "caused by:"),
174                &source.kind.to_string(),
175            )?;
176            source.write_child_details(write, &format!("{prefix}   "))?;
177        }
178
179        Ok(())
180    }
181
182    fn write_child_details(
183        &self,
184        write: &mut dyn std::fmt::Write,
185        prefix: &str,
186    ) -> std::fmt::Result {
187        let show_span_trace = self.source.is_none();
188
189        writeln!(
190            write,
191            "{}{} {}:{}:{}",
192            prefix,
193            style("2;37", "  at"),
194            self.location.file(),
195            self.location.line(),
196            self.location.column()
197        )?;
198
199        if show_span_trace && self.span_trace.status() == SpanTraceStatus::CAPTURED {
200            writeln!(write, "{}{}", prefix, style("36", "  span trace:"))?;
201            write_span_trace(write, prefix, &self.span_trace)?;
202        }
203
204        if let Some(source) = self.source.as_deref() {
205            write_rendered_cause(
206                write,
207                prefix,
208                &style("33", "caused by:"),
209                &source.kind.to_string(),
210            )?;
211            source.write_child_details(write, &format!("{prefix}   "))?;
212        }
213
214        Ok(())
215    }
216}
217
218fn write_rendered_cause(
219    write: &mut dyn std::fmt::Write,
220    prefix: &str,
221    label: &str,
222    rendered: &str,
223) -> std::fmt::Result {
224    if rendered.contains('\n') {
225        writeln!(write, "{prefix}{label}")?;
226        for line in rendered.lines() {
227            writeln!(write, "{prefix}   {line}")?;
228        }
229    } else {
230        writeln!(write, "{prefix}{label} {rendered}")?;
231    }
232
233    Ok(())
234}
235
236fn write_span_trace(
237    write: &mut dyn std::fmt::Write,
238    prefix: &str,
239    span_trace: &SpanTrace,
240) -> std::fmt::Result {
241    let mut result = Ok(());
242    let mut span_index = 0;
243
244    span_trace.with_spans(|metadata, fields| {
245        if span_index > 0 && writeln!(write).is_err() {
246            result = Err(std::fmt::Error);
247            return false;
248        }
249
250        if writeln!(
251            write,
252            "{}    {}: {}::{}",
253            prefix,
254            span_index,
255            metadata.target(),
256            metadata.name()
257        )
258        .is_err()
259        {
260            result = Err(std::fmt::Error);
261            return false;
262        }
263
264        if !fields.is_empty()
265            && writeln!(
266                write,
267                "{}       {}",
268                prefix,
269                format_span_trace_fields(fields)
270            )
271            .is_err()
272        {
273            result = Err(std::fmt::Error);
274            return false;
275        }
276
277        if let Some((file, line)) = metadata
278            .file()
279            .and_then(|file| metadata.line().map(|line| (file, line)))
280            && writeln!(write, "{}       at {}:{}", prefix, file, line).is_err()
281        {
282            result = Err(std::fmt::Error);
283            return false;
284        }
285
286        span_index += 1;
287        true
288    });
289
290    result
291}
292
293fn format_span_trace_fields(fields: &str) -> String {
294    let mut formatted = String::new();
295
296    for (index, field) in fields.split_whitespace().enumerate() {
297        if index > 0 {
298            formatted.push(' ');
299        }
300
301        if let Some((key, value)) = field.split_once('=') {
302            formatted.push_str(key);
303            formatted.push(':');
304            formatted.push(' ');
305            formatted.push_str(&style("1;97", value));
306        } else {
307            formatted.push_str(field);
308        }
309    }
310
311    formatted
312}
313
314fn style(code: &str, text: &str) -> String {
315    format!("\u{1b}[{code}m{text}\u{1b}[0m")
316}
317
318impl<T> From<T> for OcelotError
319where
320    T: StdError + Send + Sync + 'static,
321{
322    #[track_caller]
323    fn from(value: T) -> Self {
324        Self::std(value)
325    }
326}
327
328#[macro_export]
329macro_rules! err {
330    ($($arg:tt)*) => {
331        $crate::error::OcelotError::message(format!($($arg)*))
332    };
333}
334pub use err;
335
336#[macro_export]
337macro_rules! bail {
338    ($($arg:tt)*) => {
339        return Err($crate::err!($($arg)*))
340    };
341}
342pub use bail;
343
344#[cfg(test)]
345mod tests {
346    use super::format_span_trace_fields;
347
348    #[test]
349    fn test_format_span_trace_fields() {
350        let rendered = format_span_trace_fields(
351            "sources_dir=verification/sources output_dir=verification/output/ocelot",
352        );
353        let rendered = crate::unansi(&rendered);
354
355        assert_eq!(
356            rendered,
357            "sources_dir: verification/sources output_dir: verification/output/ocelot"
358        );
359    }
360}