1use std::collections::HashMap;
2use std::sync::Arc;
3
4use condition::TestCondition;
5use verb::TestVerb;
6
7pub mod arguments;
8pub mod condition;
9pub mod error;
10pub mod test_case;
11pub mod verb;
12pub use kdl;
13
14pub struct TestDsl<H> {
15 verbs: HashMap<String, Box<dyn TestVerb<H>>>,
16 conditions: HashMap<String, Box<dyn condition::TestCondition<H>>>,
17}
18
19impl<H> std::fmt::Debug for TestDsl<H> {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("TestDsl").finish_non_exhaustive()
22 }
23}
24
25impl<H: 'static> Default for TestDsl<H> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<H: 'static> TestDsl<H> {
32 pub fn new() -> Self {
33 TestDsl {
34 verbs: HashMap::default(),
35 conditions: HashMap::default(),
36 }
37 }
38
39 pub fn add_verb(&mut self, name: impl AsRef<str>, verb: impl TestVerb<H>) {
40 let existing = self.verbs.insert(name.as_ref().to_string(), Box::new(verb));
41 assert!(existing.is_none());
42 }
43
44 pub fn add_condition(
45 &mut self,
46 name: impl AsRef<str>,
47 condition: impl condition::TestCondition<H>,
48 ) {
49 let existing = self
50 .conditions
51 .insert(name.as_ref().to_string(), Box::new(condition));
52
53 assert!(existing.is_none());
54 }
55
56 pub fn parse_document(
57 &self,
58 input: miette::NamedSource<Arc<str>>,
59 ) -> Result<Vec<test_case::TestCase<H>>, error::TestParseError> {
60 let document = kdl::KdlDocument::parse(input.inner())?;
61
62 let mut cases = vec![];
63
64 let mut errors = vec![];
65
66 for testcase_node in document.nodes() {
67 if testcase_node.name().value() != "testcase" {
68 errors.push(error::TestErrorCase::NotTestcase {
69 span: testcase_node.name().span(),
70 });
71
72 continue;
73 }
74
75 let mut testcase = test_case::TestCase::new(input.clone());
76
77 for verb_node in testcase_node.iter_children() {
78 match self.parse_verb(verb_node) {
79 Ok(verb) => testcase.creators.push(verb),
80 Err(e) => errors.push(e),
81 }
82 }
83
84 cases.push(testcase);
85 }
86
87 if !errors.is_empty() {
88 return Err(error::TestParseError {
89 errors,
90 source_code: Some(input.clone()),
91 });
92 }
93
94 Ok(cases)
95 }
96
97 fn parse_condition(
98 &self,
99 condition_node: &kdl::KdlNode,
100 ) -> Result<Box<dyn TestConditionCreator<H>>, error::TestErrorCase> {
101 self.conditions
102 .get(condition_node.name().value())
103 .ok_or_else(|| error::TestErrorCase::UnknownCondition {
104 condition: condition_node.name().span(),
105 })
106 .map(|cond| {
107 Box::new(DirectCondition {
108 condition: cond.clone(),
109 node: condition_node.clone(),
110 }) as Box<_>
111 })
112 }
113
114 fn parse_verb(
115 &self,
116 verb_node: &kdl::KdlNode,
117 ) -> Result<Box<dyn TestVerbCreator<H>>, error::TestErrorCase> {
118 match verb_node.name().value() {
119 "repeat" => {
120 let times = verb_node
121 .get(0)
122 .ok_or_else(|| error::TestErrorCase::MissingArgument {
123 parent: verb_node.name().span(),
124 missing: String::from("`repeat` takes one argument, the repetition count"),
125 })?
126 .as_integer()
127 .ok_or_else(|| error::TestErrorCase::WrongArgumentType {
128 parent: verb_node.name().span(),
129 argument: verb_node.iter().next().unwrap().span(),
130 expected: String::from("Expected an integer"),
131 })? as usize;
132
133 let block = verb_node
134 .iter_children()
135 .map(|node| self.parse_verb(node))
136 .collect::<Result<_, _>>()?;
137
138 Ok(Box::new(Repeat { times, block }))
139 }
140 "group" => Ok(Box::new(Group {
141 block: verb_node
142 .iter_children()
143 .map(|n| self.parse_verb(n))
144 .collect::<Result<_, _>>()?,
145 })),
146 "assert" => Ok(Box::new(AssertConditions {
147 conditions: verb_node
148 .iter_children()
149 .map(|node| self.parse_condition(node))
150 .collect::<Result<_, _>>()?,
151 })),
152 name => {
153 let verb = self
154 .verbs
155 .get(name)
156 .ok_or_else(|| error::TestErrorCase::UnknownVerb {
157 verb: verb_node.name().span(),
158 })?
159 .clone();
160
161 Ok(Box::new(DirectVerb {
162 verb,
163 node: verb_node.clone(),
164 }))
165 }
166 }
167 }
168}
169
170trait TestVerbCreator<H> {
171 fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_>;
172}
173
174trait TestConditionCreator<H> {
175 fn get_test_conditions(&self) -> Box<dyn Iterator<Item = TestConditionInstance<'_, H>> + '_>;
176}
177
178struct Group<H> {
179 block: Vec<Box<dyn TestVerbCreator<H>>>,
180}
181
182impl<H: 'static> TestVerbCreator<H> for Group<H> {
183 fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
184 Box::new(self.block.iter().flat_map(|c| c.get_test_verbs()))
185 }
186}
187
188struct Repeat<H> {
189 times: usize,
190 block: Vec<Box<dyn TestVerbCreator<H>>>,
191}
192
193impl<H: 'static> TestVerbCreator<H> for Repeat<H> {
194 fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
195 Box::new(
196 std::iter::repeat_with(|| self.block.iter().flat_map(|c| c.get_test_verbs()))
197 .take(self.times)
198 .flatten(),
199 )
200 }
201}
202
203struct DirectVerb<H> {
204 verb: Box<dyn TestVerb<H>>,
205 node: kdl::KdlNode,
206}
207
208impl<H: 'static> TestVerbCreator<H> for DirectVerb<H> {
209 fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
210 Box::new(std::iter::once(TestVerbInstance {
211 verb: self.verb.clone(),
212 node: &self.node,
213 }))
214 }
215}
216
217struct DirectCondition<H> {
218 condition: Box<dyn TestCondition<H>>,
219 node: kdl::KdlNode,
220}
221
222impl<H: 'static> TestConditionCreator<H> for DirectCondition<H> {
223 fn get_test_conditions(&self) -> Box<dyn Iterator<Item = TestConditionInstance<'_, H>> + '_> {
224 Box::new(std::iter::once(TestConditionInstance {
225 condition: self.condition.clone(),
226 node: &self.node,
227 }))
228 }
229}
230
231struct AssertConditions<H> {
232 conditions: Vec<Box<dyn TestConditionCreator<H>>>,
233}
234
235impl<H: 'static> TestVerbCreator<H> for AssertConditions<H> {
236 fn get_test_verbs(&self) -> Box<dyn Iterator<Item = TestVerbInstance<'_, H>> + '_> {
237 Box::new(
238 self.conditions
239 .iter()
240 .flat_map(|cond| cond.get_test_conditions())
241 .map(|cond| TestVerbInstance {
242 node: cond.node,
243 verb: Box::new(AssertVerb {
244 condition: cond.condition,
245 }),
246 }),
247 )
248 }
249}
250
251struct AssertVerb<H> {
252 condition: Box<dyn TestCondition<H>>,
253}
254
255impl<H: 'static> TestVerb<H> for AssertVerb<H> {
256 fn run(&self, harness: &mut H, node: &kdl::KdlNode) -> Result<(), error::TestErrorCase> {
257 self.condition.check_now(harness, node).and_then(|res| {
258 res.then_some(())
259 .ok_or_else(|| error::TestErrorCase::Error {
260 error: miette::miette!("Assert failed"),
261 label: node.span(),
262 })
263 })
264 }
265
266 fn clone_box(&self) -> Box<dyn TestVerb<H>> {
267 Box::new(self.clone())
268 }
269}
270
271impl<H: 'static> Clone for AssertVerb<H> {
272 fn clone(&self) -> Self {
273 AssertVerb {
274 condition: self.condition.clone(),
275 }
276 }
277}
278
279struct TestConditionInstance<'p, H> {
280 condition: Box<dyn TestCondition<H>>,
281 node: &'p kdl::KdlNode,
282}
283
284struct TestVerbInstance<'p, H> {
285 verb: Box<dyn TestVerb<H>>,
286 node: &'p kdl::KdlNode,
287}
288
289impl<'p, H: 'static> TestVerbInstance<'p, H> {
290 fn run<'h>(&'p self, harness: &'h mut H) -> Result<(), error::TestErrorCase> {
291 self.verb.run(harness, self.node)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use std::sync::Arc;
298 use std::sync::atomic::AtomicUsize;
299
300 use miette::NamedSource;
301
302 use crate::TestDsl;
303 use crate::verb::FunctionVerb;
304
305 struct ArithmeticHarness {
306 value: AtomicUsize,
307 }
308
309 #[test]
310 fn simple_test() {
311 let mut ts = TestDsl::<ArithmeticHarness>::new();
312 ts.add_verb(
313 "add_one",
314 FunctionVerb::new(|ah: &mut ArithmeticHarness| {
315 ah.value.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
316
317 Ok(())
318 }),
319 );
320
321 ts.add_verb(
322 "mul_two",
323 FunctionVerb::new(|ah: &mut ArithmeticHarness| {
324 let value = ah.value.load(std::sync::atomic::Ordering::SeqCst);
325 ah.value
326 .store(value * 2, std::sync::atomic::Ordering::SeqCst);
327 Ok(())
328 }),
329 );
330
331 let tc = ts
332 .parse_document(NamedSource::new(
333 "test.kdl",
334 Arc::from(
335 r#"
336 testcase {
337 add_one
338 add_one
339 mul_two
340 }
341 "#,
342 ),
343 ))
344 .unwrap();
345
346 let mut ah = ArithmeticHarness {
347 value: AtomicUsize::new(0),
348 };
349
350 tc[0].run(&mut ah).unwrap();
351
352 assert_eq!(ah.value.load(std::sync::atomic::Ordering::SeqCst), 4);
353 }
354
355 #[test]
356 fn repeat_test() {
357 let mut ts = TestDsl::<ArithmeticHarness>::new();
358 ts.add_verb(
359 "add_one",
360 FunctionVerb::new(|ah: &mut ArithmeticHarness| {
361 ah.value.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
362
363 Ok(())
364 }),
365 );
366
367 ts.add_verb(
368 "mul_two",
369 FunctionVerb::new(|ah: &mut ArithmeticHarness| {
370 let value = ah.value.load(std::sync::atomic::Ordering::SeqCst);
371 ah.value
372 .store(value * 2, std::sync::atomic::Ordering::SeqCst);
373
374 Ok(())
375 }),
376 );
377
378 let tc = ts
379 .parse_document(NamedSource::new(
380 "test.kdl",
381 Arc::from(
382 r#"
383 testcase {
384 repeat 2 {
385 repeat 2 {
386 add_one
387 mul_two
388 }
389 }
390 }
391 "#,
392 ),
393 ))
394 .unwrap();
395
396 let mut ah = ArithmeticHarness {
397 value: AtomicUsize::new(0),
398 };
399
400 tc[0].run(&mut ah).unwrap();
401
402 assert_eq!(ah.value.load(std::sync::atomic::Ordering::SeqCst), 30);
403 }
404
405 #[test]
406 fn check_arguments_work() {
407 let mut ts = TestDsl::<ArithmeticHarness>::new();
408 ts.add_verb(
409 "add_one",
410 FunctionVerb::new(|ah: &mut ArithmeticHarness| {
411 ah.value.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
412
413 Ok(())
414 }),
415 );
416
417 ts.add_verb(
418 "add",
419 FunctionVerb::new(|ah: &mut ArithmeticHarness, num: usize| {
420 ah.value.fetch_add(num, std::sync::atomic::Ordering::SeqCst);
421 Ok(())
422 }),
423 );
424
425 ts.add_verb(
426 "mul_two",
427 FunctionVerb::new(|ah: &mut ArithmeticHarness| {
428 let value = ah.value.load(std::sync::atomic::Ordering::SeqCst);
429 ah.value
430 .store(value * 2, std::sync::atomic::Ordering::SeqCst);
431 Ok(())
432 }),
433 );
434
435 let tc = ts
436 .parse_document(NamedSource::new(
437 "test.kdl",
438 Arc::from(
439 r#"
440 testcase {
441 repeat 2 {
442 repeat 2 {
443 group {
444 add 2
445 mul_two
446 }
447 }
448 }
449 }
450 "#,
451 ),
452 ))
453 .unwrap();
454
455 let mut ah = ArithmeticHarness {
456 value: AtomicUsize::new(0),
457 };
458
459 tc[0].run(&mut ah).unwrap();
460
461 assert_eq!(ah.value.load(std::sync::atomic::Ordering::SeqCst), 60);
462 }
463}