Skip to main content

token_goblin_runtime/
wire.rs

1//! Guest-side wire format for the dylib boundary.
2//! It's an internall module, but feel free to enjoy the docs.
3
4use std::{ops::Range, panic::UnwindSafe};
5
6use proc_macro2::{LexError, Span, TokenStream, TokenTree};
7
8use crate::wire::panic::{PanicLocation, PanicReport};
9
10/// Guest output packet returned from dylib `entry`.
11/// THIS TYPE SHOULD MATCH `token_goblin::span_recovery::Output`
12pub struct Output {
13    pub text: String,
14    pub spans: Vec<OutputEntry>,
15}
16/// Single entry in the output.
17/// THIS TYPE SHOULD MATCH `token_goblin::span_recovery::OutputEntry`
18#[derive(PartialEq, Eq, Debug)]
19pub struct OutputEntry {
20    pub is_panic: bool,
21    pub range: Range<usize>,
22}
23impl OutputEntry {
24    pub fn new(is_panic: bool, range: Range<usize>) -> Self {
25        Self { is_panic, range }
26    }
27}
28
29/// The entry for `charm` that handle input/output converions, and panics handling.
30///
31/// check out:
32/// - [`panic::run_and_catch`] for panic handling.
33/// - [`parse_input`] for input parsing.
34/// - [`collect_outputs`] for output serialization.
35///
36#[allow(clippy::missing_panics_doc, reason = "panic is in catch block")]
37pub fn entry(input: &str, body: impl FnOnce(TokenStream) -> TokenStream + UnwindSafe) -> Output {
38    let mut panics = None;
39
40    let (tokens, anchor) = panic::run_and_catch(|| {
41        let (input, anchor) = parse_input(input).unwrap();
42        (body(input), anchor)
43    })
44    .unwrap_or_else(|e| {
45        let (msg, location) = panic_to_compile_error(e);
46        panics = Some((
47            msg.clone(),
48            location // TODO: Rewrite it with normal types
49                .map(|loc| Range {
50                    start: loc.line as usize,
51                    end: loc.column as usize,
52                })
53                .unwrap_or_default(),
54        ));
55        (TokenStream::new(), None)
56    });
57
58    collect_outputs(tokens, panics, anchor)
59}
60fn panic_to_compile_error(e: PanicReport) -> (TokenStream, Option<PanicLocation>) {
61    let message = format!(
62        "panic in charm (at {location}): {error}",
63        error = e.message,
64        location = e.location.as_ref().map_or_else(
65            || "<unknown location>".to_string(),
66            PanicLocation::to_string
67        )
68    );
69    // we duplicate panics, to visualize it as error on both sides:
70    // - call site of macro
71    // - macro definition position
72    let msg = syn::Error::new(Span::call_site(), message).to_compile_error();
73
74    (msg, e.location)
75}
76
77/// Parse canonical host input text into a local fallback token stream.
78///
79/// Returns:
80/// - `TokenStream` suitable for further parsing
81/// - anchor - span that is used in `output` to filter spans from external source (e.g. embedded `TokenStream::from_str`)
82///
83/// # Errors
84/// - `LexError` - if input is not a valid token stream.
85///
86pub fn parse_input(source: &str) -> Result<(TokenStream, Option<Span>), LexError> {
87    let tokens: TokenStream = source.parse()?;
88    let anchor = first_leaf_span(&tokens);
89    Ok((tokens, anchor))
90}
91
92/// Serialize macro output into text plus flattened leaf-token source ranges.
93#[allow(clippy::needless_pass_by_value, reason = "consume token stream")]
94#[must_use]
95pub fn collect_outputs(
96    tokens: TokenStream,
97    error: Option<(TokenStream, Range<usize>)>,
98    anchor: Option<Span>,
99) -> Output {
100    let mut output = {
101        // collect external tts first
102        let resulted_stream = crate::ux::flush_output(tokens);
103
104        let source_range_fn = |span: Span| source_range(span, anchor);
105        // then regular output with remapped spans
106        Output {
107            text: resulted_stream.to_string(),
108            spans: flatten_leaf_spans(&resulted_stream, &source_range_fn)
109                .into_iter()
110                .map(|range| OutputEntry::new(false, range))
111                .collect(),
112        }
113    };
114
115    // And then panic with location info
116    if let Some((panics, range)) = error {
117        let error_source_ranges = |_| range.clone();
118        output.text.push(' ');
119        output.text.push_str(&panics.to_string());
120
121        output.spans.extend(
122            flatten_leaf_spans(&panics, &error_source_ranges)
123                .into_iter()
124                .map(|range| OutputEntry::new(true, range)),
125        );
126    }
127    output
128}
129
130fn first_leaf_span(tokens: &TokenStream) -> Option<Span> {
131    for token in tokens.clone() {
132        match token {
133            TokenTree::Group(group) => {
134                if let Some(span) = first_leaf_span(&group.stream()) {
135                    return Some(span);
136                }
137            }
138            TokenTree::Ident(ident) => return Some(ident.span()),
139            TokenTree::Punct(punct) => return Some(punct.span()),
140            TokenTree::Literal(literal) => return Some(literal.span()),
141        }
142    }
143    None
144}
145
146fn flatten_leaf_spans(
147    tokens: &TokenStream,
148    source_range_fn: &dyn Fn(Span) -> Range<usize>,
149) -> Vec<Range<usize>> {
150    let mut spans = Vec::new();
151    collect_leaf_spans(tokens, &mut spans, source_range_fn);
152    spans
153}
154
155fn collect_leaf_spans(
156    tokens: &TokenStream,
157    spans: &mut Vec<Range<usize>>,
158    source_range_fn: &dyn Fn(Span) -> Range<usize>,
159) {
160    for token in tokens.clone() {
161        match token {
162            TokenTree::Group(group) => collect_leaf_spans(&group.stream(), spans, source_range_fn),
163            TokenTree::Ident(ident) => spans.push(source_range_fn(ident.span())),
164            TokenTree::Punct(punct) => spans.push(source_range_fn(punct.span())),
165            TokenTree::Literal(literal) => spans.push(source_range_fn(literal.span())),
166        }
167    }
168}
169const CALL_SITE_RANGE: Range<usize> = 0..0;
170fn source_range(span: Span, anchor: Option<Span>) -> Range<usize> {
171    let range = span.byte_range();
172    if range.is_empty() {
173        return CALL_SITE_RANGE;
174    }
175
176    match anchor {
177        Some(anchor) if anchor.join(span).is_some() => range,
178        // Either it already call_site, or it is just span from "virtual file"
179        // in both cases we map it to call_site
180        _ => CALL_SITE_RANGE,
181    }
182}
183
184// A hack that provide extra info
185mod panic {
186    use core::fmt;
187    use std::{
188        any::Any,
189        cell::RefCell,
190        fmt::Display,
191        panic::{self, AssertUnwindSafe, PanicHookInfo},
192    };
193
194    #[derive(Debug)]
195    pub struct PanicLocation {
196        pub file: String,
197        pub line: u32,
198        pub column: u32,
199    }
200    impl Display for PanicLocation {
201        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202            write!(f, "{}:{}:{}", self.file, self.line, self.column)
203        }
204    }
205    #[derive(Debug)]
206    pub struct PanicReport {
207        pub message: String,
208        pub location: Option<PanicLocation>,
209    }
210
211    thread_local! {
212        static LAST_PANIC: RefCell<Option<PanicReport>> = const {RefCell::new(None)};
213    }
214
215    fn panic_payload_to_string(payload: &(dyn Any + Send)) -> String {
216        if let Some(s) = payload.downcast_ref::<&str>() {
217            s.to_string()
218        } else if let Some(s) = payload.downcast_ref::<String>() {
219            s.clone()
220        } else {
221            "<non-string panic payload>".to_string()
222        }
223    }
224
225    /// Install a panic hook for processing panics, return old one.
226    fn install_panic_hook() -> Box<dyn Fn(&PanicHookInfo<'_>) + 'static + Sync + Send> {
227        let default_hook = panic::take_hook();
228
229        panic::set_hook(Box::new(move |info: &PanicHookInfo<'_>| {
230            let message = panic_payload_to_string(info.payload());
231
232            let location = info.location().map(|loc| PanicLocation {
233                file: loc.file().to_string(),
234                line: loc.line(),
235                column: loc.column(),
236            });
237
238            LAST_PANIC.with(|slot| {
239                *slot.borrow_mut() = Some(PanicReport { message, location });
240            });
241        }));
242        default_hook
243    }
244
245    pub fn run_and_catch<F, R>(f: F) -> Result<R, PanicReport>
246    where
247        F: FnOnce() -> R,
248    {
249        LAST_PANIC.with(|slot| {
250            *slot.borrow_mut() = None;
251        });
252
253        let old_hook = install_panic_hook();
254
255        let res = match panic::catch_unwind(AssertUnwindSafe(f)) {
256            Ok(value) => Ok(value),
257
258            Err(payload) => {
259                let fallback_message = panic_payload_to_string(payload.as_ref());
260
261                let report = LAST_PANIC.with(|slot| slot.borrow_mut().take());
262
263                Err(report.unwrap_or(PanicReport {
264                    message: fallback_message,
265                    location: None,
266                }))
267            }
268        };
269
270        panic::set_hook(old_hook);
271        res
272    }
273}
274
275#[cfg(test)]
276#[allow(clippy::single_range_in_vec_init)]
277mod tests {
278    use std::str::FromStr as _;
279
280    use proc_macro2::{Literal, TokenTree};
281    use syn::Ident;
282
283    use super::*;
284
285    fn single_literal(tokens: &TokenStream) -> Literal {
286        match tokens.clone().into_iter().next().expect("one token") {
287            TokenTree::Literal(literal) => literal,
288            other => panic!("expected literal, got {other:?}"),
289        }
290    }
291
292    #[test]
293    fn input_anchor_joins_input_span() {
294        let input = parse_input("12").expect("valid token stream");
295        let literal = single_literal(&input.0);
296        assert!(input.1.unwrap().join(literal.span()).is_some());
297    }
298
299    #[test]
300    fn unrelated_parse_does_not_join_input_anchor() {
301        let input = parse_input("12").expect("valid token stream");
302        let generated = TokenStream::from_str(
303            "
304        12
305    ",
306        )
307        .unwrap();
308        let literal = single_literal(&generated);
309        assert!(input.1.unwrap().join(literal.span()).is_none());
310    }
311
312    #[test]
313    fn output_maps_unrelated_spans_to_call_site() {
314        let input = parse_input("12").expect("valid token stream");
315        let generated = TokenStream::from_str(
316            "
317        12
318    ",
319        )
320        .unwrap();
321        let out = collect_outputs(generated, None, input.1);
322        assert_eq!(out.spans, [OutputEntry::new(false, 0..0)]);
323    }
324
325    #[test]
326    fn output_preserves_input_relative_spans() {
327        let input = parse_input("12").expect("valid token stream");
328        let out = collect_outputs(input.0.clone(), None, input.1);
329        assert_eq!(out.text, "12");
330        assert_eq!(out.spans, [OutputEntry::new(false, 0..2)]);
331    }
332
333    #[test]
334    fn panic_capture_hook_captures_panic() {
335        let report = panic::run_and_catch(|| {
336            panic!("test panic");
337        })
338        .unwrap_err();
339        assert_eq!(report.message, "test panic");
340        assert!(report.location.is_some());
341    }
342
343    #[test]
344    fn test_entry_full_flow() {
345        let input = "12";
346        let body = |mut input: TokenStream| {
347            input.extend([Ident::new("foo", Span::call_site())]);
348            input
349        };
350        let output = entry(input, body);
351        assert_eq!(output.text, "12 foo");
352        assert_eq!(
353            output.spans,
354            [OutputEntry::new(false, 0..2), OutputEntry::new(false, 0..0)]
355        );
356    }
357
358    #[test]
359    fn test_checks_that_entry_cleanup() {
360        let input = "12";
361        let body = |mut input: TokenStream| {
362            input.extend([Ident::new("foo", Span::call_site())]);
363            input
364        };
365        let output = entry(input, body);
366        assert_eq!(output.text, "12 foo");
367        // second call should produce same output
368        let output = entry(input, body);
369        assert_eq!(output.text, "12 foo");
370        assert_eq!(
371            output.spans,
372            [OutputEntry::new(false, 0..2), OutputEntry::new(false, 0..0)]
373        );
374    }
375
376    #[test]
377    #[should_panic(expected = "second panic")]
378    fn test_check_panic_hook_cleanup() {
379        let input = "12";
380        let body = |_: TokenStream| {
381            panic!("test panic");
382        };
383        let output = entry(input, body);
384        assert!(output.text.contains("test panic"));
385        panic!("second panic");
386    }
387}