1use std::{ops::Range, panic::UnwindSafe};
5
6use proc_macro2::{LexError, Span, TokenStream, TokenTree};
7
8use crate::wire::panic::{PanicLocation, PanicReport};
9
10pub struct Output {
13 pub text: String,
14 pub spans: Vec<OutputEntry>,
15}
16#[derive(PartialEq, Eq, Debug)]
19pub struct OutputEntry {
20 pub is_panic: bool,
21 pub range: Range<usize>,
22}
23impl OutputEntry {
24 pub fn new(is_panic: bool, range: Range<usize>) -> Self {
25 Self { is_panic, range }
26 }
27}
28
29#[allow(clippy::missing_panics_doc, reason = "panic is in catch block")]
37pub fn entry(input: &str, body: impl FnOnce(TokenStream) -> TokenStream + UnwindSafe) -> Output {
38 let mut panics = None;
39
40 let (tokens, anchor) = panic::run_and_catch(|| {
41 let (input, anchor) = parse_input(input).unwrap();
42 (body(input), anchor)
43 })
44 .unwrap_or_else(|e| {
45 let (msg, location) = panic_to_compile_error(e);
46 panics = Some((
47 msg.clone(),
48 location .map(|loc| Range {
50 start: loc.line as usize,
51 end: loc.column as usize,
52 })
53 .unwrap_or_default(),
54 ));
55 (TokenStream::new(), None)
56 });
57
58 collect_outputs(tokens, panics, anchor)
59}
60fn panic_to_compile_error(e: PanicReport) -> (TokenStream, Option<PanicLocation>) {
61 let message = format!(
62 "panic in charm (at {location}): {error}",
63 error = e.message,
64 location = e.location.as_ref().map_or_else(
65 || "<unknown location>".to_string(),
66 PanicLocation::to_string
67 )
68 );
69 let msg = syn::Error::new(Span::call_site(), message).to_compile_error();
73
74 (msg, e.location)
75}
76
77pub fn parse_input(source: &str) -> Result<(TokenStream, Option<Span>), LexError> {
87 let tokens: TokenStream = source.parse()?;
88 let anchor = first_leaf_span(&tokens);
89 Ok((tokens, anchor))
90}
91
92#[allow(clippy::needless_pass_by_value, reason = "consume token stream")]
94#[must_use]
95pub fn collect_outputs(
96 tokens: TokenStream,
97 error: Option<(TokenStream, Range<usize>)>,
98 anchor: Option<Span>,
99) -> Output {
100 let mut output = {
101 let resulted_stream = crate::ux::flush_output(tokens);
103
104 let source_range_fn = |span: Span| source_range(span, anchor);
105 Output {
107 text: resulted_stream.to_string(),
108 spans: flatten_leaf_spans(&resulted_stream, &source_range_fn)
109 .into_iter()
110 .map(|range| OutputEntry::new(false, range))
111 .collect(),
112 }
113 };
114
115 if let Some((panics, range)) = error {
117 let error_source_ranges = |_| range.clone();
118 output.text.push(' ');
119 output.text.push_str(&panics.to_string());
120
121 output.spans.extend(
122 flatten_leaf_spans(&panics, &error_source_ranges)
123 .into_iter()
124 .map(|range| OutputEntry::new(true, range)),
125 );
126 }
127 output
128}
129
130fn first_leaf_span(tokens: &TokenStream) -> Option<Span> {
131 for token in tokens.clone() {
132 match token {
133 TokenTree::Group(group) => {
134 if let Some(span) = first_leaf_span(&group.stream()) {
135 return Some(span);
136 }
137 }
138 TokenTree::Ident(ident) => return Some(ident.span()),
139 TokenTree::Punct(punct) => return Some(punct.span()),
140 TokenTree::Literal(literal) => return Some(literal.span()),
141 }
142 }
143 None
144}
145
146fn flatten_leaf_spans(
147 tokens: &TokenStream,
148 source_range_fn: &dyn Fn(Span) -> Range<usize>,
149) -> Vec<Range<usize>> {
150 let mut spans = Vec::new();
151 collect_leaf_spans(tokens, &mut spans, source_range_fn);
152 spans
153}
154
155fn collect_leaf_spans(
156 tokens: &TokenStream,
157 spans: &mut Vec<Range<usize>>,
158 source_range_fn: &dyn Fn(Span) -> Range<usize>,
159) {
160 for token in tokens.clone() {
161 match token {
162 TokenTree::Group(group) => collect_leaf_spans(&group.stream(), spans, source_range_fn),
163 TokenTree::Ident(ident) => spans.push(source_range_fn(ident.span())),
164 TokenTree::Punct(punct) => spans.push(source_range_fn(punct.span())),
165 TokenTree::Literal(literal) => spans.push(source_range_fn(literal.span())),
166 }
167 }
168}
169const CALL_SITE_RANGE: Range<usize> = 0..0;
170fn source_range(span: Span, anchor: Option<Span>) -> Range<usize> {
171 let range = span.byte_range();
172 if range.is_empty() {
173 return CALL_SITE_RANGE;
174 }
175
176 match anchor {
177 Some(anchor) if anchor.join(span).is_some() => range,
178 _ => CALL_SITE_RANGE,
181 }
182}
183
184mod panic {
186 use core::fmt;
187 use std::{
188 any::Any,
189 cell::RefCell,
190 fmt::Display,
191 panic::{self, AssertUnwindSafe, PanicHookInfo},
192 };
193
194 #[derive(Debug)]
195 pub struct PanicLocation {
196 pub file: String,
197 pub line: u32,
198 pub column: u32,
199 }
200 impl Display for PanicLocation {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 write!(f, "{}:{}:{}", self.file, self.line, self.column)
203 }
204 }
205 #[derive(Debug)]
206 pub struct PanicReport {
207 pub message: String,
208 pub location: Option<PanicLocation>,
209 }
210
211 thread_local! {
212 static LAST_PANIC: RefCell<Option<PanicReport>> = const {RefCell::new(None)};
213 }
214
215 fn panic_payload_to_string(payload: &(dyn Any + Send)) -> String {
216 if let Some(s) = payload.downcast_ref::<&str>() {
217 s.to_string()
218 } else if let Some(s) = payload.downcast_ref::<String>() {
219 s.clone()
220 } else {
221 "<non-string panic payload>".to_string()
222 }
223 }
224
225 fn install_panic_hook() -> Box<dyn Fn(&PanicHookInfo<'_>) + 'static + Sync + Send> {
227 let default_hook = panic::take_hook();
228
229 panic::set_hook(Box::new(move |info: &PanicHookInfo<'_>| {
230 let message = panic_payload_to_string(info.payload());
231
232 let location = info.location().map(|loc| PanicLocation {
233 file: loc.file().to_string(),
234 line: loc.line(),
235 column: loc.column(),
236 });
237
238 LAST_PANIC.with(|slot| {
239 *slot.borrow_mut() = Some(PanicReport { message, location });
240 });
241 }));
242 default_hook
243 }
244
245 pub fn run_and_catch<F, R>(f: F) -> Result<R, PanicReport>
246 where
247 F: FnOnce() -> R,
248 {
249 LAST_PANIC.with(|slot| {
250 *slot.borrow_mut() = None;
251 });
252
253 let old_hook = install_panic_hook();
254
255 let res = match panic::catch_unwind(AssertUnwindSafe(f)) {
256 Ok(value) => Ok(value),
257
258 Err(payload) => {
259 let fallback_message = panic_payload_to_string(payload.as_ref());
260
261 let report = LAST_PANIC.with(|slot| slot.borrow_mut().take());
262
263 Err(report.unwrap_or(PanicReport {
264 message: fallback_message,
265 location: None,
266 }))
267 }
268 };
269
270 panic::set_hook(old_hook);
271 res
272 }
273}
274
275#[cfg(test)]
276#[allow(clippy::single_range_in_vec_init)]
277mod tests {
278 use std::str::FromStr as _;
279
280 use proc_macro2::{Literal, TokenTree};
281 use syn::Ident;
282
283 use super::*;
284
285 fn single_literal(tokens: &TokenStream) -> Literal {
286 match tokens.clone().into_iter().next().expect("one token") {
287 TokenTree::Literal(literal) => literal,
288 other => panic!("expected literal, got {other:?}"),
289 }
290 }
291
292 #[test]
293 fn input_anchor_joins_input_span() {
294 let input = parse_input("12").expect("valid token stream");
295 let literal = single_literal(&input.0);
296 assert!(input.1.unwrap().join(literal.span()).is_some());
297 }
298
299 #[test]
300 fn unrelated_parse_does_not_join_input_anchor() {
301 let input = parse_input("12").expect("valid token stream");
302 let generated = TokenStream::from_str(
303 "
304 12
305 ",
306 )
307 .unwrap();
308 let literal = single_literal(&generated);
309 assert!(input.1.unwrap().join(literal.span()).is_none());
310 }
311
312 #[test]
313 fn output_maps_unrelated_spans_to_call_site() {
314 let input = parse_input("12").expect("valid token stream");
315 let generated = TokenStream::from_str(
316 "
317 12
318 ",
319 )
320 .unwrap();
321 let out = collect_outputs(generated, None, input.1);
322 assert_eq!(out.spans, [OutputEntry::new(false, 0..0)]);
323 }
324
325 #[test]
326 fn output_preserves_input_relative_spans() {
327 let input = parse_input("12").expect("valid token stream");
328 let out = collect_outputs(input.0.clone(), None, input.1);
329 assert_eq!(out.text, "12");
330 assert_eq!(out.spans, [OutputEntry::new(false, 0..2)]);
331 }
332
333 #[test]
334 fn panic_capture_hook_captures_panic() {
335 let report = panic::run_and_catch(|| {
336 panic!("test panic");
337 })
338 .unwrap_err();
339 assert_eq!(report.message, "test panic");
340 assert!(report.location.is_some());
341 }
342
343 #[test]
344 fn test_entry_full_flow() {
345 let input = "12";
346 let body = |mut input: TokenStream| {
347 input.extend([Ident::new("foo", Span::call_site())]);
348 input
349 };
350 let output = entry(input, body);
351 assert_eq!(output.text, "12 foo");
352 assert_eq!(
353 output.spans,
354 [OutputEntry::new(false, 0..2), OutputEntry::new(false, 0..0)]
355 );
356 }
357
358 #[test]
359 fn test_checks_that_entry_cleanup() {
360 let input = "12";
361 let body = |mut input: TokenStream| {
362 input.extend([Ident::new("foo", Span::call_site())]);
363 input
364 };
365 let output = entry(input, body);
366 assert_eq!(output.text, "12 foo");
367 let output = entry(input, body);
369 assert_eq!(output.text, "12 foo");
370 assert_eq!(
371 output.spans,
372 [OutputEntry::new(false, 0..2), OutputEntry::new(false, 0..0)]
373 );
374 }
375
376 #[test]
377 #[should_panic(expected = "second panic")]
378 fn test_check_panic_hook_cleanup() {
379 let input = "12";
380 let body = |_: TokenStream| {
381 panic!("test panic");
382 };
383 let output = entry(input, body);
384 assert!(output.text.contains("test panic"));
385 panic!("second panic");
386 }
387}