cogno_attr/
lib.rs

1use debug::debug_enabled;
2use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
3use std::fs::File;
4use std::io;
5use std::io::Read;
6use std::path::Path;
7
8mod debug;
9mod module_ref;
10
11/// Mark a function as a Cogno test.
12///
13/// This attribute accepts a `spec` parameter which groups the test under a specification identifier.
14/// The attribute is optional but recommended!
15///
16/// ```
17/// #[cogno_test(spec = "rfc-1034")]
18/// fn example_test() {}
19/// ```
20///
21/// If you choose not to group your tests by specification because what you are testing is standalone then
22///
23/// ```
24/// #[cogno_test]
25/// fn example_test() {}
26/// ```
27///
28/// Your test should use the provided assertion macros like `should_eq!` and avoid panicking unless necessary.
29/// That means you should avoid Rust's `assert_eq!` and other test assertion macros.
30/// However, a program failing to start or being unable to open a file would be valid reasons to panic and fail the test.
31#[proc_macro_attribute]
32pub fn cogno_test(attr: TokenStream, item: TokenStream) -> TokenStream {
33    if debug_enabled() {
34        println!("cogno_test attr => {}", attr.to_string());
35        println!("cogno_test => {}", item.to_string());
36    }
37
38    let mut spec_id = String::new();
39    let mut header_src = String::new();
40    let mut attr_iter = attr.into_iter();
41    if let Some(TokenTree::Ident(id)) = attr_iter.next() {
42        match id.to_string().as_str() {
43            "spec" => {
44                attr_iter.next();
45                if let Some(TokenTree::Literal(id)) = attr_iter.next() {
46                    spec_id = id.to_string();
47                    header_src.push_str(
48                        format!(
49                            r#"
50                        if !controller.lock().unwrap().is_spec_enabled({}) {{
51                            cogno::tracing::event!(cogno::tracing::Level::INFO, "skipped");
52                            return;
53                        }}
54                        "#,
55                            spec_id
56                        )
57                            .as_str(),
58                    );
59                }
60            }
61            _ => {
62                panic!("Unrecognised syntax in test attribute");
63            }
64        }
65    }
66
67    // Default to literal empty string if no spec_id provided.
68    // This is expected to be unusual, but is supported
69    if spec_id.is_empty() {
70        spec_id.push_str("\"\"");
71    }
72
73    let mut ret = TokenStream::new();
74
75    let mut fn_found = false;
76    let mut param_injected = false;
77    let mut fn_name = String::new();
78    for token in item {
79        if !param_injected {
80            if !fn_found {
81                if token.to_string() == "fn" {
82                    fn_found = true;
83                }
84                ret.extend(Some(token));
85                continue;
86            }
87
88            match token {
89                TokenTree::Group(g) => {
90                    if g.delimiter() == Delimiter::Parenthesis {
91                        ret.extend(to_token_stream("(controller: &mut std::sync::Arc<std::sync::Mutex<cogno::TestController>>)"));
92                        param_injected = true;
93                    } else {
94                        panic!("unexpected group after test function name");
95                    }
96                }
97                TokenTree::Ident(i) => {
98                    fn_name = i.to_string();
99                    ret.extend(Some(TokenTree::Ident(i)));
100                }
101                other => {
102                    ret.extend(Some(other));
103                }
104            }
105            continue;
106        }
107
108        let group = match token {
109            TokenTree::Group(g) => {
110                if g.delimiter() == Delimiter::Brace {
111                    g
112                } else {
113                    ret.extend(Some(TokenTree::Group(g)));
114                    continue;
115                }
116            }
117            other => {
118                ret.extend(Some(other));
119                continue;
120            }
121        };
122
123        let mut new_body = TokenStream::new();
124
125        let mut group_stream = group.stream().into_iter().peekable();
126        while let Some(tt) = group_stream.peek() {
127            match tt.to_string().as_str() {
128                "should_eq" | "should_not_eq" | "must_eq" | "must_not_eq" | "may_eq" => {
129                    new_body.extend(group_stream.next());
130
131                    if group_stream.peek().is_some()
132                        && group_stream.peek().unwrap().to_string() == "!"
133                    {
134                        new_body.extend(group_stream.next());
135                        match group_stream.next() {
136                            Some(TokenTree::Group(g)) => {
137                                let mut new_group = TokenStream::new();
138                                new_group.extend(to_token_stream("controller_thread_ref,"));
139                                new_group.extend(g.stream());
140
141                                new_body.extend(Some(TokenTree::from(Group::new(
142                                    g.delimiter(),
143                                    new_group,
144                                ))));
145                            }
146                            _ => {
147                                panic!("expected arguments after assertion macro");
148                            }
149                        }
150                    } else {
151                        panic!("identifier conflicts with an assertion macro")
152                    }
153                }
154                _ => {
155                    new_body.extend(group_stream.next());
156                }
157            }
158        }
159
160        let mut traced_header_src = String::new();
161        traced_header_src.push_str(format!(r#"
162        let span = cogno::tracing::span!(cogno::tracing::Level::INFO, "{}");
163        let _enter = span.enter();
164        cogno::tracing::event!(cogno::tracing::Level::INFO, "enter");
165        {}
166        "#, fn_name, header_src).as_str());
167
168        let wrapped_body = to_token_stream(
169            format!(
170                r#"
171            {}
172            controller.lock().unwrap().register("{}", {});
173
174    let controller_thread_ref = controller.clone();
175
176    let result = std::thread::Builder::new()
177    .name("{}".to_string())
178    .spawn(move || {{
179        std::panic::catch_unwind(move || {{
180                {}
181            }})
182        }}).unwrap().join().unwrap();
183
184        cogno::tracing::event!(cogno::tracing::Level::INFO, "exit");
185        match result {{
186            Ok(_) => {{
187                controller.lock().unwrap().complete();
188            }}
189            _ => {{}}
190        }};
191        "#,
192                traced_header_src,
193                fn_name,
194                spec_id,
195                fn_name,
196                new_body.to_string()
197            )
198                .as_str(),
199        );
200
201        ret.extend(Some(TokenTree::from(Group::new(
202            group.delimiter(),
203            wrapped_body,
204        ))));
205    }
206
207    if debug_enabled() {
208        println!("cogno_test transformed => {}", ret.to_string());
209    }
210
211    ret
212}
213
214/// Generate the main function to run Cogno tests.
215///
216/// The entry point of your program should be marked with this attributed and be empty
217///
218/// ```
219/// #[cogno_main]
220/// fn main() {}
221/// ```
222///
223/// The generated code will include a `TestController` and invocations of each of your test functions.
224#[proc_macro_attribute]
225pub fn cogno_main(_: TokenStream, item: TokenStream) -> TokenStream {
226    if debug_enabled() {
227        println!("cogno_main => {}", item.to_string());
228    }
229
230    let manifest_path = option_env!("COGNO_MANIFEST");
231    // TODO check up to date? should really always run with cargo cogno but could just run with cargo run
232    if manifest_path.is_none() {
233        panic!("Run with `cargo cogno`")
234    }
235
236    let manifest = load_manifest(manifest_path.unwrap()).unwrap();
237
238    if debug_enabled() {
239        println!("manifest => {:?}", manifest);
240    }
241
242    let mut ret = String::new();
243    ret.push_str("fn main() {");
244
245    ret.push_str(r#"
246    if "true" == std::env::var("COGNO_TRACE").unwrap_or(String::from("false")).as_str() {
247        let sub = cogno::tracing_subscriber::FmtSubscriber::new();
248        cogno::tracing::subscriber::set_global_default(sub)
249            .expect("setting tracing default failed");
250    }
251    let span = cogno::tracing::span!(cogno::tracing::Level::INFO, "cogno_main");
252    let _enter = span.enter();
253    cogno::tracing::event!(cogno::tracing::Level::INFO, "starting");
254
255    let mut controller = std::sync::Arc::new(std::sync::Mutex::new(cogno::TestController::new().unwrap()));
256    "#);
257
258    ret.push_str(
259        r#"
260    let controller_panic_ref = controller.clone();
261    std::panic::set_hook(Box::new(move |info| {
262        cogno::tracing::event!(cogno::tracing::Level::INFO, "captured a panic - {}", info);
263        let mut controller_handle = controller_panic_ref.lock().unwrap();
264        controller_handle.set_panic_info(info.to_string());
265    }));
266    "#,
267    );
268
269    for module_ref in manifest {
270        ret.push_str(format!("{}", module_ref.to_source()).as_str());
271    }
272
273    ret.push_str(r#"
274    cogno::tracing::event!(cogno::tracing::Level::INFO, "finishing report");
275    let finalize_result = controller.lock().unwrap().finalize();
276    finalize_result.unwrap();
277    cogno::tracing::event!(cogno::tracing::Level::INFO, "done");
278    "#);
279    ret.push_str("}");
280
281    let ret = to_token_stream(ret.as_str());
282
283    if debug_enabled() {
284        println!("cogno_main transformed => {}", ret.to_string());
285    }
286
287    ret
288}
289
290fn to_token_stream(code: &str) -> TokenStream {
291    code.parse().unwrap()
292}
293
294fn load_manifest<P: AsRef<Path>>(source: P) -> Result<Vec<module_ref::ModuleRef>, io::Error> {
295    let mut content = String::new();
296    File::open(source)?.read_to_string(&mut content)?;
297    let module_refs = serde_json::from_str(content.as_str())?;
298    Ok(module_refs)
299}