1#[cfg(feature = "syn")]
7use std::any::Any;
8use std::fmt::{Debug, Display};
9
10pub struct TestResult<T, E> {
12 inner: Result<T, E>,
13 context: Option<String>,
14 source: Option<String>,
15}
16
17impl<T: Debug, E: Display + Debug + 'static> TestResult<T, E> {
18 pub fn new(result: Result<T, E>) -> Self {
19 Self {
20 inner: result,
21 context: None,
22 source: None,
23 }
24 }
25
26 pub fn with_context(mut self, context: &str) -> Self {
27 self.context = Some(context.to_string());
28 self
29 }
30
31 pub fn with_source(mut self, source: &str) -> Self {
32 self.source = Some(source.to_string());
33 self
34 }
35
36 fn format_context(&self) -> String {
37 self.context
38 .as_ref()
39 .map(|c| format!("\nContext: {}", c))
40 .unwrap_or_default()
41 }
42
43 fn format_err(&self, err: &E) -> String {
44 format_error_impl(err, self.source.as_deref())
45 }
46
47 pub fn assert_success(self) -> T {
49 let ctx = self.format_context();
50 match self.inner {
51 Ok(val) => val,
52 Err(ref e) => {
53 let msg = self.format_err(e);
54 panic!(
55 "\n🔴 TEST FAILED (Expected Success, but got Error):{}\nMessage: {}\nError Debug: {:?}\n",
56 ctx, msg, e
57 );
58 }
59 }
60 }
61
62 pub fn assert_success_is<Exp>(self, expected: Exp) -> T
65 where
66 T: PartialEq<Exp>,
67 Exp: Debug,
68 {
69 let ctx = self.format_context();
70 let val = self.assert_success();
71 if val != expected {
72 panic!(
73 "\n🔴 TEST FAILED (Value Mismatch):{}\nExpected: {:?}\nGot: {:?}\n",
74 ctx, expected, val
75 );
76 }
77 val
78 }
79
80 pub fn assert_success_with<F>(self, f: F) -> T
83 where
84 F: FnOnce(&T),
85 {
86 let val = self.assert_success();
87 f(&val);
88 val
89 }
90
91 pub fn assert_success_debug(self, expected_debug: &str) -> T {
94 let ctx = self.format_context();
95 let val = self.assert_success();
96 let actual_debug = format!("{:?}", val);
97 if actual_debug != expected_debug {
98 panic!(
99 "\n🔴 TEST FAILED (Debug Mismatch):{}\nExpected: {:?}\nGot: {:?}\n",
100 ctx, expected_debug, actual_debug
101 );
102 }
103 val
104 }
105
106 pub fn assert_failure(self) -> E {
108 let ctx = self.format_context();
109 match self.inner {
110 Ok(val) => {
111 panic!(
112 "\n🔴 TEST FAILED (Expected Failure, but got Success):{}\nParsed Value: {:?}\n",
113 ctx, val
114 );
115 }
116 Err(e) => e,
117 }
118 }
119
120 pub fn assert_failure_contains(self, expected_msg_part: &str) {
122 let ctx = self.format_context();
123 let source = self.source.clone();
124 let err = self.assert_failure();
125 let actual_msg = err.to_string();
126 if !actual_msg.contains(expected_msg_part) {
127 let formatted = format_error_impl(&err, source.as_deref());
128 panic!(
129 "\n🔴 TEST FAILED (Error Message Mismatch):{}\nExpected part: {:?}\nActual msg: {:?}\nError Debug: {:?}\nFormatted:\n{}\n",
130 ctx, expected_msg_part, actual_msg, err, formatted
131 );
132 }
133 }
134
135 pub fn assert_success_contains(self, expected_part: &str) -> T
137 where
138 T: Display,
139 {
140 let ctx = self.format_context();
141 let val = self.assert_success();
142 let val_str = val.to_string();
143 if !val_str.contains(expected_part) {
144 panic!(
145 "\n🔴 TEST FAILED (Content Mismatch):{}\nExpected to contain: {:?}\nGot: {:?}\n",
146 ctx, expected_part, val_str
147 );
148 }
149 val
150 }
151}
152
153pub trait Testable<T, E> {
154 fn test(self) -> TestResult<T, E>;
155}
156
157#[cfg(feature = "syn")]
158impl<T: Debug> Testable<T, syn::Error> for syn::Result<T> {
159 fn test(self) -> TestResult<T, syn::Error> {
160 TestResult::new(self)
161 }
162}
163
164fn format_error_impl<E: Display + Debug + 'static>(err: &E, source: Option<&str>) -> String {
165 #[cfg(feature = "syn")]
166 if let Some(src) = source {
167 if let Some(syn_err) = (err as &dyn Any).downcast_ref::<syn::Error>() {
168 return pretty_print_syn_error(syn_err, src);
169 }
170 }
171 format!("{}", err)
172}
173
174#[cfg(feature = "syn")]
175fn pretty_print_syn_error(err: &syn::Error, source: &str) -> String {
176 let start = err.span().start();
177 let end = err.span().end();
178
179 if start.line == 0 {
180 return err.to_string();
181 }
182
183 let line_idx = start.line - 1;
184 let lines: Vec<&str> = source.lines().collect();
185
186 if line_idx >= lines.len() {
187 return err.to_string();
188 }
189
190 let line = lines[line_idx];
191 let col = start.column;
192
193 let width = if start.line == end.line {
197 end.column.saturating_sub(col).max(1)
198 } else {
199 1
200 };
201
202 format!(
203 "{}\n --> line {}:{}\n |\n {} | {}\n | {}{}",
204 err,
205 start.line,
206 col,
207 start.line,
208 line,
209 " ".repeat(col),
210 "^".repeat(width)
211 )
212}