test_dsl/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use condition::TestCondition;
5use verb::TestVerb;
6
7pub mod arguments;
8pub mod condition;
9pub mod error;
10pub mod test_case;
11pub mod verb;
12pub use kdl;
13
14pub struct TestDsl<H> {
15    verbs: HashMap<String, Box<dyn TestVerb<H>>>,
16    conditions: HashMap<String, Box<dyn condition::TestCondition<H>>>,
17}
18
19impl<H> std::fmt::Debug for TestDsl<H> {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("TestDsl").finish_non_exhaustive()
22    }
23}
24
25impl<H: 'static> Default for TestDsl<H> {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl<H: 'static> TestDsl<H> {
32    pub fn new() -> Self {
33        TestDsl {
34            verbs: HashMap::default(),
35            conditions: HashMap::default(),
36        }
37    }
38
39    pub fn add_verb(&mut self, name: impl AsRef<str>, verb: impl TestVerb<H>) {
40        let existing = self.verbs.insert(name.as_ref().to_string(), Box::new(verb));
41        assert!(existing.is_none());
42    }
43
44    pub fn add_condition(
45        &mut self,
46        name: impl AsRef<str>,
47        condition: impl condition::TestCondition<H>,
48    ) {
49        let existing = self
50            .conditions
51            .insert(name.as_ref().to_string(), Box::new(condition));
52
53        assert!(existing.is_none());
54    }
55
56    pub fn parse_document(
57        &self,
58        input: miette::NamedSource<Arc<str>>,
59    ) -> Result<Vec<test_case::TestCase<H>>, error::TestParseError> {
60        let document = kdl::KdlDocument::parse(input.inner())?;
61
62        let mut cases = vec![];
63
64        let mut errors = vec![];
65
66        for testcase_node in document.nodes() {
67            if testcase_node.name().value() != "testcase" {
68                errors.push(error::TestErrorCase::NotTestcase {
69                    span: testcase_node.name().span(),
70                });
71
72                continue;
73            }
74
75            let mut testcase = test_case::TestCase::new(input.clone());
76
77            for verb_node in testcase_node.iter_children() {
78                match self.parse_verb(verb_node) {
79                    Ok(verb) => testcase.creators.push(verb),
80                    Err(e) => errors.push(e),
81                }
82            }
83
84            cases.push(testcase);
85        }
86
87        if !errors.is_empty() {
88            return Err(error::TestParseError {
89                errors,
90                source_code: Some(input.clone()),
91            });
92        }
93
94        Ok(cases)
95    }
96
97    fn parse_condition(
98        &self,
99        condition_node: &kdl::KdlNode,
100    ) -> Result<Box<dyn TestConditionCreator<H>>, error::TestErrorCase> {
101        self.conditions
102            .get(condition_node.name().value())
103            .ok_or_else(|| error::TestErrorCase::UnknownCondition {
104                condition: condition_node.name().span(),
105            })
106            .map(|cond| {
107                Box::new(DirectCondition {
108                    condition: cond.clone(),
109                    node: condition_node.clone(),
110                }) as Box<_>
111            })
112    }
113
114    fn parse_verb(
115        &self,
116        verb_node: &kdl::KdlNode,
117    ) -> Result<Box<dyn TestVerbCreator<H>>, error::TestErrorCase> {
118        match verb_node.name().value() {
119            "repeat" => {
120                let times = verb_node
121                    .get(0)
122                    .ok_or_else(|| error::TestErrorCase::MissingArgument {
123                        parent: verb_node.name().span(),
124                        missing: String::from("`repeat` takes one argument, the repetition count"),
125                    })?
126                    .as_integer()
127                    .ok_or_else(|| error::TestErrorCase::WrongArgumentType {
128                        parent: verb_node.name().span(),
129                        argument: verb_node.iter().next().unwrap().span(),
130                        expected: String::from("Expected an integer"),
131                    })? as usize;
132
133                let block = verb_node
134                    .iter_children()
135                    .map(|node| self.parse_verb(node))
136                    .collect::<Result<_, _>>()?;
137
138                Ok(Box::new(Repeat { times, block }))
139            }
140            "group" => Ok(Box::new(Group {
141                block: verb_node
142                    .iter_children()
143                    .map(|n| self.parse_verb(n))
144                    .collect::<Result<_, _>>()?,
145            })),
146            "assert" => Ok(Box::new(AssertConditions {
147                conditions: verb_node
148                    .iter_children()
149                    .map(|node| self.parse_condition(node))
150                    .collect::<Result<_, _>>()?,
151            })),
152            name => {
153                let verb = self
154                    .verbs
155                    .get(name)
156                    .ok_or_else(|| error::TestErrorCase::UnknownVerb {
157                        verb: verb_node.name().span(),
158                    })?
159                    .clone();
160
161                Ok(Box::new(DirectVerb {
162                    verb,
163                    node: verb_node.clone(),
164                }))
165            }
166        }
167    }
168}
169
170trait TestVerbCreator<H> {
171    fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_>;
172}
173
174trait TestConditionCreator<H> {
175    fn get_test_conditions(&self) -> Box<dyn Iterator<Item = TestConditionInstance<'_, H>> + '_>;
176}
177
178struct Group<H> {
179    block: Vec<Box<dyn TestVerbCreator<H>>>,
180}
181
182impl<H: 'static> TestVerbCreator<H> for Group<H> {
183    fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
184        Box::new(self.block.iter().flat_map(|c| c.get_test_verbs()))
185    }
186}
187
188struct Repeat<H> {
189    times: usize,
190    block: Vec<Box<dyn TestVerbCreator<H>>>,
191}
192
193impl<H: 'static> TestVerbCreator<H> for Repeat<H> {
194    fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
195        Box::new(
196            std::iter::repeat_with(|| self.block.iter().flat_map(|c| c.get_test_verbs()))
197                .take(self.times)
198                .flatten(),
199        )
200    }
201}
202
203struct DirectVerb<H> {
204    verb: Box<dyn TestVerb<H>>,
205    node: kdl::KdlNode,
206}
207
208impl<H: 'static> TestVerbCreator<H> for DirectVerb<H> {
209    fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
210        Box::new(std::iter::once(TestVerbInstance {
211            verb: self.verb.clone(),
212            node: &self.node,
213        }))
214    }
215}
216
217struct DirectCondition<H> {
218    condition: Box<dyn TestCondition<H>>,
219    node: kdl::KdlNode,
220}
221
222impl<H: 'static> TestConditionCreator<H> for DirectCondition<H> {
223    fn get_test_conditions(&self) -> Box<dyn Iterator<Item = TestConditionInstance<'_, H>> + '_> {
224        Box::new(std::iter::once(TestConditionInstance {
225            condition: self.condition.clone(),
226            node: &self.node,
227        }))
228    }
229}
230
231struct AssertConditions<H> {
232    conditions: Vec<Box<dyn TestConditionCreator<H>>>,
233}
234
235impl<H: 'static> TestVerbCreator<H> for AssertConditions<H> {
236    fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
237        Box::new(
238            self.conditions
239                .iter()
240                .flat_map(|cond| cond.get_test_conditions())
241                .map(|cond| TestVerbInstance {
242                    node: cond.node,
243                    verb: Box::new(AssertVerb {
244                        condition: cond.condition,
245                    }),
246                }),
247        )
248    }
249}
250
251struct AssertVerb<H> {
252    condition: Box<dyn TestCondition<H>>,
253}
254
255impl<H: 'static> TestVerb<H> for AssertVerb<H> {
256    fn run(&self, harness: &mut H, node: &kdl::KdlNode) -> Result<(), error::TestErrorCase> {
257        self.condition.check_now(harness, node).and_then(|res| {
258            res.then_some(())
259                .ok_or_else(|| error::TestErrorCase::Error {
260                    error: miette::miette!("Assert failed"),
261                    label: node.span(),
262                })
263        })
264    }
265
266    fn clone_box(&self) -> Box<dyn TestVerb<H>> {
267        Box::new(self.clone())
268    }
269}
270
271impl<H: 'static> Clone for AssertVerb<H> {
272    fn clone(&self) -> Self {
273        AssertVerb {
274            condition: self.condition.clone(),
275        }
276    }
277}
278
279struct TestConditionInstance<'p, H> {
280    condition: Box<dyn TestCondition<H>>,
281    node: &'p kdl::KdlNode,
282}
283
284struct TestVerbInstance<'p, H> {
285    verb: Box<dyn TestVerb<H>>,
286    node: &'p kdl::KdlNode,
287}
288
289impl<'p, H: 'static> TestVerbInstance<'p, H> {
290    fn run<'h>(&'p self, harness: &'h mut H) -> Result<(), error::TestErrorCase> {
291        self.verb.run(harness, self.node)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use std::sync::Arc;
298    use std::sync::atomic::AtomicUsize;
299
300    use miette::NamedSource;
301
302    use crate::TestDsl;
303    use crate::verb::FunctionVerb;
304
305    struct ArithmeticHarness {
306        value: AtomicUsize,
307    }
308
309    #[test]
310    fn simple_test() {
311        let mut ts = TestDsl::<ArithmeticHarness>::new();
312        ts.add_verb(
313            "add_one",
314            FunctionVerb::new(|ah: &mut ArithmeticHarness| {
315                ah.value.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
316
317                Ok(())
318            }),
319        );
320
321        ts.add_verb(
322            "mul_two",
323            FunctionVerb::new(|ah: &mut ArithmeticHarness| {
324                let value = ah.value.load(std::sync::atomic::Ordering::SeqCst);
325                ah.value
326                    .store(value * 2, std::sync::atomic::Ordering::SeqCst);
327                Ok(())
328            }),
329        );
330
331        let tc = ts
332            .parse_document(NamedSource::new(
333                "test.kdl",
334                Arc::from(
335                    r#"
336            testcase {
337                add_one
338                add_one
339                mul_two
340            }
341            "#,
342                ),
343            ))
344            .unwrap();
345
346        let mut ah = ArithmeticHarness {
347            value: AtomicUsize::new(0),
348        };
349
350        tc[0].run(&mut ah).unwrap();
351
352        assert_eq!(ah.value.load(std::sync::atomic::Ordering::SeqCst), 4);
353    }
354
355    #[test]
356    fn repeat_test() {
357        let mut ts = TestDsl::<ArithmeticHarness>::new();
358        ts.add_verb(
359            "add_one",
360            FunctionVerb::new(|ah: &mut ArithmeticHarness| {
361                ah.value.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
362
363                Ok(())
364            }),
365        );
366
367        ts.add_verb(
368            "mul_two",
369            FunctionVerb::new(|ah: &mut ArithmeticHarness| {
370                let value = ah.value.load(std::sync::atomic::Ordering::SeqCst);
371                ah.value
372                    .store(value * 2, std::sync::atomic::Ordering::SeqCst);
373
374                Ok(())
375            }),
376        );
377
378        let tc = ts
379            .parse_document(NamedSource::new(
380                "test.kdl",
381                Arc::from(
382                    r#"
383            testcase {
384                repeat 2 {
385                    repeat 2 {
386                        add_one
387                        mul_two
388                    }
389                }
390            }
391            "#,
392                ),
393            ))
394            .unwrap();
395
396        let mut ah = ArithmeticHarness {
397            value: AtomicUsize::new(0),
398        };
399
400        tc[0].run(&mut ah).unwrap();
401
402        assert_eq!(ah.value.load(std::sync::atomic::Ordering::SeqCst), 30);
403    }
404
405    #[test]
406    fn check_arguments_work() {
407        let mut ts = TestDsl::<ArithmeticHarness>::new();
408        ts.add_verb(
409            "add_one",
410            FunctionVerb::new(|ah: &mut ArithmeticHarness| {
411                ah.value.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
412
413                Ok(())
414            }),
415        );
416
417        ts.add_verb(
418            "add",
419            FunctionVerb::new(|ah: &mut ArithmeticHarness, num: usize| {
420                ah.value.fetch_add(num, std::sync::atomic::Ordering::SeqCst);
421                Ok(())
422            }),
423        );
424
425        ts.add_verb(
426            "mul_two",
427            FunctionVerb::new(|ah: &mut ArithmeticHarness| {
428                let value = ah.value.load(std::sync::atomic::Ordering::SeqCst);
429                ah.value
430                    .store(value * 2, std::sync::atomic::Ordering::SeqCst);
431                Ok(())
432            }),
433        );
434
435        let tc = ts
436            .parse_document(NamedSource::new(
437                "test.kdl",
438                Arc::from(
439                    r#"
440            testcase {
441                repeat 2 {
442                    repeat 2 {
443                        group {
444                            add 2
445                            mul_two
446                        }
447                    }
448                }
449            }
450            "#,
451                ),
452            ))
453            .unwrap();
454
455        let mut ah = ArithmeticHarness {
456            value: AtomicUsize::new(0),
457        };
458
459        tc[0].run(&mut ah).unwrap();
460
461        assert_eq!(ah.value.load(std::sync::atomic::Ordering::SeqCst), 60);
462    }
463}