1use std::{
2 cell::RefCell,
3 collections::{btree_map, BTreeMap, VecDeque},
4 fmt::Debug,
5 marker::PhantomData,
6 rc::Rc,
7 sync::Arc,
8};
9
10use serde::{de::DeserializeOwned, Serialize};
11
12use crate::{
13 interpreter::VertexInfo,
14 ir::{EdgeParameters, FieldValue, IndexedQuery},
15};
16
17use super::{
18 execution::interpret_ir,
19 trace::{FunctionCall, Opid, Trace, TraceOp, TraceOpContent, YieldValue},
20 Adapter, AsVertex, ContextIterator, ContextOutcomeIterator, DataContext, ResolveEdgeInfo,
21 ResolveInfo, VertexIterator,
22};
23
24#[derive(Clone, Debug)]
25struct TraceReaderAdapter<'trace, Vertex>
26where
27 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'trace,
28{
29 next_op: Rc<RefCell<btree_map::Iter<'trace, Opid, TraceOp<Vertex>>>>,
30}
31
32fn advance_ref_iter<T, Iter: Iterator<Item = T>>(iter: &RefCell<Iter>) -> Option<T> {
33 iter.borrow_mut().next()
36}
37
38#[derive(Debug)]
39struct TraceReaderStartingVerticesIter<'trace, Vertex>
40where
41 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'trace,
42{
43 exhausted: bool,
44 parent_opid: Opid,
45 inner: Rc<RefCell<btree_map::Iter<'trace, Opid, TraceOp<Vertex>>>>,
46}
47
48#[allow(unused_variables)]
49impl<'trace, Vertex> Iterator for TraceReaderStartingVerticesIter<'trace, Vertex>
50where
51 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'trace,
52{
53 type Item = Vertex;
54
55 fn next(&mut self) -> Option<Self::Item> {
56 assert!(!self.exhausted);
57
58 let (_, trace_op) = advance_ref_iter(self.inner.as_ref())
59 .expect("Expected to have an item but found none.");
60 assert_eq!(
61 self.parent_opid,
62 trace_op.parent_opid.expect("Expected an operation with a parent_opid."),
63 "Expected parent_opid {:?} did not match operation {:#?}",
64 self.parent_opid,
65 trace_op,
66 );
67
68 match &trace_op.content {
69 TraceOpContent::OutputIteratorExhausted => {
70 self.exhausted = true;
71 None
72 }
73 TraceOpContent::YieldFrom(YieldValue::ResolveStartingVertices(vertex)) => {
74 Some(vertex.clone())
75 }
76 _ => unreachable!(),
77 }
78 }
79}
80
81struct TraceReaderResolvePropertiesIter<'trace, V, Vertex>
82where
83 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'trace,
84{
85 exhausted: bool,
86 parent_opid: Opid,
87 contexts: ContextIterator<'trace, V>,
88 input_batch: VecDeque<DataContext<V>>,
89 inner: Rc<RefCell<btree_map::Iter<'trace, Opid, TraceOp<Vertex>>>>,
90}
91
92#[allow(unused_variables)]
93impl<'trace, V, Vertex> Iterator for TraceReaderResolvePropertiesIter<'trace, V, Vertex>
94where
95 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'trace,
96 V: AsVertex<Vertex>,
97{
98 type Item = (DataContext<V>, FieldValue);
99
100 fn next(&mut self) -> Option<Self::Item> {
101 assert!(!self.exhausted);
102 let next_op = loop {
103 let (_, input_op) = advance_ref_iter(self.inner.as_ref())
104 .expect("Expected to have an item but found none.");
105 assert_eq!(
106 self.parent_opid,
107 input_op.parent_opid.expect("Expected an operation with a parent_opid."),
108 "Expected parent_opid {:?} did not match operation {:#?}",
109 self.parent_opid,
110 input_op,
111 );
112
113 if let TraceOpContent::AdvanceInputIterator = &input_op.content {
114 let input_data = self.contexts.next();
115
116 let (_, input_op) = advance_ref_iter(self.inner.as_ref())
117 .expect("Expected to have an item but found none.");
118 assert_eq!(
119 self.parent_opid,
120 input_op.parent_opid.expect("Expected an operation with a parent_opid."),
121 "Expected parent_opid {:?} did not match operation {:#?}",
122 self.parent_opid,
123 input_op,
124 );
125
126 if let TraceOpContent::YieldInto(context) = &input_op.content {
127 let input_context = input_data.unwrap();
128 assert_eq!(
129 context,
130 &input_context.clone().flat_map(&mut |v| v.into_vertex()),
131 "at {input_op:?}"
132 );
133 self.input_batch.push_back(input_context);
134 } else if let TraceOpContent::InputIteratorExhausted = &input_op.content {
135 assert!(input_data.is_none(), "at {input_op:?}");
136 } else {
137 unreachable!();
138 }
139 } else {
140 break input_op;
141 }
142 };
143
144 match &next_op.content {
145 TraceOpContent::YieldFrom(YieldValue::ResolveProperty(trace_context, value)) => {
146 let input_context = self.input_batch.pop_front().unwrap();
147 assert_eq!(
148 trace_context,
149 &input_context.clone().flat_map(&mut |v| v.into_vertex()),
150 "at {next_op:?}"
151 );
152 Some((input_context, value.clone()))
153 }
154 TraceOpContent::OutputIteratorExhausted => {
155 assert!(self.input_batch.pop_front().is_none(), "at {next_op:?}");
156 self.exhausted = true;
157 None
158 }
159 _ => unreachable!(),
160 }
161 }
162}
163
164struct TraceReaderResolveCoercionIter<'query, 'trace, V, Vertex>
165where
166 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'query,
167 V: AsVertex<Vertex>,
168 'trace: 'query,
169{
170 exhausted: bool,
171 parent_opid: Opid,
172 contexts: ContextIterator<'query, V>,
173 input_batch: VecDeque<DataContext<V>>,
174 inner: Rc<RefCell<btree_map::Iter<'trace, Opid, TraceOp<Vertex>>>>,
175}
176
177#[allow(unused_variables)]
178impl<'query, 'trace, V, Vertex> Iterator
179 for TraceReaderResolveCoercionIter<'query, 'trace, V, Vertex>
180where
181 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'query,
182 V: AsVertex<Vertex>,
183 'trace: 'query,
184{
185 type Item = (DataContext<V>, bool);
186
187 fn next(&mut self) -> Option<Self::Item> {
188 assert!(!self.exhausted);
189 let next_op = loop {
190 let (_, input_op) = advance_ref_iter(self.inner.as_ref())
191 .expect("Expected to have an item but found none.");
192 assert_eq!(
193 self.parent_opid,
194 input_op.parent_opid.expect("Expected an operation with a parent_opid."),
195 "Expected parent_opid {:?} did not match operation {:#?}",
196 self.parent_opid,
197 input_op,
198 );
199
200 if let TraceOpContent::AdvanceInputIterator = &input_op.content {
201 let input_data = self.contexts.next();
202
203 let (_, input_op) = advance_ref_iter(self.inner.as_ref())
204 .expect("Expected to have an item but found none.");
205 assert_eq!(
206 self.parent_opid,
207 input_op.parent_opid.expect("Expected an operation with a parent_opid."),
208 "Expected parent_opid {:?} did not match operation {:#?}",
209 self.parent_opid,
210 input_op,
211 );
212
213 if let TraceOpContent::YieldInto(context) = &input_op.content {
214 let input_context = input_data.unwrap();
215 assert_eq!(
216 context,
217 &input_context.clone().flat_map(&mut |v| v.into_vertex()),
218 "at {input_op:?}"
219 );
220
221 self.input_batch.push_back(input_context);
222 } else if let TraceOpContent::InputIteratorExhausted = &input_op.content {
223 assert!(input_data.is_none(), "at {input_op:?}");
224 } else {
225 unreachable!();
226 }
227 } else {
228 break input_op;
229 }
230 };
231
232 match &next_op.content {
233 TraceOpContent::YieldFrom(YieldValue::ResolveCoercion(trace_context, can_coerce)) => {
234 let input_context = self.input_batch.pop_front().unwrap();
235 assert_eq!(
236 trace_context,
237 &input_context.clone().flat_map(&mut |v| v.into_vertex()),
238 "at {next_op:?}"
239 );
240 Some((input_context, *can_coerce))
241 }
242 TraceOpContent::OutputIteratorExhausted => {
243 assert!(self.input_batch.pop_front().is_none(), "at {next_op:?}");
244 self.exhausted = true;
245 None
246 }
247 _ => unreachable!(),
248 }
249 }
250}
251
252struct TraceReaderResolveNeighborsIter<'query, 'trace, V, Vertex>
253where
254 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'query,
255 V: AsVertex<Vertex>,
256 'trace: 'query,
257{
258 exhausted: bool,
259 parent_opid: Opid,
260 contexts: ContextIterator<'query, V>,
261 input_batch: VecDeque<DataContext<V>>,
262 inner: Rc<RefCell<btree_map::Iter<'trace, Opid, TraceOp<Vertex>>>>,
263}
264
265impl<'query, 'trace, V, Vertex> Iterator
266 for TraceReaderResolveNeighborsIter<'query, 'trace, V, Vertex>
267where
268 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'query,
269 V: AsVertex<Vertex>,
270 'trace: 'query,
271{
272 type Item = (DataContext<V>, VertexIterator<'query, Vertex>);
273
274 fn next(&mut self) -> Option<Self::Item> {
275 assert!(!self.exhausted);
276 let next_op = loop {
277 let (_, input_op) = advance_ref_iter(self.inner.as_ref())
278 .expect("Expected to have an item but found none.");
279 assert_eq!(
280 self.parent_opid,
281 input_op.parent_opid.expect("Expected an operation with a parent_opid."),
282 "Expected parent_opid {:?} did not match operation {:#?}",
283 self.parent_opid,
284 input_op,
285 );
286
287 if let TraceOpContent::AdvanceInputIterator = &input_op.content {
288 let input_data = self.contexts.next();
289
290 let (_, input_op) = advance_ref_iter(self.inner.as_ref())
291 .expect("Expected to have an item but found none.");
292 assert_eq!(
293 self.parent_opid,
294 input_op.parent_opid.expect("Expected an operation with a parent_opid."),
295 "Expected parent_opid {:?} did not match operation {:#?}",
296 self.parent_opid,
297 input_op,
298 );
299
300 if let TraceOpContent::YieldInto(context) = &input_op.content {
301 let input_context = input_data.unwrap();
302 assert_eq!(
303 context,
304 &input_context.clone().flat_map(&mut |v| v.into_vertex()),
305 "at {input_op:?}"
306 );
307
308 self.input_batch.push_back(input_context);
309 } else if let TraceOpContent::InputIteratorExhausted = &input_op.content {
310 assert!(input_data.is_none(), "at {input_op:?}");
311 } else {
312 unreachable!();
313 }
314 } else {
315 break input_op;
316 }
317 };
318
319 match &next_op.content {
320 TraceOpContent::YieldFrom(YieldValue::ResolveNeighborsOuter(trace_context)) => {
321 let input_context = self.input_batch.pop_front().unwrap();
322 assert_eq!(
323 trace_context,
324 &input_context.clone().flat_map(&mut |v| v.into_vertex()),
325 "at {next_op:?}"
326 );
327
328 let neighbors = Box::new(TraceReaderNeighborIter {
329 exhausted: false,
330 parent_iterator_opid: next_op.opid,
331 next_index: 0,
332 inner: self.inner.clone(),
333 _phantom: PhantomData,
334 });
335 Some((input_context, neighbors))
336 }
337 TraceOpContent::OutputIteratorExhausted => {
338 assert!(self.input_batch.pop_front().is_none(), "at {next_op:?}");
339 self.exhausted = true;
340 None
341 }
342 _ => unreachable!(),
343 }
344 }
345}
346
347struct TraceReaderNeighborIter<'query, 'trace, Vertex>
348where
349 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'query,
350 'trace: 'query,
351{
352 exhausted: bool,
353 parent_iterator_opid: Opid,
354 next_index: usize,
355 inner: Rc<RefCell<btree_map::Iter<'trace, Opid, TraceOp<Vertex>>>>,
356 _phantom: PhantomData<&'query ()>,
357}
358
359impl<'query, 'trace, Vertex> Iterator for TraceReaderNeighborIter<'query, 'trace, Vertex>
360where
361 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'query,
362 'trace: 'query,
363{
364 type Item = Vertex;
365
366 fn next(&mut self) -> Option<Self::Item> {
367 let (_, trace_op) = advance_ref_iter(self.inner.as_ref())
368 .expect("Expected to have an item but found none.");
369 assert!(!self.exhausted);
370 assert_eq!(
371 self.parent_iterator_opid,
372 trace_op.parent_opid.expect("Expected an operation with a parent_opid."),
373 "Expected parent_opid {:?} did not match operation {:#?}",
374 self.parent_iterator_opid,
375 trace_op,
376 );
377
378 match &trace_op.content {
379 TraceOpContent::OutputIteratorExhausted => {
380 self.exhausted = true;
381 None
382 }
383 TraceOpContent::YieldFrom(YieldValue::ResolveNeighborsInner(index, vertex)) => {
384 assert_eq!(self.next_index, *index, "at {trace_op:?}");
385 self.next_index += 1;
386 Some(vertex.clone())
387 }
388 _ => unreachable!(),
389 }
390 }
391}
392
393#[allow(unused_variables)]
394impl<'trace, Vertex> Adapter<'trace> for TraceReaderAdapter<'trace, Vertex>
395where
396 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'trace,
397{
398 type Vertex = Vertex;
399
400 fn resolve_starting_vertices(
401 &self,
402 edge_name: &Arc<str>,
403 parameters: &EdgeParameters,
404 resolve_info: &ResolveInfo,
405 ) -> VertexIterator<'trace, Self::Vertex> {
406 let (root_opid, trace_op) = advance_ref_iter(self.next_op.as_ref())
407 .expect("Expected a resolve_starting_vertices() call operation, but found none.");
408 assert_eq!(None, trace_op.parent_opid);
409
410 if let TraceOpContent::Call(FunctionCall::ResolveStartingVertices(vid)) = trace_op.content {
411 assert_eq!(vid, resolve_info.vid());
412
413 Box::new(TraceReaderStartingVerticesIter {
414 exhausted: false,
415 parent_opid: *root_opid,
416 inner: self.next_op.clone(),
417 })
418 } else {
419 unreachable!()
420 }
421 }
422
423 fn resolve_property<V: AsVertex<Self::Vertex> + 'trace>(
424 &self,
425 contexts: ContextIterator<'trace, V>,
426 type_name: &Arc<str>,
427 property_name: &Arc<str>,
428 resolve_info: &ResolveInfo,
429 ) -> ContextOutcomeIterator<'trace, V, FieldValue> {
430 let (root_opid, trace_op) = advance_ref_iter(self.next_op.as_ref())
431 .expect("Expected a resolve_property() call operation, but found none.");
432 assert_eq!(None, trace_op.parent_opid);
433
434 if let TraceOpContent::Call(FunctionCall::ResolveProperty(vid, op_type_name, property)) =
435 &trace_op.content
436 {
437 assert_eq!(*vid, resolve_info.vid());
438 assert_eq!(op_type_name, type_name);
439 assert_eq!(property, property_name);
440
441 Box::new(TraceReaderResolvePropertiesIter {
442 exhausted: false,
443 parent_opid: *root_opid,
444 contexts,
445 input_batch: Default::default(),
446 inner: self.next_op.clone(),
447 })
448 } else {
449 unreachable!()
450 }
451 }
452
453 fn resolve_neighbors<V: AsVertex<Self::Vertex> + 'trace>(
454 &self,
455 contexts: ContextIterator<'trace, V>,
456 type_name: &Arc<str>,
457 edge_name: &Arc<str>,
458 parameters: &EdgeParameters,
459 resolve_info: &ResolveEdgeInfo,
460 ) -> ContextOutcomeIterator<'trace, V, VertexIterator<'trace, Self::Vertex>> {
461 let (root_opid, trace_op) = advance_ref_iter(self.next_op.as_ref())
462 .expect("Expected a resolve_property() call operation, but found none.");
463 assert_eq!(None, trace_op.parent_opid);
464
465 if let TraceOpContent::Call(FunctionCall::ResolveNeighbors(vid, op_type_name, eid)) =
466 &trace_op.content
467 {
468 assert_eq!(*vid, resolve_info.origin_vid());
469 assert_eq!(op_type_name, type_name);
470 assert_eq!(*eid, resolve_info.eid());
471
472 Box::new(TraceReaderResolveNeighborsIter {
473 exhausted: false,
474 parent_opid: *root_opid,
475 contexts,
476 input_batch: Default::default(),
477 inner: self.next_op.clone(),
478 })
479 } else {
480 unreachable!()
481 }
482 }
483
484 fn resolve_coercion<V: AsVertex<Self::Vertex> + 'trace>(
485 &self,
486 contexts: ContextIterator<'trace, V>,
487 type_name: &Arc<str>,
488 coerce_to_type: &Arc<str>,
489 resolve_info: &ResolveInfo,
490 ) -> ContextOutcomeIterator<'trace, V, bool> {
491 let (root_opid, trace_op) = advance_ref_iter(self.next_op.as_ref())
492 .expect("Expected a resolve_coercion() call operation, but found none.");
493 assert_eq!(None, trace_op.parent_opid);
494
495 if let TraceOpContent::Call(FunctionCall::ResolveCoercion(vid, from_type, to_type)) =
496 &trace_op.content
497 {
498 assert_eq!(*vid, resolve_info.vid());
499 assert_eq!(from_type, type_name);
500 assert_eq!(to_type, coerce_to_type);
501
502 Box::new(TraceReaderResolveCoercionIter {
503 exhausted: false,
504 parent_opid: *root_opid,
505 contexts,
506 input_batch: Default::default(),
507 inner: self.next_op.clone(),
508 })
509 } else {
510 unreachable!()
511 }
512 }
513}
514
515#[allow(dead_code)]
516pub fn assert_interpreted_results<'query, 'trace, Vertex>(
517 trace: &Trace<Vertex>,
518 expected_results: &[BTreeMap<Arc<str>, FieldValue>],
519 complete: bool,
520) where
521 Vertex: Clone + Debug + PartialEq + Eq + Serialize + DeserializeOwned + 'query,
522 'trace: 'query,
523{
524 let next_op = Rc::new(RefCell::new(trace.ops.iter()));
525 let trace_reader_adapter = Arc::new(TraceReaderAdapter { next_op: next_op.clone() });
526
527 let query: Arc<IndexedQuery> = Arc::new(trace.ir_query.clone().try_into().unwrap());
528 let arguments = Arc::new(
529 trace.arguments.iter().map(|(k, v)| (Arc::from(k.to_owned()), v.clone())).collect(),
530 );
531 let mut trace_iter = interpret_ir(trace_reader_adapter, query, arguments).unwrap();
532 let mut expected_iter = expected_results.iter();
533
534 loop {
535 let expected_row = expected_iter.next();
536 let trace_row = trace_iter.next();
537
538 if let Some(expected_row_content) = expected_row {
539 let trace_expected_row = {
540 let mut next_op_ref = next_op.borrow_mut();
541 let Some((_, trace_op)) = next_op_ref.next() else {
542 panic!("Reached the end of the trace without producing result {trace_row:#?}");
543 };
544 let TraceOpContent::ProduceQueryResult(expected_result) = &trace_op.content else {
545 panic!("Expected the trace to produce a result {trace_row:#?} but got another type of operation instead: {trace_op:#?}");
546 };
547 drop(next_op_ref);
548
549 expected_result
550 };
551 assert_eq!(
552 trace_expected_row, expected_row_content,
553 "This trace is self-inconsistent: trace produces row {trace_expected_row:#?} \
554 but results have row {expected_row_content:#?}",
555 );
556
557 assert_eq!(expected_row, trace_row.as_ref());
558 } else {
559 if complete {
560 assert_eq!(None, trace_row);
561 }
562 return;
563 }
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use std::{
570 fmt::Debug,
571 fs,
572 path::{Path, PathBuf},
573 };
574
575 use serde::{de::DeserializeOwned, Serialize};
576 use trustfall_filetests_macros::parameterize;
577
578 use crate::{
579 filesystem_interpreter::FilesystemVertex,
580 interpreter::replay::assert_interpreted_results,
581 numbers_interpreter::NumbersVertex,
582 test_types::{
583 TestIRQuery, TestIRQueryResult, TestInterpreterOutputData, TestInterpreterOutputTrace,
584 },
585 };
586
587 fn check_trace<Vertex>(
588 expected_ir: TestIRQuery,
589 test_data: TestInterpreterOutputTrace<Vertex>,
590 test_outputs: TestInterpreterOutputData,
591 ) where
592 Vertex: Debug + Clone + PartialEq + Eq + Serialize + DeserializeOwned,
593 {
594 assert_eq!(expected_ir.ir_query, test_data.trace.ir_query);
596 assert_eq!(expected_ir.arguments, test_data.trace.arguments);
597
598 assert_interpreted_results(&test_data.trace, &test_outputs.results, true);
599 }
600
601 fn check_filesystem_trace(
602 expected_ir: TestIRQuery,
603 input_data: &str,
604 test_outputs: TestInterpreterOutputData,
605 ) {
606 match ron::from_str::<TestInterpreterOutputTrace<FilesystemVertex>>(input_data) {
607 Ok(test_data) => {
608 assert_eq!(expected_ir.schema_name, "filesystem");
609 assert_eq!(test_data.schema_name, "filesystem");
610 check_trace(expected_ir, test_data, test_outputs);
611 }
612 Err(e) => {
613 unreachable!("failed to parse trace file: {e}");
614 }
615 }
616 }
617
618 fn check_numbers_trace(
619 expected_ir: TestIRQuery,
620 input_data: &str,
621 test_outputs: TestInterpreterOutputData,
622 ) {
623 match ron::from_str::<TestInterpreterOutputTrace<NumbersVertex>>(input_data) {
624 Ok(test_data) => {
625 assert_eq!(expected_ir.schema_name, "numbers");
626 assert_eq!(test_data.schema_name, "numbers");
627 check_trace(expected_ir, test_data, test_outputs);
628 }
629 Err(e) => {
630 unreachable!("failed to parse trace file: {e}");
631 }
632 }
633 }
634
635 #[parameterize("trustfall_core/test_data/tests/valid_queries")]
636 fn parameterized_tester(base: &Path, stem: &str) {
637 let mut input_path = PathBuf::from(base);
638 input_path.push(format!("{stem}.trace.ron"));
639
640 let input_data = fs::read_to_string(input_path).unwrap();
641
642 let mut output_data_path = PathBuf::from(base);
643 output_data_path.push(format!("{stem}.output.ron"));
644 let output_data =
645 fs::read_to_string(output_data_path).expect("failed to read outputs file");
646 let test_outputs: TestInterpreterOutputData =
647 ron::from_str(&output_data).expect("failed to parse outputs file");
648
649 let mut check_path = PathBuf::from(base);
650 check_path.push(format!("{stem}.ir.ron"));
651 let check_data = fs::read_to_string(check_path).unwrap();
652 let expected_ir: TestIRQueryResult = ron::from_str(&check_data).unwrap();
653 let expected_ir = expected_ir.unwrap();
654
655 match expected_ir.schema_name.as_str() {
656 "filesystem" => check_filesystem_trace(expected_ir, input_data.as_str(), test_outputs),
657 "numbers" => check_numbers_trace(expected_ir, input_data.as_str(), test_outputs),
658 _ => unreachable!("{}", expected_ir.schema_name),
659 }
660 }
661}