1use std::collections::HashMap;
2
3use crate::{
4 errors::{ExprError, IterationError, RuntimeError, RuntimeErrorKind},
5 stmt::DataEntries,
6 DataEntry, DataRow, EntryIndex, EvalContext, ExpectedEntry, ExpectedValue, InputEntry,
7 InputValue, OutputEntry, OutputResultEntry, OutputValue, Signal, SignalType, StmtIterator,
8 TestCase, TestDriver,
9};
10
11#[derive(Debug)]
12pub struct DataRowIterator<'a, 'b, T> {
14 ctx: EvalContext,
15 test_data: DataRowIteratorTestData<'a>,
16 driver: &'b mut T,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20enum OutputEntryIndex<'a> {
21 None,
22 Output(usize),
23 Virtual(&'a crate::expr::Expr),
24}
25
26#[derive(Debug)]
27struct DataRowIteratorTestData<'a> {
28 signals: &'a [Signal],
29 iter: StmtIterator<'a>,
30 input_indices: &'a [EntryIndex],
32 expected_indices: &'a [EntryIndex],
34 output_indices: Vec<OutputEntryIndex<'a>>,
37 prev: Option<Vec<DataEntry>>,
38 cache: Vec<DataEntries>,
39}
40
41#[derive(Debug)]
42struct EvaluatedRow<'a> {
43 line: usize,
44 inputs: Vec<InputEntry<'a>>,
45 expected: Vec<ExpectedEntry<'a>>,
46 update_output: bool,
47}
48
49impl EntryIndex {
50 pub(crate) fn signal_index(&self) -> usize {
51 match self {
52 EntryIndex::Entry {
53 entry_index: _,
54 signal_index,
55 } => *signal_index,
56 EntryIndex::Default { signal_index } => *signal_index,
57 }
58 }
59
60 pub(crate) fn indexes(&self, entry_index: usize) -> bool {
61 match self {
62 EntryIndex::Entry {
63 entry_index: i,
64 signal_index: _,
65 } => *i == entry_index,
66 EntryIndex::Default { signal_index: _ } => false,
67 }
68 }
69}
70
71impl<'a, 'b, T: TestDriver> Iterator for DataRowIterator<'a, 'b, T> {
72 type Item = Result<DataRow<'a>, IterationError<T::Error>>;
73
74 fn next(&mut self) -> Option<Self::Item> {
75 let row = match self.test_data.get_row(&mut self.ctx) {
76 Ok(val) => val?,
77 Err(err) => {
78 return Some(Err(IterationError::Runtime(
79 RuntimeErrorKind::ExprError(err).into(),
80 )))
81 }
82 };
83
84 match self.handle_io(&row.inputs, row.update_output) {
85 Ok(outputs) => Some(Ok(row.into_data_row(outputs))),
86 Err(err) => Some(Err(err)),
87 }
88 }
89}
90
91impl<'a, 'b, T: TestDriver> DataRowIterator<'a, 'b, T> {
92 pub(crate) fn try_new(
93 test_case: &'a TestCase,
94 driver: &'b mut T,
95 ) -> Result<Self, IterationError<T::Error>> {
96 let mut test_data = DataRowIteratorTestData::new(test_case);
97
98 let inputs = test_data.generate_default_input_entries();
99 let outputs = driver.write_input_and_read_output(&inputs)?;
100 test_data.build_output_indices(&outputs, &test_case.read_outputs)?;
101
102 let ctx = EvalContext::new_with_outputs(&outputs);
103
104 Ok(Self {
105 ctx,
106 test_data,
107 driver,
108 })
109 }
110
111 fn handle_io(
112 &mut self,
113 inputs: &[InputEntry<'a>],
114 update_output: bool,
115 ) -> Result<Vec<OutputValue>, IterationError<T::Error>> {
116 if update_output {
117 let outputs = self.driver.write_input_and_read_output(inputs)?;
118 self.ctx.set_outputs(&outputs);
119 self.test_data.extract_output_values(outputs, &mut self.ctx)
120 } else {
121 self.driver.write_input(inputs)?;
122 Ok(vec![])
123 }
124 }
125
126 pub fn vars(&self) -> HashMap<String, i64> {
128 self.ctx.vars()
129 }
130}
131
132impl<'a> DataRowIteratorTestData<'a> {
133 fn generate_default_input_entries(&self) -> Vec<InputEntry<'a>> {
134 self.input_indices
135 .iter()
136 .map(|index| {
137 let signal = &self.signals[index.signal_index()];
138 let value = signal.default_value().unwrap();
139 InputEntry {
140 signal,
141 value,
142 changed: false,
143 }
144 })
145 .collect()
146 }
147
148 fn generate_input_entries(
149 &self,
150 stmt_entries: &[DataEntry],
151 changed: &[bool],
152 ) -> Vec<InputEntry<'a>> {
153 self.input_indices
154 .iter()
155 .map(|index| match index {
156 EntryIndex::Entry {
157 entry_index,
158 signal_index,
159 } => {
160 let signal = &self.signals[*signal_index];
161 let value = match &stmt_entries[*entry_index] {
162 DataEntry::Number(n) => InputValue::Value(n & ((1 << signal.bits) - 1)),
163 DataEntry::Z => InputValue::Z,
164 _ => unreachable!(),
165 };
166 let changed = changed[*entry_index];
167 InputEntry {
168 signal,
169 value,
170 changed,
171 }
172 }
173 EntryIndex::Default { signal_index } => {
174 let signal = &self.signals[*signal_index];
175 InputEntry {
176 signal,
177 value: signal.default_value().unwrap(),
178 changed: false,
179 }
180 }
181 })
182 .collect()
183 }
184
185 fn generate_expected_entries(&self, stmt_entries: &[DataEntry]) -> Vec<ExpectedEntry<'a>> {
186 self.expected_indices
187 .iter()
188 .map(|index| match index {
189 EntryIndex::Entry {
190 entry_index,
191 signal_index,
192 } => {
193 let signal = &self.signals[*signal_index];
194 let value = match &stmt_entries[*entry_index] {
195 DataEntry::Number(n) => {
196 let mask = if signal.bits < 64 {
197 (1 << signal.bits) - 1
198 } else {
199 -1
200 };
201 ExpectedValue::Value(n & mask)
202 }
203 DataEntry::Z => ExpectedValue::Z,
204 DataEntry::X => ExpectedValue::X,
205 _ => unreachable!(),
206 };
207 ExpectedEntry { signal, value }
208 }
209 EntryIndex::Default { signal_index } => {
210 let signal = &self.signals[*signal_index];
211 ExpectedEntry {
212 signal,
213 value: ExpectedValue::X,
214 }
215 }
216 })
217 .collect()
218 }
219
220 fn build_output_indices<E: std::error::Error>(
221 &mut self,
222 outputs: &[OutputEntry<'_>],
223 read_outputs: &[usize],
224 ) -> Result<(), IterationError<E>> {
225 let mut output_indices = Vec::with_capacity(outputs.len());
226 let mut found_outputs = vec![];
227
228 for expected_index in self.expected_indices.iter() {
229 let signal = &self.signals[expected_index.signal_index()];
230 let entry = if let SignalType::Virtual { expr } = &signal.typ {
231 OutputEntryIndex::Virtual(&expr.expr)
232 } else if let Some(n) = outputs.iter().position(|output| output.signal == signal) {
233 OutputEntryIndex::Output(n)
234 } else {
235 OutputEntryIndex::None
236 };
237 output_indices.push(entry);
238 if matches!(entry, OutputEntryIndex::Output(_)) {
239 found_outputs.push(expected_index.signal_index());
240 }
241 }
242
243 let missing = read_outputs
244 .iter()
245 .filter_map(|read| {
246 if !found_outputs.contains(read) {
247 Some(self.signals[*read].name.clone())
248 } else {
249 None
250 }
251 })
252 .collect::<Vec<_>>();
253 if missing.is_empty() {
254 self.output_indices = output_indices;
255 Ok(())
256 } else {
257 Err(IterationError::Runtime(
258 RuntimeErrorKind::MissingOutputs(missing.join(", ")).into(),
259 ))
260 }
261 }
262
263 fn num_outputs(&self) -> usize {
264 self.output_indices
265 .iter()
266 .filter(|i| matches!(i, OutputEntryIndex::Output(_)))
267 .count()
268 }
269
270 fn extract_output_values<E: std::error::Error>(
271 &self,
272 outputs: Vec<OutputEntry<'_>>,
273 ctx: &mut EvalContext,
274 ) -> Result<Vec<OutputValue>, IterationError<E>> {
275 let num_outputs = self.num_outputs();
276
277 if outputs.len() != num_outputs {
278 return Err(IterationError::Runtime(
279 RuntimeErrorKind::WrongNumberOfOutputs(num_outputs, outputs.len()).into(),
280 ));
281 }
282
283 ctx.swap_vars();
284 let result = self
285 .expected_indices
286 .iter()
287 .zip(&self.output_indices)
288 .map(|(expected_index, output_index)| match *output_index {
289 OutputEntryIndex::Output(output_entry_index) => {
290 let expected_signal = &self.signals[expected_index.signal_index()];
291 let output_signal = outputs[output_entry_index].signal;
292
293 if expected_signal == output_signal {
294 Ok(outputs[output_entry_index].value)
295 } else {
296 Err(IterationError::Runtime(
297 RuntimeErrorKind::WrongOutputOrder.into(),
298 ))
299 }
300 }
301 OutputEntryIndex::Virtual(expr) => expr
302 .eval(ctx)
303 .map(OutputValue::Value)
304 .map_err(|err| IterationError::Runtime(RuntimeError(err.into()))),
305 OutputEntryIndex::None => Ok(OutputValue::X),
306 })
307 .collect();
308 ctx.swap_vars();
309 result
310 }
311
312 fn new(test_case: &'a TestCase) -> Self {
313 DataRowIteratorTestData {
314 iter: StmtIterator::new(&test_case.stmts),
315 signals: &test_case.signals,
316 input_indices: &test_case.input_indices,
317 expected_indices: &test_case.expected_indices,
318 output_indices: vec![],
319 prev: None,
320 cache: vec![],
321 }
322 }
323
324 fn check_changed_entries(&self, stmt_entries: &[DataEntry]) -> Vec<bool> {
325 if let Some(prev) = &self.prev {
326 stmt_entries
327 .iter()
328 .zip(prev)
329 .map(|(new, old)| new != old)
330 .collect()
331 } else {
332 vec![true; stmt_entries.len()]
333 }
334 }
335
336 fn entry_is_input(&self, entry_index: usize) -> bool {
337 self.input_indices
338 .iter()
339 .any(|entry| entry.indexes(entry_index))
340 }
341
342 fn expand_x(&mut self) {
343 loop {
344 let row_result = self
345 .cache
346 .last()
347 .expect("cache should be refilled before calling expand_x");
348
349 let Some(x_index) =
350 row_result
351 .entries
352 .iter()
353 .enumerate()
354 .rev()
355 .find_map(|(i, entry)| {
356 if entry == &DataEntry::X && self.entry_is_input(i) {
357 Some(i)
358 } else {
359 None
360 }
361 })
362 else {
363 break;
364 };
365 let mut row_result = self.cache.pop().unwrap();
366 row_result.entries[x_index] = DataEntry::Number(1);
367 self.cache.push(row_result.clone());
368 row_result.entries[x_index] = DataEntry::Number(0);
369 self.cache.push(row_result);
370 }
371 }
372
373 fn expand_c(&mut self) {
374 let mut row_result = self
375 .cache
376 .pop()
377 .expect("cache should be refilled before calling expand_c");
378
379 let c_indices = row_result
380 .entries
381 .iter()
382 .enumerate()
383 .filter_map(|(i, entry)| {
384 if entry == &DataEntry::C && self.entry_is_input(i) {
385 Some(i)
386 } else {
387 None
388 }
389 })
390 .collect::<Vec<_>>();
391
392 if c_indices.is_empty() {
393 self.cache.push(row_result);
394 } else {
395 for &i in &c_indices {
396 row_result.entries[i] = DataEntry::Number(0);
397 }
398 self.cache.push(row_result.clone());
399 for entry_index in self.expected_indices {
400 match entry_index {
401 EntryIndex::Entry {
402 entry_index,
403 signal_index: _,
404 } => row_result.entries[*entry_index] = DataEntry::X,
405 EntryIndex::Default { signal_index: _ } => continue,
406 }
407 }
408 row_result.update_output = false;
409 for &i in &c_indices {
410 row_result.entries[i] = DataEntry::Number(1);
411 }
412 self.cache.push(row_result.clone());
413 for &i in &c_indices {
414 row_result.entries[i] = DataEntry::Number(0);
415 }
416 self.cache.push(row_result);
417 }
418 }
419
420 fn get_row(&mut self, ctx: &mut EvalContext) -> Result<Option<EvaluatedRow<'a>>, ExprError> {
421 if self.cache.is_empty() {
422 let Some(row_result) = self.iter.next_with_context(ctx)? else {
423 return Ok(None);
424 };
425 self.cache.push(row_result);
426 }
427
428 self.expand_x();
429 self.expand_c();
430
431 let row_result = self.cache.pop().unwrap();
432
433 let changed = self.check_changed_entries(&row_result.entries);
434
435 let inputs = self.generate_input_entries(&row_result.entries, &changed);
436
437 let expected = self.generate_expected_entries(&row_result.entries);
438
439 let line = row_result.line;
440 let update_output = row_result.update_output;
441
442 self.prev = Some(row_result.entries);
443
444 Ok(Some(EvaluatedRow {
445 line,
446 inputs,
447 expected,
448 update_output,
449 }))
450 }
451}
452
453impl<'a> EvaluatedRow<'a> {
454 fn into_data_row(self, outputs: Vec<OutputValue>) -> DataRow<'a> {
455 let outputs = self
456 .expected
457 .into_iter()
458 .zip(outputs)
459 .map(|(expected_entry, output_value)| OutputResultEntry {
460 signal: expected_entry.signal,
461 output: output_value,
462 expected: expected_entry.value,
463 })
464 .collect();
465
466 DataRow {
467 inputs: self.inputs,
468 outputs,
469 line: self.line,
470 }
471 }
472}