digital_test_runner/
data_row_iterator.rs

1use std::collections::HashMap;
2
3use crate::{
4    errors::{ExprError, IterationError, RuntimeError, RuntimeErrorKind},
5    stmt::DataEntries,
6    DataEntry, DataRow, EntryIndex, EvalContext, ExpectedEntry, ExpectedValue, InputEntry,
7    InputValue, OutputEntry, OutputResultEntry, OutputValue, Signal, SignalType, StmtIterator,
8    TestCase, TestDriver,
9};
10
11#[derive(Debug)]
12/// An iterator over the test results for a dynamic test
13pub struct DataRowIterator<'a, 'b, T> {
14    ctx: EvalContext,
15    test_data: DataRowIteratorTestData<'a>,
16    driver: &'b mut T,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20enum OutputEntryIndex<'a> {
21    None,
22    Output(usize),
23    Virtual(&'a crate::expr::Expr),
24}
25
26#[derive(Debug)]
27struct DataRowIteratorTestData<'a> {
28    signals: &'a [Signal],
29    iter: StmtIterator<'a>,
30    /// List of inputs which links signals to test entries
31    input_indices: &'a [EntryIndex],
32    /// List of expected values which links signals to test entries
33    expected_indices: &'a [EntryIndex],
34    /// List with the same number of entries as `expected_indices`.
35    /// Each non-trivial entry is an index into the output vec from the driver.
36    output_indices: Vec<OutputEntryIndex<'a>>,
37    prev: Option<Vec<DataEntry>>,
38    cache: Vec<DataEntries>,
39}
40
41#[derive(Debug)]
42struct EvaluatedRow<'a> {
43    line: usize,
44    inputs: Vec<InputEntry<'a>>,
45    expected: Vec<ExpectedEntry<'a>>,
46    update_output: bool,
47}
48
49impl EntryIndex {
50    pub(crate) fn signal_index(&self) -> usize {
51        match self {
52            EntryIndex::Entry {
53                entry_index: _,
54                signal_index,
55            } => *signal_index,
56            EntryIndex::Default { signal_index } => *signal_index,
57        }
58    }
59
60    pub(crate) fn indexes(&self, entry_index: usize) -> bool {
61        match self {
62            EntryIndex::Entry {
63                entry_index: i,
64                signal_index: _,
65            } => *i == entry_index,
66            EntryIndex::Default { signal_index: _ } => false,
67        }
68    }
69}
70
71impl<'a, 'b, T: TestDriver> Iterator for DataRowIterator<'a, 'b, T> {
72    type Item = Result<DataRow<'a>, IterationError<T::Error>>;
73
74    fn next(&mut self) -> Option<Self::Item> {
75        let row = match self.test_data.get_row(&mut self.ctx) {
76            Ok(val) => val?,
77            Err(err) => {
78                return Some(Err(IterationError::Runtime(
79                    RuntimeErrorKind::ExprError(err).into(),
80                )))
81            }
82        };
83
84        match self.handle_io(&row.inputs, row.update_output) {
85            Ok(outputs) => Some(Ok(row.into_data_row(outputs))),
86            Err(err) => Some(Err(err)),
87        }
88    }
89}
90
91impl<'a, 'b, T: TestDriver> DataRowIterator<'a, 'b, T> {
92    pub(crate) fn try_new(
93        test_case: &'a TestCase,
94        driver: &'b mut T,
95    ) -> Result<Self, IterationError<T::Error>> {
96        let mut test_data = DataRowIteratorTestData::new(test_case);
97
98        let inputs = test_data.generate_default_input_entries();
99        let outputs = driver.write_input_and_read_output(&inputs)?;
100        test_data.build_output_indices(&outputs, &test_case.read_outputs)?;
101
102        let ctx = EvalContext::new_with_outputs(&outputs);
103
104        Ok(Self {
105            ctx,
106            test_data,
107            driver,
108        })
109    }
110
111    fn handle_io(
112        &mut self,
113        inputs: &[InputEntry<'a>],
114        update_output: bool,
115    ) -> Result<Vec<OutputValue>, IterationError<T::Error>> {
116        if update_output {
117            let outputs = self.driver.write_input_and_read_output(inputs)?;
118            self.ctx.set_outputs(&outputs);
119            self.test_data.extract_output_values(outputs, &mut self.ctx)
120        } else {
121            self.driver.write_input(inputs)?;
122            Ok(vec![])
123        }
124    }
125
126    /// The current value of all variables
127    pub fn vars(&self) -> HashMap<String, i64> {
128        self.ctx.vars()
129    }
130}
131
132impl<'a> DataRowIteratorTestData<'a> {
133    fn generate_default_input_entries(&self) -> Vec<InputEntry<'a>> {
134        self.input_indices
135            .iter()
136            .map(|index| {
137                let signal = &self.signals[index.signal_index()];
138                let value = signal.default_value().unwrap();
139                InputEntry {
140                    signal,
141                    value,
142                    changed: false,
143                }
144            })
145            .collect()
146    }
147
148    fn generate_input_entries(
149        &self,
150        stmt_entries: &[DataEntry],
151        changed: &[bool],
152    ) -> Vec<InputEntry<'a>> {
153        self.input_indices
154            .iter()
155            .map(|index| match index {
156                EntryIndex::Entry {
157                    entry_index,
158                    signal_index,
159                } => {
160                    let signal = &self.signals[*signal_index];
161                    let value = match &stmt_entries[*entry_index] {
162                        DataEntry::Number(n) => InputValue::Value(n & ((1 << signal.bits) - 1)),
163                        DataEntry::Z => InputValue::Z,
164                        _ => unreachable!(),
165                    };
166                    let changed = changed[*entry_index];
167                    InputEntry {
168                        signal,
169                        value,
170                        changed,
171                    }
172                }
173                EntryIndex::Default { signal_index } => {
174                    let signal = &self.signals[*signal_index];
175                    InputEntry {
176                        signal,
177                        value: signal.default_value().unwrap(),
178                        changed: false,
179                    }
180                }
181            })
182            .collect()
183    }
184
185    fn generate_expected_entries(&self, stmt_entries: &[DataEntry]) -> Vec<ExpectedEntry<'a>> {
186        self.expected_indices
187            .iter()
188            .map(|index| match index {
189                EntryIndex::Entry {
190                    entry_index,
191                    signal_index,
192                } => {
193                    let signal = &self.signals[*signal_index];
194                    let value = match &stmt_entries[*entry_index] {
195                        DataEntry::Number(n) => {
196                            let mask = if signal.bits < 64 {
197                                (1 << signal.bits) - 1
198                            } else {
199                                -1
200                            };
201                            ExpectedValue::Value(n & mask)
202                        }
203                        DataEntry::Z => ExpectedValue::Z,
204                        DataEntry::X => ExpectedValue::X,
205                        _ => unreachable!(),
206                    };
207                    ExpectedEntry { signal, value }
208                }
209                EntryIndex::Default { signal_index } => {
210                    let signal = &self.signals[*signal_index];
211                    ExpectedEntry {
212                        signal,
213                        value: ExpectedValue::X,
214                    }
215                }
216            })
217            .collect()
218    }
219
220    fn build_output_indices<E: std::error::Error>(
221        &mut self,
222        outputs: &[OutputEntry<'_>],
223        read_outputs: &[usize],
224    ) -> Result<(), IterationError<E>> {
225        let mut output_indices = Vec::with_capacity(outputs.len());
226        let mut found_outputs = vec![];
227
228        for expected_index in self.expected_indices.iter() {
229            let signal = &self.signals[expected_index.signal_index()];
230            let entry = if let SignalType::Virtual { expr } = &signal.typ {
231                OutputEntryIndex::Virtual(&expr.expr)
232            } else if let Some(n) = outputs.iter().position(|output| output.signal == signal) {
233                OutputEntryIndex::Output(n)
234            } else {
235                OutputEntryIndex::None
236            };
237            output_indices.push(entry);
238            if matches!(entry, OutputEntryIndex::Output(_)) {
239                found_outputs.push(expected_index.signal_index());
240            }
241        }
242
243        let missing = read_outputs
244            .iter()
245            .filter_map(|read| {
246                if !found_outputs.contains(read) {
247                    Some(self.signals[*read].name.clone())
248                } else {
249                    None
250                }
251            })
252            .collect::<Vec<_>>();
253        if missing.is_empty() {
254            self.output_indices = output_indices;
255            Ok(())
256        } else {
257            Err(IterationError::Runtime(
258                RuntimeErrorKind::MissingOutputs(missing.join(", ")).into(),
259            ))
260        }
261    }
262
263    fn num_outputs(&self) -> usize {
264        self.output_indices
265            .iter()
266            .filter(|i| matches!(i, OutputEntryIndex::Output(_)))
267            .count()
268    }
269
270    fn extract_output_values<E: std::error::Error>(
271        &self,
272        outputs: Vec<OutputEntry<'_>>,
273        ctx: &mut EvalContext,
274    ) -> Result<Vec<OutputValue>, IterationError<E>> {
275        let num_outputs = self.num_outputs();
276
277        if outputs.len() != num_outputs {
278            return Err(IterationError::Runtime(
279                RuntimeErrorKind::WrongNumberOfOutputs(num_outputs, outputs.len()).into(),
280            ));
281        }
282
283        ctx.swap_vars();
284        let result = self
285            .expected_indices
286            .iter()
287            .zip(&self.output_indices)
288            .map(|(expected_index, output_index)| match *output_index {
289                OutputEntryIndex::Output(output_entry_index) => {
290                    let expected_signal = &self.signals[expected_index.signal_index()];
291                    let output_signal = outputs[output_entry_index].signal;
292
293                    if expected_signal == output_signal {
294                        Ok(outputs[output_entry_index].value)
295                    } else {
296                        Err(IterationError::Runtime(
297                            RuntimeErrorKind::WrongOutputOrder.into(),
298                        ))
299                    }
300                }
301                OutputEntryIndex::Virtual(expr) => expr
302                    .eval(ctx)
303                    .map(OutputValue::Value)
304                    .map_err(|err| IterationError::Runtime(RuntimeError(err.into()))),
305                OutputEntryIndex::None => Ok(OutputValue::X),
306            })
307            .collect();
308        ctx.swap_vars();
309        result
310    }
311
312    fn new(test_case: &'a TestCase) -> Self {
313        DataRowIteratorTestData {
314            iter: StmtIterator::new(&test_case.stmts),
315            signals: &test_case.signals,
316            input_indices: &test_case.input_indices,
317            expected_indices: &test_case.expected_indices,
318            output_indices: vec![],
319            prev: None,
320            cache: vec![],
321        }
322    }
323
324    fn check_changed_entries(&self, stmt_entries: &[DataEntry]) -> Vec<bool> {
325        if let Some(prev) = &self.prev {
326            stmt_entries
327                .iter()
328                .zip(prev)
329                .map(|(new, old)| new != old)
330                .collect()
331        } else {
332            vec![true; stmt_entries.len()]
333        }
334    }
335
336    fn entry_is_input(&self, entry_index: usize) -> bool {
337        self.input_indices
338            .iter()
339            .any(|entry| entry.indexes(entry_index))
340    }
341
342    fn expand_x(&mut self) {
343        loop {
344            let row_result = self
345                .cache
346                .last()
347                .expect("cache should be refilled before calling expand_x");
348
349            let Some(x_index) =
350                row_result
351                    .entries
352                    .iter()
353                    .enumerate()
354                    .rev()
355                    .find_map(|(i, entry)| {
356                        if entry == &DataEntry::X && self.entry_is_input(i) {
357                            Some(i)
358                        } else {
359                            None
360                        }
361                    })
362            else {
363                break;
364            };
365            let mut row_result = self.cache.pop().unwrap();
366            row_result.entries[x_index] = DataEntry::Number(1);
367            self.cache.push(row_result.clone());
368            row_result.entries[x_index] = DataEntry::Number(0);
369            self.cache.push(row_result);
370        }
371    }
372
373    fn expand_c(&mut self) {
374        let mut row_result = self
375            .cache
376            .pop()
377            .expect("cache should be refilled before calling expand_c");
378
379        let c_indices = row_result
380            .entries
381            .iter()
382            .enumerate()
383            .filter_map(|(i, entry)| {
384                if entry == &DataEntry::C && self.entry_is_input(i) {
385                    Some(i)
386                } else {
387                    None
388                }
389            })
390            .collect::<Vec<_>>();
391
392        if c_indices.is_empty() {
393            self.cache.push(row_result);
394        } else {
395            for &i in &c_indices {
396                row_result.entries[i] = DataEntry::Number(0);
397            }
398            self.cache.push(row_result.clone());
399            for entry_index in self.expected_indices {
400                match entry_index {
401                    EntryIndex::Entry {
402                        entry_index,
403                        signal_index: _,
404                    } => row_result.entries[*entry_index] = DataEntry::X,
405                    EntryIndex::Default { signal_index: _ } => continue,
406                }
407            }
408            row_result.update_output = false;
409            for &i in &c_indices {
410                row_result.entries[i] = DataEntry::Number(1);
411            }
412            self.cache.push(row_result.clone());
413            for &i in &c_indices {
414                row_result.entries[i] = DataEntry::Number(0);
415            }
416            self.cache.push(row_result);
417        }
418    }
419
420    fn get_row(&mut self, ctx: &mut EvalContext) -> Result<Option<EvaluatedRow<'a>>, ExprError> {
421        if self.cache.is_empty() {
422            let Some(row_result) = self.iter.next_with_context(ctx)? else {
423                return Ok(None);
424            };
425            self.cache.push(row_result);
426        }
427
428        self.expand_x();
429        self.expand_c();
430
431        let row_result = self.cache.pop().unwrap();
432
433        let changed = self.check_changed_entries(&row_result.entries);
434
435        let inputs = self.generate_input_entries(&row_result.entries, &changed);
436
437        let expected = self.generate_expected_entries(&row_result.entries);
438
439        let line = row_result.line;
440        let update_output = row_result.update_output;
441
442        self.prev = Some(row_result.entries);
443
444        Ok(Some(EvaluatedRow {
445            line,
446            inputs,
447            expected,
448            update_output,
449        }))
450    }
451}
452
453impl<'a> EvaluatedRow<'a> {
454    fn into_data_row(self, outputs: Vec<OutputValue>) -> DataRow<'a> {
455        let outputs = self
456            .expected
457            .into_iter()
458            .zip(outputs)
459            .map(|(expected_entry, output_value)| OutputResultEntry {
460                signal: expected_entry.signal,
461                output: output_value,
462                expected: expected_entry.value,
463            })
464            .collect();
465
466        DataRow {
467            inputs: self.inputs,
468            outputs,
469            line: self.line,
470        }
471    }
472}