1use debug::debug_enabled;
2use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
3use std::fs::File;
4use std::io;
5use std::io::Read;
6use std::path::Path;
7
8mod debug;
9mod module_ref;
10
11#[proc_macro_attribute]
32pub fn cogno_test(attr: TokenStream, item: TokenStream) -> TokenStream {
33 if debug_enabled() {
34 println!("cogno_test attr => {}", attr.to_string());
35 println!("cogno_test => {}", item.to_string());
36 }
37
38 let mut spec_id = String::new();
39 let mut header_src = String::new();
40 let mut attr_iter = attr.into_iter();
41 if let Some(TokenTree::Ident(id)) = attr_iter.next() {
42 match id.to_string().as_str() {
43 "spec" => {
44 attr_iter.next();
45 if let Some(TokenTree::Literal(id)) = attr_iter.next() {
46 spec_id = id.to_string();
47 header_src.push_str(
48 format!(
49 r#"
50 if !controller.lock().unwrap().is_spec_enabled({}) {{
51 cogno::tracing::event!(cogno::tracing::Level::INFO, "skipped");
52 return;
53 }}
54 "#,
55 spec_id
56 )
57 .as_str(),
58 );
59 }
60 }
61 _ => {
62 panic!("Unrecognised syntax in test attribute");
63 }
64 }
65 }
66
67 if spec_id.is_empty() {
70 spec_id.push_str("\"\"");
71 }
72
73 let mut ret = TokenStream::new();
74
75 let mut fn_found = false;
76 let mut param_injected = false;
77 let mut fn_name = String::new();
78 for token in item {
79 if !param_injected {
80 if !fn_found {
81 if token.to_string() == "fn" {
82 fn_found = true;
83 }
84 ret.extend(Some(token));
85 continue;
86 }
87
88 match token {
89 TokenTree::Group(g) => {
90 if g.delimiter() == Delimiter::Parenthesis {
91 ret.extend(to_token_stream("(controller: &mut std::sync::Arc<std::sync::Mutex<cogno::TestController>>)"));
92 param_injected = true;
93 } else {
94 panic!("unexpected group after test function name");
95 }
96 }
97 TokenTree::Ident(i) => {
98 fn_name = i.to_string();
99 ret.extend(Some(TokenTree::Ident(i)));
100 }
101 other => {
102 ret.extend(Some(other));
103 }
104 }
105 continue;
106 }
107
108 let group = match token {
109 TokenTree::Group(g) => {
110 if g.delimiter() == Delimiter::Brace {
111 g
112 } else {
113 ret.extend(Some(TokenTree::Group(g)));
114 continue;
115 }
116 }
117 other => {
118 ret.extend(Some(other));
119 continue;
120 }
121 };
122
123 let mut new_body = TokenStream::new();
124
125 let mut group_stream = group.stream().into_iter().peekable();
126 while let Some(tt) = group_stream.peek() {
127 match tt.to_string().as_str() {
128 "should_eq" | "should_not_eq" | "must_eq" | "must_not_eq" | "may_eq" => {
129 new_body.extend(group_stream.next());
130
131 if group_stream.peek().is_some()
132 && group_stream.peek().unwrap().to_string() == "!"
133 {
134 new_body.extend(group_stream.next());
135 match group_stream.next() {
136 Some(TokenTree::Group(g)) => {
137 let mut new_group = TokenStream::new();
138 new_group.extend(to_token_stream("controller_thread_ref,"));
139 new_group.extend(g.stream());
140
141 new_body.extend(Some(TokenTree::from(Group::new(
142 g.delimiter(),
143 new_group,
144 ))));
145 }
146 _ => {
147 panic!("expected arguments after assertion macro");
148 }
149 }
150 } else {
151 panic!("identifier conflicts with an assertion macro")
152 }
153 }
154 _ => {
155 new_body.extend(group_stream.next());
156 }
157 }
158 }
159
160 let mut traced_header_src = String::new();
161 traced_header_src.push_str(format!(r#"
162 let span = cogno::tracing::span!(cogno::tracing::Level::INFO, "{}");
163 let _enter = span.enter();
164 cogno::tracing::event!(cogno::tracing::Level::INFO, "enter");
165 {}
166 "#, fn_name, header_src).as_str());
167
168 let wrapped_body = to_token_stream(
169 format!(
170 r#"
171 {}
172 controller.lock().unwrap().register("{}", {});
173
174 let controller_thread_ref = controller.clone();
175
176 let result = std::thread::Builder::new()
177 .name("{}".to_string())
178 .spawn(move || {{
179 std::panic::catch_unwind(move || {{
180 {}
181 }})
182 }}).unwrap().join().unwrap();
183
184 cogno::tracing::event!(cogno::tracing::Level::INFO, "exit");
185 match result {{
186 Ok(_) => {{
187 controller.lock().unwrap().complete();
188 }}
189 _ => {{}}
190 }};
191 "#,
192 traced_header_src,
193 fn_name,
194 spec_id,
195 fn_name,
196 new_body.to_string()
197 )
198 .as_str(),
199 );
200
201 ret.extend(Some(TokenTree::from(Group::new(
202 group.delimiter(),
203 wrapped_body,
204 ))));
205 }
206
207 if debug_enabled() {
208 println!("cogno_test transformed => {}", ret.to_string());
209 }
210
211 ret
212}
213
214#[proc_macro_attribute]
225pub fn cogno_main(_: TokenStream, item: TokenStream) -> TokenStream {
226 if debug_enabled() {
227 println!("cogno_main => {}", item.to_string());
228 }
229
230 let manifest_path = option_env!("COGNO_MANIFEST");
231 if manifest_path.is_none() {
233 panic!("Run with `cargo cogno`")
234 }
235
236 let manifest = load_manifest(manifest_path.unwrap()).unwrap();
237
238 if debug_enabled() {
239 println!("manifest => {:?}", manifest);
240 }
241
242 let mut ret = String::new();
243 ret.push_str("fn main() {");
244
245 ret.push_str(r#"
246 if "true" == std::env::var("COGNO_TRACE").unwrap_or(String::from("false")).as_str() {
247 let sub = cogno::tracing_subscriber::FmtSubscriber::new();
248 cogno::tracing::subscriber::set_global_default(sub)
249 .expect("setting tracing default failed");
250 }
251 let span = cogno::tracing::span!(cogno::tracing::Level::INFO, "cogno_main");
252 let _enter = span.enter();
253 cogno::tracing::event!(cogno::tracing::Level::INFO, "starting");
254
255 let mut controller = std::sync::Arc::new(std::sync::Mutex::new(cogno::TestController::new().unwrap()));
256 "#);
257
258 ret.push_str(
259 r#"
260 let controller_panic_ref = controller.clone();
261 std::panic::set_hook(Box::new(move |info| {
262 cogno::tracing::event!(cogno::tracing::Level::INFO, "captured a panic - {}", info);
263 let mut controller_handle = controller_panic_ref.lock().unwrap();
264 controller_handle.set_panic_info(info.to_string());
265 }));
266 "#,
267 );
268
269 for module_ref in manifest {
270 ret.push_str(format!("{}", module_ref.to_source()).as_str());
271 }
272
273 ret.push_str(r#"
274 cogno::tracing::event!(cogno::tracing::Level::INFO, "finishing report");
275 let finalize_result = controller.lock().unwrap().finalize();
276 finalize_result.unwrap();
277 cogno::tracing::event!(cogno::tracing::Level::INFO, "done");
278 "#);
279 ret.push_str("}");
280
281 let ret = to_token_stream(ret.as_str());
282
283 if debug_enabled() {
284 println!("cogno_main transformed => {}", ret.to_string());
285 }
286
287 ret
288}
289
290fn to_token_stream(code: &str) -> TokenStream {
291 code.parse().unwrap()
292}
293
294fn load_manifest<P: AsRef<Path>>(source: P) -> Result<Vec<module_ref::ModuleRef>, io::Error> {
295 let mut content = String::new();
296 File::open(source)?.read_to_string(&mut content)?;
297 let module_refs = serde_json::from_str(content.as_str())?;
298 Ok(module_refs)
299}