directed/
lib.rs

1#![doc = include_str!("../README.md")]
2mod error;
3mod graphs;
4mod node;
5mod registry;
6mod stage;
7mod types;
8
9pub use directed_stage_macro::stage;
10pub use error::*;
11pub use graphs::{EdgeInfo, Graph};
12pub use node::{AnyNode, Cached, DowncastEq, Node};
13pub use registry::{NodeId, Registry};
14pub use stage::{EvalStrategy, ReevaluationRule, RefType, Stage};
15pub use types::{DataLabel, NodeOutput};
16
17#[cfg(test)]
18mod tests {
19    extern crate self as directed;
20    use super::*;
21    use directed_stage_macro::stage;
22    use std::sync::atomic::{AtomicUsize, Ordering};
23
24    // A simple sanity-check test that doesn't try anything interesting
25    #[test]
26    fn basic_macro_test() {
27        #[stage(lazy, cache_last)]
28        fn TinyStage1() -> String {
29            println!("Running stage 1");
30            String::from("This is the output!")
31        }
32
33        #[stage(lazy, cache_last)]
34        fn TinyStage2(input: String, input2: String) -> String {
35            println!("Running stage 2");
36            input.to_uppercase() + " [" + &input2.chars().count().to_string() + " chars]"
37        }
38
39        #[stage(cache_last)]
40        fn TinyStage3(input: String) {
41            println!("Running stage 3");
42            assert_eq!("THIS IS THE OUTPUT! [19 chars]", input);
43        }
44
45        let mut registry = Registry::new();
46        let node_1 = registry.register(TinyStage1::new());
47        let node_2 = registry.register(TinyStage2::new());
48        let node_3 = registry.register(TinyStage3::new());
49        let graph = graph! {
50            nodes: [node_1, node_2, node_3],
51            connections: {
52                node_1 => node_2: input,
53                node_1 => node_2: input2,
54                node_2 => node_3: input,
55            }
56        }
57        .unwrap();
58
59        graph.execute(&mut registry).unwrap();
60    }
61
62    // Test multiple output stages
63    #[test]
64    fn multiple_output_stage_test() {
65        #[stage(out(number: i32, text: String))]
66        fn MultiOutputStage() -> NodeOutput {
67            let value1 = 42;
68            let value2 = String::from("Hello");
69            output! {
70                number: value1,
71                text: value2
72            }
73        }
74
75        #[stage]
76        fn ConsumerStage1(number: i32) {
77            assert_eq!(number, 42);
78        }
79
80        #[stage]
81        fn ConsumerStage2(text: String) {
82            assert_eq!(text, "Hello");
83        }
84
85        let mut registry = Registry::new();
86        let producer = registry.register(MultiOutputStage::new());
87        let consumer1 = registry.register(ConsumerStage1::new());
88        let consumer2 = registry.register(ConsumerStage2::new());
89
90        let graph = graph! {
91            nodes: [producer, consumer1, consumer2],
92            connections: {
93                producer: number => consumer1: number,
94                producer: text => consumer2: text,
95            }
96        }
97        .unwrap();
98
99        graph.execute(&mut registry).unwrap();
100    }
101
102    // Test evaluating lazy vs urgent nodes
103    #[test]
104    fn lazy_and_urgent_eval_test() {
105        static COUNTER: AtomicUsize = AtomicUsize::new(0);
106
107        #[stage(lazy, cache_last)]
108        fn LazyStage() -> i32 {
109            COUNTER.fetch_add(1, Ordering::SeqCst);
110            42
111        }
112
113        #[stage(cache_last)]
114        fn UrgentStage(input: i32) {
115            assert_eq!(input, 42);
116            assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
117        }
118
119        let mut registry = Registry::new();
120        let lazy_node = registry.register(LazyStage::new());
121        let urgent_node = registry.register(UrgentStage::new());
122
123        let graph = graph! {
124            nodes: [lazy_node, urgent_node],
125            connections: {
126                lazy_node => urgent_node: input,
127            }
128        }
129        .unwrap();
130
131        // Reset counter
132        COUNTER.store(0, Ordering::SeqCst);
133
134        // Execute should evaluate LazyStage because UrgentStage depends on it
135        graph.execute(&mut registry).unwrap();
136    }
137
138    // Test transparent vs opaque reevaluation rules
139    #[test]
140    fn transparent_opaque_reevaluation_test() {
141        static TRANSPARENT_COUNTER: AtomicUsize = AtomicUsize::new(0);
142        static OPAQUE_COUNTER: AtomicUsize = AtomicUsize::new(0);
143
144        #[stage(lazy, cache_last)]
145        fn SourceStage() -> i32 {
146            println!("SourceStage");
147            42
148        }
149
150        #[stage(lazy, cache_last)]
151        fn TransparentStage(input: i32) -> i32 {
152            println!("TransparentStage");
153            TRANSPARENT_COUNTER.fetch_add(1, Ordering::SeqCst);
154            input * 2
155        }
156
157        #[stage(lazy)]
158        fn OpaqueStage(input: &i32) -> i32 {
159            println!("OpaqueStage");
160            OPAQUE_COUNTER.fetch_add(1, Ordering::SeqCst);
161            input * 3
162        }
163
164        #[stage]
165        fn SinkStage(t_input: &i32, o_input: &i32) {
166            println!("SinkStage");
167            assert_eq!(*t_input, 84);
168            assert_eq!(*o_input, 126);
169        }
170
171        let mut registry = Registry::new();
172        let source = registry.register(SourceStage::new());
173        let transparent = registry.register(TransparentStage::new());
174        let opaque = registry.register(OpaqueStage::new());
175        let sink = registry.register(SinkStage::new());
176
177        let graph = graph! {
178            nodes: [source, transparent, opaque, sink],
179            connections: {
180                source => transparent: input,
181                source => opaque: input,
182                transparent => sink: t_input,
183                opaque => sink: o_input,
184            }
185        }
186        .unwrap();
187
188        // Reset counters
189        TRANSPARENT_COUNTER.store(0, Ordering::SeqCst);
190        OPAQUE_COUNTER.store(0, Ordering::SeqCst);
191
192        // First execution
193        graph.execute(&mut registry).unwrap();
194        assert_eq!(TRANSPARENT_COUNTER.load(Ordering::SeqCst), 1);
195        assert_eq!(OPAQUE_COUNTER.load(Ordering::SeqCst), 1);
196
197        // Second execution - transparent stage shouldn't execute again since inputs haven't changed
198        graph.execute(&mut registry).unwrap();
199        assert_eq!(TRANSPARENT_COUNTER.load(Ordering::SeqCst), 1); // Still 1
200        assert_eq!(OPAQUE_COUNTER.load(Ordering::SeqCst), 2); // Increased to 2
201    }
202
203    // Test graph cycle detection
204    #[test]
205    fn cycle_detection_test() {
206        #[stage]
207        fn StageA(input: i32) -> i32 {
208            input + 1
209        }
210
211        #[stage]
212        fn StageB(input: i32) -> i32 {
213            input * 2
214        }
215
216        let mut registry = Registry::new();
217        let node_a = registry.register(StageA::new());
218        let node_b = registry.register(StageB::new());
219
220        // Attempt to create a cyclic graph
221        let result = graph! {
222            nodes: [node_a, node_b],
223            connections: {
224                node_a => node_b: input,
225                node_b => node_a: input,
226            }
227        };
228
229        // The graph creation should fail due to cycle detection
230        assert!(result.is_err());
231    }
232
233    // Test registry functionality
234    #[test]
235    fn registry_operations_test() {
236        #[stage]
237        fn SimpleStage() -> i32 {
238            42
239        }
240
241        let mut registry = Registry::new();
242
243        // Register a node
244        let node_id = registry.register(SimpleStage::new());
245
246        // Validate node type
247        registry.validate_node_type::<SimpleStage>(node_id).unwrap();
248
249        // Validate incorrect type
250        #[stage]
251        fn OtherStage() -> String {
252            "hello".to_string()
253        }
254        assert!(registry.validate_node_type::<OtherStage>(node_id).is_err());
255
256        // Get node
257        assert!(registry.get(node_id).is_some());
258
259        // Get mutable node
260        assert!(registry.get_mut(node_id).is_some());
261
262        // Unregister
263        let node = registry
264            .unregister::<SimpleStage>(node_id)
265            .unwrap()
266            .unwrap();
267        assert!(node.stage.eval_strategy() == EvalStrategy::Urgent);
268
269        // Node no longer exists
270        assert!(registry.get(node_id).is_none());
271    }
272
273    // Test error handling when node doesn't exist
274    #[test]
275    fn nonexistent_node_test() {
276        let mut registry = Registry::new();
277
278        // Node ID that doesn't exist
279        let invalid_id = 9999;
280
281        // Various operations should fail
282        assert!(registry.get(invalid_id).is_none());
283        assert!(registry.get_mut(invalid_id).is_none());
284        assert!(registry.unregister_and_drop(invalid_id).is_err());
285    }
286
287    // Test type mismatches in connections
288    #[test]
289    fn type_mismatch_test() {
290        #[stage]
291        fn StringStage() -> String {
292            "Hello".to_string()
293        }
294
295        #[stage]
296        fn IntegerConsumer(_input: i32) {
297            // This should never execute due to type mismatch
298            panic!("Should not execute");
299        }
300
301        let mut registry = Registry::new();
302        let producer = registry.register(StringStage::new());
303        let consumer = registry.register(IntegerConsumer::new());
304
305        // Create graph with type-incompatible connection
306        let graph = graph! {
307            nodes: [producer, consumer],
308            connections: {
309                producer => consumer: input,
310            }
311        }
312        .unwrap();
313
314        // Execution should fail due to type mismatch when flowing data
315        let result = graph.execute(&mut registry);
316        assert!(result.is_err());
317    }
318
319    // Test missing inputs
320    #[test]
321    fn missing_input_test() {
322        #[stage]
323        fn ConsumerStage(_input1: i32, _input2: String) {
324            // This should never execute due to missing input
325            panic!("Should not execute");
326        }
327
328        #[stage]
329        fn ProducerStage() -> i32 {
330            42
331        }
332
333        let mut registry = Registry::new();
334        let producer = registry.register(ProducerStage::new());
335        let consumer = registry.register(ConsumerStage::new());
336
337        // Only connect one of the required inputs
338        let graph = graph! {
339            nodes: [producer, consumer],
340            connections: {
341                producer => consumer: input1,
342            }
343        }
344        .unwrap();
345
346        // Execution should fail due to missing input
347        let result = graph.execute(&mut registry);
348        assert!(result.is_err());
349    }
350
351    // Test DataLabel functionality
352    #[test]
353    fn data_label_test() {
354        let label1 = DataLabel::new("test");
355        let label2 = DataLabel::new("test");
356        let label3 = DataLabel::new("different");
357
358        assert_eq!(label1, label2);
359        assert_ne!(label1, label3);
360
361        let const_label = DataLabel::new_const("const");
362        assert_eq!(const_label.inner(), Some("const"));
363
364        let from_str: DataLabel = "string".into();
365        assert_eq!(from_str.inner(), Some("string"));
366    }
367
368    // Test graph with diamond pattern
369    #[test]
370    fn diamond_graph_test() {
371        #[stage]
372        fn Source() -> i32 {
373            10
374        }
375
376        #[stage]
377        fn PathA(input: i32) -> i32 {
378            input * 2
379        }
380
381        #[stage]
382        fn PathB(input: i32) -> i32 {
383            input + 5
384        }
385
386        #[stage]
387        fn Sink(a: i32, b: i32) {
388            assert_eq!(a, 20); // 10 * 2
389            assert_eq!(b, 15); // 10 + 5
390        }
391
392        let mut registry = Registry::new();
393        let source = registry.register(Source::new());
394        let path_a = registry.register(PathA::new());
395        let path_b = registry.register(PathB::new());
396        let sink = registry.register(Sink::new());
397
398        let graph = graph! {
399            nodes: [source, path_a, path_b, sink],
400            connections: {
401                source => path_a: input,
402                source => path_b: input,
403                path_a => sink: a,
404                path_b => sink: b,
405            }
406        }
407        .unwrap();
408
409        graph.execute(&mut registry).unwrap();
410    }
411
412    // Test accessing outputs by wrong name
413    #[test]
414    fn invalid_output_name_test() {
415        #[stage]
416        fn MultiOutputStage() -> NodeOutput {
417            output! {
418                output1: 42,
419                output2: "Hello".to_string()
420            }
421        }
422
423        #[stage]
424        fn ConsumerStage(_input: i32) {
425            // Should never execute
426            panic!("Should not execute");
427        }
428
429        let mut registry = Registry::new();
430        let producer = registry.register(MultiOutputStage::new());
431        let consumer = registry.register(ConsumerStage::new());
432
433        // Connect with non-existent output name
434        let graph = graph! {
435            nodes: [producer, consumer],
436            connections: {
437                producer: nonexistent => consumer: input,
438            }
439        }
440        .unwrap();
441
442        // Should fail because the output name doesn't exist
443        let result = graph.execute(&mut registry);
444        assert!(result.is_err());
445    }
446
447    /// Test nodes with internal state
448    #[test]
449    fn node_with_state_test() {
450        #[stage(state((u8, u8)))]
451        fn StateStage() {
452            assert_eq!(state.1, state.0 * 5);
453            state.0 += 1;
454            state.1 += 5;
455            println!("State is {}", state.1);
456        }
457
458        let mut registry = Registry::new();
459        // Note: If the state has an implementation of "default", the simple
460        // register can still be called instead
461        let node = registry.register_with_state(StateStage::new(), (1, 5));
462        let graph = graph! {
463            nodes: [node],
464            connections: {}
465        }
466        .unwrap();
467
468        // TODO: Actually return results so this test can be real (right now it would pass if state never updated)
469        graph.execute(&mut registry).unwrap();
470        graph.execute(&mut registry).unwrap();
471        graph.execute(&mut registry).unwrap();
472        graph.execute(&mut registry).unwrap();
473    }
474
475    // Test the output! macro
476    #[test]
477    fn output_macro_test() {
478        #[stage(out(number: i32, text: String, vector: Vec<i32>))]
479        fn ProduceOutput1() -> NodeOutput {
480            println!("Running ProduceOutput1");
481            let number = 42;
482            let text = "hello".to_string();
483            let vector = vec![1, 2, 3];
484
485            output! {
486                number,
487                text,
488                vector
489            }
490        }
491
492        #[stage]
493        fn ConsumeOutputs(num: i32, txt: String, vec: Vec<i32>) {
494            assert_eq!(num, 42);
495            assert_eq!(txt, "hello");
496            assert_eq!(vec, vec![1, 2, 3]);
497        }
498
499        let mut registry = Registry::new();
500        let producer = registry.register(ProduceOutput1::new());
501        let consumer = registry.register(ConsumeOutputs::new());
502
503        let graph = graph! {
504            nodes: [producer, consumer],
505            connections: {
506                producer: number => consumer: num,
507                producer: text => consumer: txt,
508                producer: vector => consumer: vec,
509            }
510        }
511        .unwrap();
512
513        graph.execute(&mut registry).unwrap();
514    }
515
516    // Test registry node type validation
517    #[test]
518    fn registry_type_validation_test() {
519        #[stage]
520        fn StageA() -> i32 {
521            42
522        }
523
524        #[stage]
525        fn StageB() -> String {
526            "hello".to_string()
527        }
528
529        let mut registry = Registry::new();
530        let node_a = registry.register(StageA::new());
531
532        // Correct type validation should succeed
533        assert!(registry.validate_node_type::<StageA>(node_a).is_ok());
534
535        // Incorrect type validation should fail
536        assert!(registry.validate_node_type::<StageB>(node_a).is_err());
537
538        // Unregistering with incorrect type should fail
539        assert!(registry.unregister::<StageB>(node_a).is_err());
540
541        // Unregistering with correct type should succeed
542        assert!(registry.unregister::<StageA>(node_a).is_ok());
543    }
544
545    #[test]
546    fn basic_cache_all_test() {
547        static COUNTER: AtomicUsize = AtomicUsize::new(0);
548
549        #[stage(lazy, cache_all)]
550        fn CacheStage1() -> String {
551            println!("Running stage 1");
552            COUNTER.fetch_add(1, Ordering::SeqCst);
553            String::from("This is the output!")
554        }
555
556        #[stage(lazy, cache_all)]
557        fn CacheStage2(input: String, input2: String) -> String {
558            println!("Running stage 2");
559            COUNTER.fetch_add(1, Ordering::SeqCst);
560            input.to_uppercase() + " [" + &input2.chars().count().to_string() + " chars]"
561        }
562
563        #[stage(cache_last)]
564        fn TinyStage3(input: String) {
565            println!("Running stage 3");
566            assert_eq!("THIS IS THE OUTPUT! [19 chars]", input);
567        }
568
569        #[stage(lazy, cache_all)]
570        fn CacheStage1Alternate() -> String {
571            println!("Running alt stage 1");
572            COUNTER.fetch_add(1, Ordering::SeqCst);
573            String::from("This is a different output!")
574        }
575
576        #[stage(cache_last)]
577        fn TinyStage3Alternate(input: String) {
578            println!("Running alt stage 3");
579            assert_eq!("THIS IS A DIFFERENT OUTPUT! [27 chars]", input);
580        }
581
582        let mut registry = Registry::new();
583        let node_1 = registry.register(CacheStage1::new());
584        let node_2 = registry.register(CacheStage2::new());
585        let node_3 = registry.register(TinyStage3::new());
586        let node_1_alt = registry.register(CacheStage1Alternate::new());
587        let node_3_alt = registry.register(TinyStage3Alternate::new());
588
589        let graph1 = graph! {
590            nodes: [node_1, node_2, node_3],
591            connections: {
592                node_1 => node_2: input,
593                node_1 => node_2: input2,
594                node_2 => node_3: input,
595            }
596        }
597        .unwrap();
598
599        graph1.execute(&mut registry).unwrap();
600        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
601        graph1.execute(&mut registry).unwrap();
602        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
603
604        // Now with a modified graph, but same stage 2
605        let graph2 = graph! {
606            nodes: [node_1_alt, node_2, node_3_alt],
607            connections: {
608                node_1_alt => node_2: input,
609                node_1_alt => node_2: input2,
610                node_2 => node_3_alt: input,
611            }
612        }
613        .unwrap();
614
615        graph2.execute(&mut registry).unwrap();
616        assert_eq!(COUNTER.load(Ordering::SeqCst), 4);
617        graph2.execute(&mut registry).unwrap();
618        assert_eq!(COUNTER.load(Ordering::SeqCst), 4);
619        graph1.execute(&mut registry).unwrap();
620        assert_eq!(COUNTER.load(Ordering::SeqCst), 4);
621    }
622
623    /// Test connections without data
624    #[test]
625    fn blank_connections_test() {
626        static COUNTER: AtomicUsize = AtomicUsize::new(0);
627        #[stage(lazy)]
628        fn TinyStage1() {
629            println!("Running stage 1");
630            COUNTER.fetch_add(1, Ordering::SeqCst);
631        }
632
633        #[stage(lazy)]
634        fn TinyStage2() {
635            println!("Running stage 2");
636            assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
637            COUNTER.fetch_add(1, Ordering::SeqCst);
638        }
639
640        #[stage]
641        fn TinyStage3() {
642            println!("Running stage 3");
643            assert_eq!(COUNTER.load(Ordering::SeqCst), 3);
644            COUNTER.fetch_add(1, Ordering::SeqCst);
645        }
646
647        let mut registry = Registry::new();
648        let node_1 = registry.register(TinyStage1::new());
649        let node_2 = registry.register(TinyStage2::new());
650        let node_3 = registry.register(TinyStage3::new());
651        let graph = graph! {
652            nodes: [node_1, node_2, node_3],
653            connections: {
654                node_1 => node_2,
655                node_2 => node_3,
656                node_1 => node_3,
657            }
658        }
659        .unwrap();
660
661        graph.execute(&mut registry).unwrap();
662        assert_eq!(COUNTER.load(Ordering::SeqCst), 4);
663    }
664
665    // TODO: Specific test for trace generation
666}
667
668// In src/lib.rs - Add a test for async execution
669#[cfg(all(test, feature = "tokio"))]
670mod async_tests {
671    extern crate self as directed;
672    use super::*;
673    use directed_stage_macro::stage;
674    use std::sync::atomic::{AtomicUsize, Ordering};
675
676    #[tokio::test(flavor = "multi_thread", worker_threads = 3)]
677    async fn parallel_execution_test() {
678        use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
679        // FIXME: Unsurprisingly, the dubious test occasionally failes
680        static COUNTER: AtomicUsize = AtomicUsize::new(0);
681        let (tx1, rx1) = unbounded_channel::<u8>();
682        let (tx2, rx2) = unbounded_channel::<u8>();
683
684        #[stage(lazy, state((UnboundedSender<u8>, UnboundedReceiver<u8>)))]
685        async fn SlowStage1() -> i32 {
686            println!("Running SlowStage1");
687            let (tx, rx) = state;
688            tx.send(1).unwrap();
689            assert_eq!(rx.recv().await.unwrap(), 2);
690            COUNTER.fetch_add(1, Ordering::SeqCst);
691            42
692        }
693
694        #[stage(lazy, state((UnboundedSender<u8>, UnboundedReceiver<u8>)))]
695        async fn SlowStage2() -> String {
696            println!("Running SlowStage2");
697            let (tx, rx) = state;
698            assert_eq!(rx.recv().await.unwrap(), 1);
699            tx.send(2).unwrap();
700            COUNTER.fetch_add(1, Ordering::SeqCst);
701            "hello".to_string()
702        }
703
704        #[stage]
705        fn CombineStage(as_num: i32, as_text: String) -> String {
706            println!("Running CombineStage");
707            format!("{} {}", as_text, as_num)
708        }
709
710        let mut registry = Registry::new();
711        let stage1 = registry.register_with_state(SlowStage1::new(), (tx1, rx2));
712        let stage2 = registry.register_with_state(SlowStage2::new(), (tx2, rx1));
713        let combine = registry.register(CombineStage::new());
714
715        println!("Node {stage1} is SlowStage1");
716        println!("Node {stage2} is SlowStage2");
717        println!("Node {combine} is CombineStage");
718
719        let graph = graph! {
720            nodes: [stage1, stage2, combine],
721            connections: {
722                stage1 => combine: as_num,
723                stage2 => combine: as_text,
724            }
725        }
726        .unwrap();
727        let graph = std::sync::Arc::new(graph);
728
729        // Reset counter
730        COUNTER.store(0, Ordering::SeqCst);
731
732        // Time the execution
733        graph
734            .execute_async(tokio::sync::Mutex::new(registry))
735            .await
736            .unwrap();
737
738        // Both slow stages should have been executed
739        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
740    }
741}