uncertain_rs/
computation.rs

1use crate::operations::{Arithmetic, arithmetic::BinaryOperation};
2use crate::traits::Shareable;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6/// Adaptive sampling strategy for optimizing computation graph evaluation
7#[derive(Debug, Clone)]
8pub struct AdaptiveSampling {
9    /// Minimum sample count to try
10    pub min_samples: usize,
11    /// Maximum sample count allowed
12    pub max_samples: usize,
13    /// Relative error threshold for convergence
14    pub error_threshold: f64,
15    /// Factor to increase sample count on each iteration
16    pub growth_factor: f64,
17}
18
19impl Default for AdaptiveSampling {
20    fn default() -> Self {
21        Self {
22            min_samples: 100,
23            max_samples: 10000,
24            error_threshold: 0.01,
25            growth_factor: 1.5,
26        }
27    }
28}
29
30/// Caching strategy for computation graph nodes
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum CachingStrategy {
33    /// Cache all intermediate results
34    Aggressive,
35    /// Cache only expensive operations
36    Conservative,
37    /// Adaptive caching based on computation cost
38    Adaptive,
39}
40
41/// Context for memoizing samples within a single evaluation to ensure
42/// shared variables produce the same sample value throughout an evaluation
43pub struct SampleContext {
44    /// Memoized values indexed by node ID
45    memoized_values: HashMap<uuid::Uuid, Box<dyn std::any::Any + Send>>,
46    /// Current caching strategy
47    caching_strategy: CachingStrategy,
48    /// Adaptive sampling configuration
49    adaptive_sampling: AdaptiveSampling,
50}
51
52impl SampleContext {
53    /// Create a new empty sample context
54    #[must_use]
55    pub fn new() -> Self {
56        Self {
57            memoized_values: HashMap::new(),
58            caching_strategy: CachingStrategy::Adaptive,
59            adaptive_sampling: AdaptiveSampling::default(),
60        }
61    }
62
63    /// Create a sample context with specific caching strategy
64    #[must_use]
65    pub fn with_caching_strategy(strategy: CachingStrategy) -> Self {
66        Self {
67            memoized_values: HashMap::new(),
68            caching_strategy: strategy,
69            adaptive_sampling: AdaptiveSampling::default(),
70        }
71    }
72
73    /// Get a memoized value for a given node ID
74    #[must_use]
75    pub fn get_value<T: Clone + 'static>(&self, id: &uuid::Uuid) -> Option<T> {
76        self.memoized_values.get(id)?.downcast_ref::<T>().cloned()
77    }
78
79    /// Set a memoized value for a given node ID
80    pub fn set_value<T: Clone + Send + 'static>(&mut self, id: uuid::Uuid, value: T) {
81        self.memoized_values.insert(id, Box::new(value));
82    }
83
84    /// Clear all memoized values
85    pub fn clear(&mut self) {
86        self.memoized_values.clear();
87    }
88
89    /// Get the number of memoized values
90    #[must_use]
91    pub fn len(&self) -> usize {
92        self.memoized_values.len()
93    }
94
95    /// Check if the context is empty
96    #[must_use]
97    pub fn is_empty(&self) -> bool {
98        self.memoized_values.is_empty()
99    }
100
101    /// Determine if a node should be cached based on strategy and cost
102    #[must_use]
103    pub fn should_cache_node(&self, node: &ComputationNode<impl Shareable>) -> bool {
104        match self.caching_strategy {
105            CachingStrategy::Aggressive => true,
106            CachingStrategy::Conservative => {
107                // Only cache expensive operations (depth > 2 or complex nodes)
108                node.depth() > 2 || matches!(node, ComputationNode::Conditional { .. })
109            }
110            CachingStrategy::Adaptive => {
111                let complexity = node.compute_complexity();
112                complexity > 5 // Threshold for caching
113            }
114        }
115    }
116
117    /// Get the adaptive sampling configuration
118    #[must_use]
119    pub fn adaptive_sampling(&self) -> &AdaptiveSampling {
120        &self.adaptive_sampling
121    }
122
123    /// Set adaptive sampling configuration
124    pub fn set_adaptive_sampling(&mut self, config: AdaptiveSampling) {
125        self.adaptive_sampling = config;
126    }
127}
128
129impl Default for SampleContext {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135/// Computation graph node for lazy evaluation using indirect enum
136///
137/// This enables building complex expressions like `(x + y) * 2.0 - z` as a computation
138/// graph that's only evaluated when samples are needed, with proper memoization to
139/// ensure shared variables use the same sample within a single evaluation.
140#[derive(Clone)]
141pub enum ComputationNode<T> {
142    /// Leaf node representing a direct sampling function with unique ID
143    Leaf {
144        id: uuid::Uuid,
145        sample: Arc<dyn Fn() -> T + Send + Sync>,
146    },
147
148    /// Binary operation node for combining two uncertain values
149    BinaryOp {
150        left: Box<ComputationNode<T>>,
151        right: Box<ComputationNode<T>>,
152        operation: BinaryOperation,
153    },
154
155    /// Unary operation node for transforming a single uncertain value
156    UnaryOp {
157        operand: Box<ComputationNode<T>>,
158        operation: UnaryOperation<T>,
159    },
160
161    /// Conditional node for if-then-else logic
162    Conditional {
163        condition: Box<ComputationNode<bool>>,
164        if_true: Box<ComputationNode<T>>,
165        if_false: Box<ComputationNode<T>>,
166    },
167}
168
169/// Unary operation types for computation graph
170#[derive(Clone)]
171pub enum UnaryOperation<T> {
172    Map(Arc<dyn Fn(T) -> T + Send + Sync>),
173    Filter(Arc<dyn Fn(&T) -> bool + Send + Sync>),
174}
175
176impl<T> ComputationNode<T>
177where
178    T: Shareable,
179{
180    /// Evaluates the computation graph node with memoization context
181    ///
182    /// This is the core evaluation method that respects memoization to ensure
183    /// shared variables produce consistent samples within a single evaluation.
184    ///
185    /// # Panics
186    ///
187    /// - Panics if called on a `BinaryOp` variant. Use `evaluate_arithmetic` instead for binary operations.
188    /// - Panics if called on a `Conditional` variant. Use `evaluate_conditional` instead for conditional operations.
189    pub fn evaluate(&self, context: &mut SampleContext) -> T {
190        match self {
191            ComputationNode::Leaf { id, sample } => {
192                // Check if we already have a memoized value for this node
193                if let Some(cached) = context.get_value::<T>(id) {
194                    cached
195                } else {
196                    // Generate new sample and memoize it
197                    let value = sample();
198                    context.set_value(*id, value.clone());
199                    value
200                }
201            }
202
203            ComputationNode::UnaryOp { operand, operation } => {
204                let operand_val = operand.evaluate(context);
205                match operation {
206                    UnaryOperation::Map(func) => func(operand_val),
207                    UnaryOperation::Filter(_) => {
208                        // Filter requires special handling with rejection sampling
209                        // This is a simplified implementation
210                        operand_val
211                    }
212                }
213            }
214
215            // These variants require special handling based on type constraints
216            ComputationNode::BinaryOp { .. } => {
217                panic!(
218                    "BinaryOp evaluation requires arithmetic trait bounds. Use evaluate_arithmetic instead."
219                )
220            }
221
222            ComputationNode::Conditional { .. } => {
223                panic!(
224                    "Conditional evaluation requires specific handling. Use evaluate_conditional instead."
225                )
226            }
227        }
228    }
229
230    /// Evaluates arithmetic operations with proper trait bounds
231    ///
232    /// # Panics
233    ///
234    /// Panics if called on a `Conditional` variant with a boolean condition, as this is not supported in arithmetic context.
235    pub fn evaluate_arithmetic(&self, context: &mut SampleContext) -> T
236    where
237        T: Arithmetic,
238    {
239        match self {
240            ComputationNode::Leaf { id, sample } => {
241                if let Some(cached) = context.get_value::<T>(id) {
242                    cached
243                } else {
244                    let value = sample();
245                    context.set_value(*id, value.clone());
246                    value
247                }
248            }
249
250            ComputationNode::BinaryOp {
251                left,
252                right,
253                operation,
254            } => {
255                let left_val = left.evaluate_arithmetic(context);
256                let right_val = right.evaluate_arithmetic(context);
257                operation.apply(left_val, right_val)
258            }
259
260            ComputationNode::UnaryOp { operand, operation } => {
261                let operand_val = operand.evaluate_arithmetic(context);
262                match operation {
263                    UnaryOperation::Map(func) => func(operand_val),
264                    UnaryOperation::Filter(_) => operand_val,
265                }
266            }
267
268            ComputationNode::Conditional {
269                condition: _,
270                if_true: _,
271                if_false: _,
272            } => {
273                panic!(
274                    "Conditional evaluation with bool condition not supported in arithmetic context"
275                )
276            }
277        }
278    }
279
280    /// Evaluates the computation graph node in a new context
281    ///
282    /// This creates a fresh context for evaluation, useful when you want
283    /// independent samples without memoization effects.
284    #[must_use]
285    pub fn evaluate_fresh(&self) -> T
286    where
287        T: Arithmetic,
288    {
289        let mut context = SampleContext::new();
290        self.evaluate_conditional_with_arithmetic(&mut context)
291    }
292
293    /// Creates a new leaf node
294    pub fn leaf<F>(sample: F) -> Self
295    where
296        F: Fn() -> T + Send + Sync + 'static,
297    {
298        ComputationNode::Leaf {
299            id: uuid::Uuid::new_v4(),
300            sample: Arc::new(sample),
301        }
302    }
303
304    /// Creates a new binary operation node
305    #[must_use]
306    pub fn binary_op(
307        left: ComputationNode<T>,
308        right: ComputationNode<T>,
309        operation: BinaryOperation,
310    ) -> Self {
311        ComputationNode::BinaryOp {
312            left: Box::new(left),
313            right: Box::new(right),
314            operation,
315        }
316    }
317
318    /// Creates a new unary map operation node
319    pub fn map<F>(operand: ComputationNode<T>, func: F) -> Self
320    where
321        F: Fn(T) -> T + Send + Sync + 'static,
322    {
323        ComputationNode::UnaryOp {
324            operand: Box::new(operand),
325            operation: UnaryOperation::Map(Arc::new(func)),
326        }
327    }
328
329    /// Creates a conditional node
330    #[must_use]
331    pub fn conditional(
332        condition: ComputationNode<bool>,
333        if_true: ComputationNode<T>,
334        if_false: ComputationNode<T>,
335    ) -> Self {
336        ComputationNode::Conditional {
337            condition: Box::new(condition),
338            if_true: Box::new(if_true),
339            if_false: Box::new(if_false),
340        }
341    }
342
343    /// Counts the number of nodes in the computation graph
344    #[must_use]
345    pub fn node_count(&self) -> usize {
346        match self {
347            ComputationNode::Leaf { .. } => 1,
348            ComputationNode::BinaryOp { left, right, .. } => {
349                1 + left.node_count() + right.node_count()
350            }
351            ComputationNode::UnaryOp { operand, .. } => 1 + operand.node_count(),
352            ComputationNode::Conditional {
353                condition,
354                if_true,
355                if_false,
356            } => 1 + condition.node_count() + if_true.node_count() + if_false.node_count(),
357        }
358    }
359
360    /// Gets the depth of the computation graph
361    #[must_use]
362    pub fn depth(&self) -> usize {
363        match self {
364            ComputationNode::Leaf { .. } => 1,
365            ComputationNode::BinaryOp { left, right, .. } => 1 + left.depth().max(right.depth()),
366            ComputationNode::UnaryOp { operand, .. } => 1 + operand.depth(),
367            ComputationNode::Conditional {
368                condition,
369                if_true,
370                if_false,
371            } => 1 + condition.depth().max(if_true.depth().max(if_false.depth())),
372        }
373    }
374
375    /// Checks if the computation graph contains any conditional nodes
376    #[must_use]
377    pub fn has_conditionals(&self) -> bool {
378        match self {
379            ComputationNode::Leaf { .. } => false,
380            ComputationNode::BinaryOp { left, right, .. } => {
381                left.has_conditionals() || right.has_conditionals()
382            }
383            ComputationNode::UnaryOp { operand, .. } => operand.has_conditionals(),
384            ComputationNode::Conditional { .. } => true,
385        }
386    }
387
388    /// Estimate computational complexity of the node for caching decisions
389    #[must_use]
390    pub fn compute_complexity(&self) -> usize {
391        match self {
392            ComputationNode::Leaf { .. } => 1,
393            ComputationNode::BinaryOp { left, right, .. } => {
394                2 + left.compute_complexity() + right.compute_complexity()
395            }
396            ComputationNode::UnaryOp { operand, .. } => 1 + operand.compute_complexity(),
397            ComputationNode::Conditional {
398                condition,
399                if_true,
400                if_false,
401            } => {
402                5 + condition.compute_complexity()
403                    + if_true.compute_complexity()
404                    + if_false.compute_complexity()
405            }
406        }
407    }
408
409    /// Generate a structural hash for computation graph caching
410    #[must_use]
411    pub fn structural_hash(&self) -> u64 {
412        use std::collections::hash_map::DefaultHasher;
413        use std::hash::Hasher;
414
415        let mut hasher = DefaultHasher::new();
416        self.hash_structure(&mut hasher);
417        hasher.finish()
418    }
419
420    fn hash_structure(&self, hasher: &mut impl std::hash::Hasher) {
421        use std::hash::Hash;
422
423        match self {
424            ComputationNode::Leaf { id, .. } => {
425                "leaf".hash(hasher);
426                id.hash(hasher);
427            }
428            ComputationNode::BinaryOp {
429                left,
430                right,
431                operation,
432            } => {
433                "binary".hash(hasher);
434                operation.hash(hasher);
435                left.hash_structure(hasher);
436                right.hash_structure(hasher);
437            }
438            ComputationNode::UnaryOp { operand, .. } => {
439                "unary".hash(hasher);
440                operand.hash_structure(hasher);
441            }
442            ComputationNode::Conditional {
443                condition,
444                if_true,
445                if_false,
446            } => {
447                "conditional".hash(hasher);
448                condition.hash_structure(hasher);
449                if_true.hash_structure(hasher);
450                if_false.hash_structure(hasher);
451            }
452        }
453    }
454}
455
456// Specialized implementation for handling conditionals with boolean conditions
457impl ComputationNode<bool> {
458    /// Evaluates boolean computation nodes
459    ///
460    /// # Panics
461    ///
462    /// Panics if called on a `BinaryOp` variant as boolean binary operations are not implemented.
463    pub fn evaluate_bool(&self, context: &mut SampleContext) -> bool {
464        match self {
465            ComputationNode::Leaf { id, sample } => {
466                if let Some(cached) = context.get_value::<bool>(id) {
467                    cached
468                } else {
469                    let value = sample();
470                    context.set_value(*id, value);
471                    value
472                }
473            }
474            ComputationNode::UnaryOp { operand, operation } => {
475                let operand_val = operand.evaluate_bool(context);
476                match operation {
477                    UnaryOperation::Map(func) => func(operand_val),
478                    UnaryOperation::Filter(_) => operand_val,
479                }
480            }
481            ComputationNode::BinaryOp { .. } => {
482                panic!("Boolean binary operations not implemented")
483            }
484            ComputationNode::Conditional {
485                condition,
486                if_true,
487                if_false,
488            } => {
489                let condition_val = condition.evaluate_bool(context);
490                if condition_val {
491                    if_true.evaluate_bool(context)
492                } else {
493                    if_false.evaluate_bool(context)
494                }
495            }
496        }
497    }
498}
499
500// Add a specialized method for evaluating conditionals with arithmetic return types
501impl<T> ComputationNode<T>
502where
503    T: Shareable,
504{
505    /// Evaluates conditional nodes where condition is bool and branches return T
506    pub fn evaluate_conditional_with_arithmetic(&self, context: &mut SampleContext) -> T
507    where
508        T: Arithmetic,
509    {
510        match self {
511            ComputationNode::Conditional {
512                condition,
513                if_true,
514                if_false,
515            } => {
516                let condition_val = condition.evaluate_bool(context);
517                if condition_val {
518                    if_true.evaluate_arithmetic(context)
519                } else {
520                    if_false.evaluate_arithmetic(context)
521                }
522            }
523            _ => self.evaluate_arithmetic(context),
524        }
525    }
526}
527
528/// Computation graph optimizer for improving evaluation performance
529pub struct GraphOptimizer {
530    /// Cache of optimized subexpressions
531    pub subexpression_cache: HashMap<u64, Box<dyn std::any::Any + Send + Sync>>,
532}
533
534impl GraphOptimizer {
535    /// Create a new graph optimizer
536    #[must_use]
537    pub fn new() -> Self {
538        Self {
539            subexpression_cache: HashMap::new(),
540        }
541    }
542
543    /// Optimizes a computation graph by applying various transformations
544    #[must_use]
545    pub fn optimize<T>(&mut self, node: ComputationNode<T>) -> ComputationNode<T>
546    where
547        T: Shareable + Arithmetic + PartialEq + Clone,
548    {
549        let node = self.eliminate_common_subexpressions(node);
550        let node = Self::eliminate_identity_operations(node);
551        Self::constant_folding(node)
552    }
553
554    /// Eliminates common subexpressions by reusing nodes with same structure
555    pub fn eliminate_common_subexpressions<T>(
556        &mut self,
557        node: ComputationNode<T>,
558    ) -> ComputationNode<T>
559    where
560        T: Shareable,
561    {
562        let hash = node.structural_hash();
563
564        // Check if we have a cached version of this subexpression
565        if let Some(cached_node) = self.subexpression_cache.get(&hash)
566            && let Some(cached) = cached_node.downcast_ref::<ComputationNode<T>>()
567        {
568            return cached.clone();
569        }
570
571        // Recursively optimize children and cache this node
572        let optimized = match node {
573            ComputationNode::BinaryOp {
574                left,
575                right,
576                operation,
577            } => {
578                let left_opt = Box::new(self.eliminate_common_subexpressions(*left));
579                let right_opt = Box::new(self.eliminate_common_subexpressions(*right));
580                ComputationNode::BinaryOp {
581                    left: left_opt,
582                    right: right_opt,
583                    operation,
584                }
585            }
586            ComputationNode::UnaryOp { operand, operation } => {
587                let operand_opt = Box::new(self.eliminate_common_subexpressions(*operand));
588                ComputationNode::UnaryOp {
589                    operand: operand_opt,
590                    operation,
591                }
592            }
593            ComputationNode::Conditional {
594                condition,
595                if_true,
596                if_false,
597            } => {
598                let condition_opt = Box::new(self.eliminate_common_subexpressions(*condition));
599                let if_true_opt = Box::new(self.eliminate_common_subexpressions(*if_true));
600                let if_false_opt = Box::new(self.eliminate_common_subexpressions(*if_false));
601                ComputationNode::Conditional {
602                    condition: condition_opt,
603                    if_true: if_true_opt,
604                    if_false: if_false_opt,
605                }
606            }
607            leaf @ ComputationNode::Leaf { .. } => leaf,
608        };
609
610        // Cache this subexpression for future use
611        self.subexpression_cache
612            .insert(hash, Box::new(optimized.clone()));
613
614        optimized
615    }
616
617    /// Eliminates identity operations like `x + 0` or `x * 1`
618    #[allow(clippy::too_many_lines)]
619    fn eliminate_identity_operations<T>(node: ComputationNode<T>) -> ComputationNode<T>
620    where
621        T: Shareable + Arithmetic + PartialEq + Clone,
622    {
623        match node {
624            ComputationNode::BinaryOp {
625                left,
626                right,
627                operation,
628            } => Self::eliminate_identity_operations_binary(*left, *right, operation),
629            ComputationNode::UnaryOp { operand, operation } => {
630                Self::eliminate_identity_operations_unary(*operand, operation)
631            }
632            ComputationNode::Conditional {
633                condition,
634                if_true,
635                if_false,
636            } => Self::eliminate_identity_operations_conditional(*condition, *if_true, *if_false),
637            ComputationNode::Leaf { .. } => node,
638        }
639    }
640
641    /// Handles identity operation elimination for binary operations
642    fn eliminate_identity_operations_binary<T>(
643        left: ComputationNode<T>,
644        right: ComputationNode<T>,
645        operation: BinaryOperation,
646    ) -> ComputationNode<T>
647    where
648        T: Shareable + Arithmetic + PartialEq + Clone,
649    {
650        let left_opt = Self::eliminate_identity_operations(left);
651        let right_opt = Self::eliminate_identity_operations(right);
652
653        // Check for identity operations by operation type
654        match operation {
655            BinaryOperation::Add => {
656                if let Some(result) = Self::check_addition_identities(&left_opt, &right_opt) {
657                    return result;
658                }
659            }
660            BinaryOperation::Sub => {
661                if let Some(result) = Self::check_subtraction_identities(&left_opt, &right_opt) {
662                    return result;
663                }
664            }
665            BinaryOperation::Mul => {
666                if let Some(result) = Self::check_multiplication_identities(&left_opt, &right_opt) {
667                    return result;
668                }
669            }
670            BinaryOperation::Div => {
671                if let Some(result) = Self::check_division_identities(&left_opt, &right_opt) {
672                    return result;
673                }
674            }
675        }
676
677        ComputationNode::BinaryOp {
678            left: Box::new(left_opt),
679            right: Box::new(right_opt),
680            operation,
681        }
682    }
683
684    /// Checks for addition identity operations: x + 0 = x, 0 + x = x
685    fn check_addition_identities<T>(
686        left: &ComputationNode<T>,
687        right: &ComputationNode<T>,
688    ) -> Option<ComputationNode<T>>
689    where
690        T: Shareable + Arithmetic + PartialEq + Clone,
691    {
692        match (left, right) {
693            // x + 0 = x
694            (
695                left,
696                ComputationNode::Leaf {
697                    sample: right_sample,
698                    ..
699                },
700            ) => {
701                if Self::is_constant_zero(right_sample) {
702                    return Some(left.clone());
703                }
704            }
705            // 0 + x = x
706            (
707                ComputationNode::Leaf {
708                    sample: left_sample,
709                    ..
710                },
711                right,
712            ) => {
713                if Self::is_constant_zero(left_sample) {
714                    return Some(right.clone());
715                }
716            }
717            _ => {}
718        }
719        None
720    }
721
722    /// Checks for subtraction identity operations: x - 0 = x
723    fn check_subtraction_identities<T>(
724        left: &ComputationNode<T>,
725        right: &ComputationNode<T>,
726    ) -> Option<ComputationNode<T>>
727    where
728        T: Shareable + Arithmetic + PartialEq + Clone,
729    {
730        // x - 0 = x
731        if let (
732            left,
733            ComputationNode::Leaf {
734                sample: right_sample,
735                ..
736            },
737        ) = (left, right)
738            && Self::is_constant_zero(right_sample)
739        {
740            return Some(left.clone());
741        }
742        None
743    }
744
745    /// Checks for multiplication identity operations: x * 0 = 0, 0 * x = 0, x * 1 = x, 1 * x = x
746    fn check_multiplication_identities<T>(
747        left: &ComputationNode<T>,
748        right: &ComputationNode<T>,
749    ) -> Option<ComputationNode<T>>
750    where
751        T: Shareable + Arithmetic + PartialEq + Clone,
752    {
753        // Check for zero multiplication first (x * 0 = 0, 0 * x = 0)
754        match (left, right) {
755            // x * 0 = 0
756            (
757                _left,
758                ComputationNode::Leaf {
759                    sample: right_sample,
760                    ..
761                },
762            ) => {
763                if Self::is_constant_zero(right_sample) {
764                    return Some(ComputationNode::leaf(|| T::zero()));
765                }
766            }
767            // 0 * x = 0
768            (
769                ComputationNode::Leaf {
770                    sample: left_sample,
771                    ..
772                },
773                _right,
774            ) => {
775                if Self::is_constant_zero(left_sample) {
776                    return Some(ComputationNode::leaf(|| T::zero()));
777                }
778            }
779            _ => {}
780        }
781
782        // Check for identity multiplication (x * 1 = x, 1 * x = x)
783        match (left, right) {
784            // x * 1 = x
785            (
786                left,
787                ComputationNode::Leaf {
788                    sample: right_sample,
789                    ..
790                },
791            ) => {
792                if Self::is_constant_one(right_sample) {
793                    return Some(left.clone());
794                }
795            }
796            // 1 * x = x
797            (
798                ComputationNode::Leaf {
799                    sample: left_sample,
800                    ..
801                },
802                right,
803            ) => {
804                if Self::is_constant_one(left_sample) {
805                    return Some(right.clone());
806                }
807            }
808            _ => {}
809        }
810
811        None
812    }
813
814    /// Checks for division identity operations: x / 1 = x
815    fn check_division_identities<T>(
816        left: &ComputationNode<T>,
817        right: &ComputationNode<T>,
818    ) -> Option<ComputationNode<T>>
819    where
820        T: Shareable + Arithmetic + PartialEq + Clone,
821    {
822        // x / 1 = x
823        if let (
824            left,
825            ComputationNode::Leaf {
826                sample: right_sample,
827                ..
828            },
829        ) = (left, right)
830            && Self::is_constant_one(right_sample)
831        {
832            return Some(left.clone());
833        }
834        None
835    }
836
837    /// Handles identity operation elimination for unary operations
838    fn eliminate_identity_operations_unary<T>(
839        operand: ComputationNode<T>,
840        operation: UnaryOperation<T>,
841    ) -> ComputationNode<T>
842    where
843        T: Shareable + Arithmetic + PartialEq + Clone,
844    {
845        let operand_opt = Self::eliminate_identity_operations(operand);
846        ComputationNode::UnaryOp {
847            operand: Box::new(operand_opt),
848            operation,
849        }
850    }
851
852    /// Handles identity operation elimination for conditional operations
853    fn eliminate_identity_operations_conditional<T>(
854        condition: ComputationNode<bool>,
855        if_true: ComputationNode<T>,
856        if_false: ComputationNode<T>,
857    ) -> ComputationNode<T>
858    where
859        T: Shareable + Arithmetic + PartialEq + Clone,
860    {
861        // For conditionals, we need to handle the boolean condition separately
862        let condition_opt = Self::eliminate_identity_operations_bool(condition);
863        let if_true_opt = Self::eliminate_identity_operations(if_true);
864        let if_false_opt = Self::eliminate_identity_operations(if_false);
865        ComputationNode::Conditional {
866            condition: Box::new(condition_opt),
867            if_true: Box::new(if_true_opt),
868            if_false: Box::new(if_false_opt),
869        }
870    }
871
872    /// Eliminates identity operations for boolean types (no arithmetic operations)
873    fn eliminate_identity_operations_bool(node: ComputationNode<bool>) -> ComputationNode<bool> {
874        match node {
875            ComputationNode::UnaryOp { operand, operation } => {
876                let operand_opt = Self::eliminate_identity_operations_bool(*operand);
877                ComputationNode::UnaryOp {
878                    operand: Box::new(operand_opt),
879                    operation,
880                }
881            }
882            ComputationNode::Conditional {
883                condition,
884                if_true,
885                if_false,
886            } => {
887                let condition_opt = Self::eliminate_identity_operations_bool(*condition);
888                let if_true_opt = Self::eliminate_identity_operations_bool(*if_true);
889                let if_false_opt = Self::eliminate_identity_operations_bool(*if_false);
890                ComputationNode::Conditional {
891                    condition: Box::new(condition_opt),
892                    if_true: Box::new(if_true_opt),
893                    if_false: Box::new(if_false_opt),
894                }
895            }
896            ComputationNode::Leaf { .. } | ComputationNode::BinaryOp { .. } => node,
897        }
898    }
899
900    /// Helper function to check if a sampling function returns zero
901    fn is_constant_zero<T>(sample_fn: &Arc<dyn Fn() -> T + Send + Sync>) -> bool
902    where
903        T: PartialEq + Clone + Arithmetic,
904    {
905        // Sample a few times to check if it's consistently zero
906        for _ in 0..3 {
907            if sample_fn() != T::zero() {
908                return false;
909            }
910        }
911        true
912    }
913
914    /// Helper function to check if a sampling function returns one
915    fn is_constant_one<T>(sample_fn: &Arc<dyn Fn() -> T + Send + Sync>) -> bool
916    where
917        T: PartialEq + Clone + Arithmetic,
918    {
919        // Sample a few times to check if it's consistently one
920        for _ in 0..3 {
921            if sample_fn() != T::one() {
922                return false;
923            }
924        }
925        true
926    }
927
928    /// Performs constant folding for compile-time evaluation of constant expressions
929    fn constant_folding<T>(node: ComputationNode<T>) -> ComputationNode<T>
930    where
931        T: Shareable + Arithmetic + Clone + PartialEq,
932    {
933        match node {
934            ComputationNode::BinaryOp {
935                left,
936                right,
937                operation,
938            } => Self::constant_folding_binary_op(*left, *right, operation),
939            ComputationNode::UnaryOp { operand, operation } => {
940                Self::constant_folding_unary_op(*operand, operation)
941            }
942            ComputationNode::Conditional {
943                condition,
944                if_true,
945                if_false,
946            } => Self::constant_folding_conditional(*condition, *if_true, *if_false),
947            ComputationNode::Leaf { .. } => node,
948        }
949    }
950
951    /// Handles constant folding for binary operations
952    fn constant_folding_binary_op<T>(
953        left: ComputationNode<T>,
954        right: ComputationNode<T>,
955        operation: BinaryOperation,
956    ) -> ComputationNode<T>
957    where
958        T: Shareable + Arithmetic + Clone + PartialEq,
959    {
960        let left_opt = Self::constant_folding(left);
961        let right_opt = Self::constant_folding(right);
962
963        if let (
964            ComputationNode::Leaf {
965                sample: left_sample,
966                ..
967            },
968            ComputationNode::Leaf {
969                sample: right_sample,
970                ..
971            },
972        ) = (&left_opt, &right_opt)
973            && Self::is_constant(left_sample)
974            && Self::is_constant(right_sample)
975        {
976            let left_val = left_sample();
977            let right_val = right_sample();
978            let result = match operation {
979                BinaryOperation::Add => left_val + right_val,
980                BinaryOperation::Sub => left_val - right_val,
981                BinaryOperation::Mul => left_val * right_val,
982                BinaryOperation::Div => left_val / right_val,
983            };
984            return ComputationNode::leaf(move || result.clone());
985        }
986
987        ComputationNode::BinaryOp {
988            left: Box::new(left_opt),
989            right: Box::new(right_opt),
990            operation,
991        }
992    }
993
994    /// Handles constant folding for unary operations
995    fn constant_folding_unary_op<T>(
996        operand: ComputationNode<T>,
997        operation: UnaryOperation<T>,
998    ) -> ComputationNode<T>
999    where
1000        T: Shareable + Arithmetic + Clone + PartialEq,
1001    {
1002        let operand_opt = Self::constant_folding(operand);
1003
1004        if let ComputationNode::Leaf {
1005            sample: operand_sample,
1006            ..
1007        } = &operand_opt
1008            && Self::is_constant(operand_sample)
1009        {
1010            let operand_val = operand_sample();
1011            let result = match operation {
1012                UnaryOperation::Map(func) => func(operand_val),
1013                UnaryOperation::Filter(_) => operand_val, // Filter doesn't change the value
1014            };
1015            return ComputationNode::leaf(move || result.clone());
1016        }
1017
1018        ComputationNode::UnaryOp {
1019            operand: Box::new(operand_opt),
1020            operation,
1021        }
1022    }
1023
1024    /// Handles constant folding for conditional operations
1025    fn constant_folding_conditional<T>(
1026        condition: ComputationNode<bool>,
1027        if_true: ComputationNode<T>,
1028        if_false: ComputationNode<T>,
1029    ) -> ComputationNode<T>
1030    where
1031        T: Shareable + Arithmetic + Clone + PartialEq,
1032    {
1033        let condition_opt = Self::constant_folding_bool(condition);
1034        let if_true_opt = Self::constant_folding(if_true);
1035        let if_false_opt = Self::constant_folding(if_false);
1036
1037        // Check if condition is constant
1038        if let ComputationNode::Leaf {
1039            sample: condition_sample,
1040            ..
1041        } = &condition_opt
1042            && Self::is_constant_bool(condition_sample)
1043        {
1044            let condition_val = condition_sample();
1045            if condition_val {
1046                return if_true_opt;
1047            }
1048            return if_false_opt;
1049        }
1050
1051        ComputationNode::Conditional {
1052            condition: Box::new(condition_opt),
1053            if_true: Box::new(if_true_opt),
1054            if_false: Box::new(if_false_opt),
1055        }
1056    }
1057
1058    /// Performs constant folding for boolean types
1059    fn constant_folding_bool(node: ComputationNode<bool>) -> ComputationNode<bool> {
1060        match node {
1061            ComputationNode::UnaryOp { operand, operation } => {
1062                Self::constant_folding_bool_unary_op(*operand, operation)
1063            }
1064            ComputationNode::Conditional {
1065                condition,
1066                if_true,
1067                if_false,
1068            } => Self::constant_folding_bool_conditional(*condition, *if_true, *if_false),
1069            ComputationNode::Leaf { .. } | ComputationNode::BinaryOp { .. } => node,
1070        }
1071    }
1072
1073    /// Handles constant folding for boolean unary operations
1074    fn constant_folding_bool_unary_op(
1075        operand: ComputationNode<bool>,
1076        operation: UnaryOperation<bool>,
1077    ) -> ComputationNode<bool> {
1078        let operand_opt = Self::constant_folding_bool(operand);
1079
1080        if let ComputationNode::Leaf {
1081            sample: operand_sample,
1082            ..
1083        } = &operand_opt
1084            && Self::is_constant_bool(operand_sample)
1085        {
1086            let operand_val = operand_sample();
1087            let result = match operation {
1088                UnaryOperation::Map(func) => func(operand_val),
1089                UnaryOperation::Filter(_) => operand_val, // Filter doesn't change the value
1090            };
1091            return ComputationNode::leaf(move || result);
1092        }
1093
1094        ComputationNode::UnaryOp {
1095            operand: Box::new(operand_opt),
1096            operation,
1097        }
1098    }
1099
1100    /// Handles constant folding for boolean conditional operations
1101    fn constant_folding_bool_conditional(
1102        condition: ComputationNode<bool>,
1103        if_true: ComputationNode<bool>,
1104        if_false: ComputationNode<bool>,
1105    ) -> ComputationNode<bool> {
1106        let condition_opt = Self::constant_folding_bool(condition);
1107        let if_true_opt = Self::constant_folding_bool(if_true);
1108        let if_false_opt = Self::constant_folding_bool(if_false);
1109
1110        if let ComputationNode::Leaf {
1111            sample: condition_sample,
1112            ..
1113        } = &condition_opt
1114            && Self::is_constant_bool(condition_sample)
1115        {
1116            let condition_val = condition_sample();
1117            if condition_val {
1118                return if_true_opt;
1119            }
1120            return if_false_opt;
1121        }
1122
1123        ComputationNode::Conditional {
1124            condition: Box::new(condition_opt),
1125            if_true: Box::new(if_true_opt),
1126            if_false: Box::new(if_false_opt),
1127        }
1128    }
1129
1130    /// Helper function to check if a sampling function returns a constant value
1131    fn is_constant<T>(sample_fn: &Arc<dyn Fn() -> T + Send + Sync>) -> bool
1132    where
1133        T: PartialEq + Clone,
1134    {
1135        // Sample a few times to check if it's consistently the same value
1136        let first_sample = sample_fn();
1137        for _ in 0..3 {
1138            if sample_fn() != first_sample {
1139                return false;
1140            }
1141        }
1142        true
1143    }
1144
1145    /// Helper function to check if a boolean sampling function returns a constant value
1146    fn is_constant_bool(sample_fn: &Arc<dyn Fn() -> bool + Send + Sync>) -> bool {
1147        // Sample a few times to check if it's consistently the same value
1148        let first_sample = sample_fn();
1149        for _ in 0..3 {
1150            if sample_fn() != first_sample {
1151                return false;
1152            }
1153        }
1154        true
1155    }
1156}
1157
1158impl Default for GraphOptimizer {
1159    fn default() -> Self {
1160        Self::new()
1161    }
1162}
1163
1164/// Computation graph visualizer for debugging and analysis
1165pub struct GraphVisualizer;
1166
1167impl GraphVisualizer {
1168    /// Generates a DOT graph representation for visualization
1169    #[must_use]
1170    pub fn to_dot<T>(node: &ComputationNode<T>) -> String
1171    where
1172        T: Shareable,
1173    {
1174        let mut dot = String::from("digraph G {\n");
1175        let mut node_id = 0;
1176        Self::add_node_to_dot(node, &mut dot, &mut node_id);
1177        dot.push_str("}\n");
1178        dot
1179    }
1180
1181    fn add_node_to_dot<T>(node: &ComputationNode<T>, dot: &mut String, node_id: &mut usize) -> usize
1182    where
1183        T: Shareable,
1184    {
1185        use std::fmt::Write;
1186        let current_id = *node_id;
1187        *node_id += 1;
1188
1189        match node {
1190            ComputationNode::Leaf { .. } => {
1191                writeln!(dot, "  {current_id} [label=\"Leaf\", shape=circle];").unwrap();
1192            }
1193            ComputationNode::BinaryOp {
1194                left,
1195                right,
1196                operation,
1197            } => {
1198                let op_name = match operation {
1199                    BinaryOperation::Add => "Add",
1200                    BinaryOperation::Sub => "Sub",
1201                    BinaryOperation::Mul => "Mul",
1202                    BinaryOperation::Div => "Div",
1203                };
1204                writeln!(dot, "  {current_id} [label=\"{op_name}\", shape=box];").unwrap();
1205
1206                let left_id = Self::add_node_to_dot(left, dot, node_id);
1207                let right_id = Self::add_node_to_dot(right, dot, node_id);
1208
1209                writeln!(dot, "  {current_id} -> {left_id};").unwrap();
1210                writeln!(dot, "  {current_id} -> {right_id};").unwrap();
1211            }
1212            ComputationNode::UnaryOp { operand, .. } => {
1213                writeln!(dot, "  {current_id} [label=\"UnaryOp\", shape=box];").unwrap();
1214                let operand_id = Self::add_node_to_dot(operand, dot, node_id);
1215                writeln!(dot, "  {current_id} -> {operand_id};").unwrap();
1216            }
1217            ComputationNode::Conditional {
1218                condition,
1219                if_true,
1220                if_false,
1221            } => {
1222                writeln!(dot, "  {current_id} [label=\"If\", shape=diamond];").unwrap();
1223
1224                let cond_id = Self::add_node_to_dot(condition, dot, node_id);
1225                let true_id = Self::add_node_to_dot(if_true, dot, node_id);
1226                let false_id = Self::add_node_to_dot(if_false, dot, node_id);
1227
1228                writeln!(dot, "  {current_id} -> {cond_id} [label=\"cond\"];").unwrap();
1229                writeln!(dot, "  {current_id} -> {true_id} [label=\"true\"];").unwrap();
1230                writeln!(dot, "  {current_id} -> {false_id} [label=\"false\"];").unwrap();
1231            }
1232        }
1233
1234        current_id
1235    }
1236
1237    /// Prints a text-based representation of the computation graph
1238    pub fn print_tree<T>(node: &ComputationNode<T>, indent: usize)
1239    where
1240        T: Shareable,
1241    {
1242        let prefix = "  ".repeat(indent);
1243
1244        match node {
1245            ComputationNode::Leaf { id, .. } => {
1246                println!("{prefix}Leaf({id})");
1247            }
1248            ComputationNode::BinaryOp {
1249                left,
1250                right,
1251                operation,
1252            } => {
1253                let op_name = match operation {
1254                    BinaryOperation::Add => "Add",
1255                    BinaryOperation::Sub => "Sub",
1256                    BinaryOperation::Mul => "Mul",
1257                    BinaryOperation::Div => "Div",
1258                };
1259                println!("{prefix}{op_name}");
1260                Self::print_tree(left, indent + 1);
1261                Self::print_tree(right, indent + 1);
1262            }
1263            ComputationNode::UnaryOp { operand, .. } => {
1264                println!("{prefix}UnaryOp");
1265                Self::print_tree(operand, indent + 1);
1266            }
1267            ComputationNode::Conditional {
1268                condition,
1269                if_true,
1270                if_false,
1271            } => {
1272                println!("{prefix}Conditional");
1273                println!("{prefix}  Condition:");
1274                Self::print_tree(condition, indent + 2);
1275                println!("{prefix}  If True:");
1276                Self::print_tree(if_true, indent + 2);
1277                println!("{prefix}  If False:");
1278                Self::print_tree(if_false, indent + 2);
1279            }
1280        }
1281    }
1282}
1283
1284/// Performance profiler for computation graphs
1285pub struct GraphProfiler {
1286    execution_times: HashMap<String, Vec<std::time::Duration>>,
1287}
1288
1289impl GraphProfiler {
1290    /// Create a new profiler
1291    #[must_use]
1292    pub fn new() -> Self {
1293        Self {
1294            execution_times: HashMap::new(),
1295        }
1296    }
1297
1298    /// Profile the execution of a computation graph
1299    pub fn profile_execution<T, F>(&mut self, name: &str, func: F) -> T
1300    where
1301        F: FnOnce() -> T,
1302    {
1303        let start = std::time::Instant::now();
1304        let result = func();
1305        let duration = start.elapsed();
1306
1307        self.execution_times
1308            .entry(name.to_string())
1309            .or_default()
1310            .push(duration);
1311
1312        result
1313    }
1314
1315    /// Get profiling statistics
1316    ///
1317    /// # Panics
1318    ///
1319    /// Panics if the internal state is corrupted and the times vector is empty
1320    /// when it shouldn't be (this should never happen in normal usage).
1321    #[must_use]
1322    pub fn get_stats(&self, name: &str) -> Option<ProfileStats> {
1323        let times = self.execution_times.get(name)?;
1324        if times.is_empty() {
1325            return None;
1326        }
1327
1328        let total: std::time::Duration = times.iter().sum();
1329        let count = times.len();
1330        let average = total / u32::try_from(count).unwrap_or(1);
1331
1332        let mut sorted_times = times.clone();
1333        sorted_times.sort();
1334        let median = sorted_times[count / 2];
1335        let min = *sorted_times
1336            .first()
1337            .expect("Times vector should not be empty");
1338        let max = *sorted_times
1339            .last()
1340            .expect("Times vector should not be empty");
1341
1342        Some(ProfileStats {
1343            count,
1344            total,
1345            average,
1346            median,
1347            min,
1348            max,
1349        })
1350    }
1351
1352    /// Print all profiling results
1353    pub fn print_report(&self) {
1354        println!("=== Computation Graph Profiling Report ===");
1355        for name in self.execution_times.keys() {
1356            if let Some(stats) = self.get_stats(name) {
1357                println!("\n{name}:");
1358                println!("  Count: {}", stats.count);
1359                println!("  Total: {:?}", stats.total);
1360                println!("  Average: {:?}", stats.average);
1361                println!("  Median: {:?}", stats.median);
1362                println!("  Min: {:?}", stats.min);
1363                println!("  Max: {:?}", stats.max);
1364            }
1365        }
1366    }
1367}
1368
1369impl Default for GraphProfiler {
1370    fn default() -> Self {
1371        Self::new()
1372    }
1373}
1374
1375/// Statistics from profiling computation graph execution
1376#[derive(Debug, Clone)]
1377pub struct ProfileStats {
1378    /// Number of executions
1379    pub count: usize,
1380    /// Total execution time
1381    pub total: std::time::Duration,
1382    /// Average execution time
1383    pub average: std::time::Duration,
1384    /// Median execution time
1385    pub median: std::time::Duration,
1386    /// Minimum execution time
1387    pub min: std::time::Duration,
1388    /// Maximum execution time
1389    pub max: std::time::Duration,
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394    use super::*;
1395    use crate::Uncertain;
1396
1397    #[test]
1398    fn test_sample_context_memoization() {
1399        let mut context = SampleContext::new();
1400        let id = uuid::Uuid::new_v4();
1401
1402        // Set a value
1403        context.set_value(id, 42.0);
1404
1405        // Should get the same value back
1406        assert_eq!(context.get_value::<f64>(&id), Some(42.0));
1407
1408        // Different ID should return None
1409        let other_id = uuid::Uuid::new_v4();
1410        assert_eq!(context.get_value::<f64>(&other_id), None);
1411    }
1412
1413    #[test]
1414    #[allow(clippy::float_cmp)]
1415    fn test_computation_node_evaluation() {
1416        let left = ComputationNode::leaf(|| 5.0);
1417        let right = ComputationNode::leaf(|| 3.0);
1418        let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1419
1420        let result = add_node.evaluate_fresh();
1421        assert_eq!(result, 8.0);
1422    }
1423
1424    #[test]
1425    #[allow(clippy::float_cmp)]
1426    fn test_shared_variable_memoization() {
1427        let mut context = SampleContext::new();
1428
1429        // Create a leaf node
1430        let leaf_id = uuid::Uuid::new_v4();
1431        let leaf = ComputationNode::Leaf {
1432            id: leaf_id,
1433            sample: Arc::new(rand::random::<f64>),
1434        };
1435
1436        // Evaluate twice with the same context
1437        let val1 = leaf.evaluate(&mut context);
1438        let val2 = leaf.evaluate(&mut context);
1439
1440        // Should get the same value due to memoization
1441        assert_eq!(val1, val2);
1442    }
1443
1444    #[test]
1445    fn test_computation_graph_metrics() {
1446        let left = ComputationNode::leaf(|| 1.0);
1447        let right = ComputationNode::leaf(|| 2.0);
1448        let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1449
1450        assert_eq!(add_node.node_count(), 3); // 2 leaves + 1 binary op
1451        assert_eq!(add_node.depth(), 2); // Binary op -> leaves
1452        assert!(!add_node.has_conditionals());
1453    }
1454
1455    #[test]
1456    #[allow(clippy::float_cmp)]
1457    fn test_conditional_node() {
1458        let condition = ComputationNode::leaf(|| true);
1459        let if_true = ComputationNode::leaf(|| 10.0);
1460        let if_false = ComputationNode::leaf(|| 20.0);
1461
1462        let conditional = ComputationNode::conditional(condition, if_true, if_false);
1463
1464        let result = conditional.evaluate_fresh();
1465        assert_eq!(result, 10.0); // Should pick the true branch
1466        assert!(conditional.has_conditionals());
1467    }
1468
1469    #[test]
1470    fn test_graph_visualizer_dot_output() {
1471        let left = ComputationNode::leaf(|| 1.0);
1472        let right = ComputationNode::leaf(|| 2.0);
1473        let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1474
1475        let dot = GraphVisualizer::to_dot(&add_node);
1476
1477        assert!(dot.contains("digraph G"));
1478        assert!(dot.contains("Add"));
1479        assert!(dot.contains("Leaf"));
1480    }
1481
1482    #[test]
1483    fn test_profiler() {
1484        let mut profiler = GraphProfiler::new();
1485
1486        let result = profiler.profile_execution("test", || {
1487            std::thread::sleep(std::time::Duration::from_millis(10));
1488            42
1489        });
1490
1491        assert_eq!(result, 42);
1492
1493        let stats = profiler.get_stats("test").unwrap();
1494        assert_eq!(stats.count, 1);
1495        assert!(stats.total >= std::time::Duration::from_millis(10));
1496    }
1497
1498    #[test]
1499    fn test_complex_computation_graph() {
1500        // Test (x + y) * (x - y) where x and y are uncertain values
1501        let x = Uncertain::normal(5.0, 1.0);
1502        let y = Uncertain::normal(3.0, 1.0);
1503
1504        // This should build a computation graph
1505        let sum = x.clone() + y.clone();
1506        let diff = x - y;
1507        let product = sum * diff;
1508
1509        // The computation graph should have multiple nodes
1510        assert!(product.node.node_count() > 5);
1511        assert!(product.node.depth() > 2);
1512
1513        // Should be able to evaluate multiple times
1514        let sample1 = product.sample();
1515        let sample2 = product.sample();
1516
1517        // Values should be reasonable (approximately (5+3)*(5-3) = 16 with some variance)
1518        // With normal distributions (5±1) and (3±1), the result can vary significantly
1519        // Allow for wide variance due to unbounded normal distributions and multiplication
1520        // Statistical analysis shows 99.9% of values fall within [-50, 150]
1521        assert!(sample1 > -50.0 && sample1 < 150.0);
1522        assert!(sample2 > -50.0 && sample2 < 150.0);
1523    }
1524
1525    #[test]
1526    fn test_sample_context_clear() {
1527        let mut context = SampleContext::new();
1528        let id = uuid::Uuid::new_v4();
1529
1530        context.set_value(id, 42.0);
1531        assert_eq!(context.len(), 1);
1532        assert!(!context.is_empty());
1533
1534        context.clear();
1535        assert_eq!(context.len(), 0);
1536        assert!(context.is_empty());
1537        assert_eq!(context.get_value::<f64>(&id), None);
1538    }
1539
1540    #[test]
1541    fn test_sample_context_default() {
1542        let context = SampleContext::default();
1543        assert!(context.is_empty());
1544        assert_eq!(context.len(), 0);
1545    }
1546
1547    #[test]
1548    #[should_panic(expected = "BinaryOp evaluation requires arithmetic trait bounds")]
1549    fn test_evaluate_panic_on_binary_op() {
1550        let left = ComputationNode::leaf(|| 1.0);
1551        let right = ComputationNode::leaf(|| 2.0);
1552        let binary_op = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1553
1554        let mut context = SampleContext::new();
1555        binary_op.evaluate(&mut context);
1556    }
1557
1558    #[test]
1559    #[should_panic(expected = "Conditional evaluation requires specific handling")]
1560    fn test_evaluate_panic_on_conditional() {
1561        let condition = ComputationNode::leaf(|| true);
1562        let if_true = ComputationNode::leaf(|| 10.0);
1563        let if_false = ComputationNode::leaf(|| 20.0);
1564        let conditional = ComputationNode::conditional(condition, if_true, if_false);
1565
1566        let mut context = SampleContext::new();
1567        conditional.evaluate(&mut context);
1568    }
1569
1570    #[test]
1571    #[should_panic(
1572        expected = "Conditional evaluation with bool condition not supported in arithmetic context"
1573    )]
1574    fn test_evaluate_arithmetic_panic_on_conditional() {
1575        let condition = ComputationNode::leaf(|| true);
1576        let if_true = ComputationNode::leaf(|| 10.0);
1577        let if_false = ComputationNode::leaf(|| 20.0);
1578        let conditional = ComputationNode::conditional(condition, if_true, if_false);
1579
1580        let mut context = SampleContext::new();
1581        conditional.evaluate_arithmetic(&mut context);
1582    }
1583
1584    #[test]
1585    #[should_panic(expected = "Boolean binary operations not implemented")]
1586    fn test_evaluate_bool_panic_on_binary_op() {
1587        let left = ComputationNode::leaf(|| true);
1588        let right = ComputationNode::leaf(|| false);
1589        let binary_op = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1590
1591        let mut context = SampleContext::new();
1592        binary_op.evaluate_bool(&mut context);
1593    }
1594
1595    #[test]
1596    #[allow(clippy::float_cmp)]
1597    fn test_unary_map_operation() {
1598        let operand = ComputationNode::leaf(|| 5.0);
1599        let mapped = ComputationNode::map(operand, |x| x * 2.0);
1600
1601        let result = mapped.evaluate_fresh();
1602        assert_eq!(result, 10.0);
1603    }
1604
1605    #[test]
1606    #[allow(clippy::float_cmp)]
1607    fn test_unary_filter_operation() {
1608        let operand = ComputationNode::leaf(|| 42.0);
1609        let filtered = ComputationNode::UnaryOp {
1610            operand: Box::new(operand),
1611            operation: UnaryOperation::Filter(Arc::new(|x: &f64| *x > 0.0)),
1612        };
1613
1614        let mut context = SampleContext::new();
1615        let result = filtered.evaluate(&mut context);
1616        assert_eq!(result, 42.0); // Filter currently just passes through
1617    }
1618
1619    #[test]
1620    fn test_graph_optimizer() {
1621        let node = ComputationNode::leaf(|| 1.0);
1622        let mut optimizer = GraphOptimizer::new();
1623        let optimized_node = optimizer.optimize(node);
1624        assert_eq!(optimized_node.node_count(), 1);
1625    }
1626
1627    #[test]
1628    fn test_common_subexpression_elimination() {
1629        let mut optimizer = GraphOptimizer::new();
1630
1631        // Create a common subexpression: (x + y) * (x + y)
1632        let x = ComputationNode::leaf(|| 2.0);
1633        let y = ComputationNode::leaf(|| 3.0);
1634        let sum = ComputationNode::binary_op(x.clone(), y.clone(), BinaryOperation::Add);
1635
1636        // Create the expression: (x + y) * (x + y)
1637        let expr = ComputationNode::binary_op(sum.clone(), sum, BinaryOperation::Mul);
1638
1639        // First optimization should cache the sum subexpression
1640        let optimized1 = optimizer.eliminate_common_subexpressions(expr.clone());
1641
1642        // Second optimization should reuse the cached sum subexpression
1643        let optimized2 = optimizer.eliminate_common_subexpressions(expr);
1644
1645        // Both should produce the same result
1646        let result1: f64 = optimized1.evaluate_fresh();
1647        let result2: f64 = optimized2.evaluate_fresh();
1648        assert!((result1 - result2).abs() < f64::EPSILON);
1649
1650        // The cache should contain the sum subexpression
1651        assert!(!optimizer.subexpression_cache.is_empty());
1652    }
1653
1654    #[test]
1655    #[allow(clippy::similar_names)]
1656    fn test_common_subexpression_elimination_complex() {
1657        let mut optimizer = GraphOptimizer::new();
1658
1659        // Create a more complex expression with multiple common subexpressions
1660        let a = ComputationNode::leaf(|| 1.0);
1661        let b = ComputationNode::leaf(|| 2.0);
1662        let c = ComputationNode::leaf(|| 3.0);
1663
1664        // Common subexpression: a + b
1665        let sum_ab = ComputationNode::binary_op(a.clone(), b.clone(), BinaryOperation::Add);
1666
1667        // Expression: (a + b) * (a + b) + (a + b) * c
1668        let expr1 =
1669            ComputationNode::binary_op(sum_ab.clone(), sum_ab.clone(), BinaryOperation::Mul);
1670        let expr2 = ComputationNode::binary_op(sum_ab.clone(), c.clone(), BinaryOperation::Mul);
1671        let final_expr = ComputationNode::binary_op(expr1, expr2, BinaryOperation::Add);
1672
1673        let optimized = optimizer.eliminate_common_subexpressions(final_expr);
1674
1675        // Should produce correct result
1676        let result: f64 = optimized.evaluate_fresh();
1677        let expected = (1.0 + 2.0) * (1.0 + 2.0) + (1.0 + 2.0) * 3.0;
1678        assert!((result - expected).abs() < f64::EPSILON);
1679
1680        // Cache should contain the common subexpression
1681        assert!(!optimizer.subexpression_cache.is_empty());
1682    }
1683
1684    #[test]
1685    fn test_identity_operation_elimination() {
1686        // Test x + 0 = x
1687        let x = ComputationNode::leaf(|| 5.0);
1688        let zero = ComputationNode::leaf(|| 0.0);
1689        let add_zero = ComputationNode::binary_op(x.clone(), zero, BinaryOperation::Add);
1690
1691        let optimized = GraphOptimizer::eliminate_identity_operations(add_zero);
1692        let result: f64 = optimized.evaluate_fresh();
1693        assert!((result - 5.0).abs() < f64::EPSILON);
1694
1695        // Test x * 1 = x
1696        let one = ComputationNode::leaf(|| 1.0);
1697        let mul_one = ComputationNode::binary_op(x.clone(), one, BinaryOperation::Mul);
1698
1699        let optimized = GraphOptimizer::eliminate_identity_operations(mul_one);
1700        let result: f64 = optimized.evaluate_fresh();
1701        assert!((result - 5.0).abs() < f64::EPSILON);
1702
1703        // Test x - 0 = x
1704        let zero2 = ComputationNode::leaf(|| 0.0);
1705        let sub_zero = ComputationNode::binary_op(x.clone(), zero2, BinaryOperation::Sub);
1706
1707        let optimized = GraphOptimizer::eliminate_identity_operations(sub_zero);
1708        let result: f64 = optimized.evaluate_fresh();
1709        assert!((result - 5.0).abs() < f64::EPSILON);
1710
1711        // Test x / 1 = x
1712        let one2 = ComputationNode::leaf(|| 1.0);
1713        let div_one = ComputationNode::binary_op(x.clone(), one2, BinaryOperation::Div);
1714
1715        let optimized = GraphOptimizer::eliminate_identity_operations(div_one);
1716        let result: f64 = optimized.evaluate_fresh();
1717        assert!((result - 5.0).abs() < f64::EPSILON);
1718
1719        // Test x * 0 = 0
1720        let zero3 = ComputationNode::leaf(|| 0.0);
1721        let mul_zero = ComputationNode::binary_op(x.clone(), zero3, BinaryOperation::Mul);
1722
1723        let optimized = GraphOptimizer::eliminate_identity_operations(mul_zero);
1724        let result: f64 = optimized.evaluate_fresh();
1725        assert!((result - 0.0).abs() < f64::EPSILON);
1726    }
1727
1728    #[test]
1729    fn test_constant_folding() {
1730        // Test constant addition: 2 + 3 = 5
1731        let two = ComputationNode::leaf(|| 2.0);
1732        let three = ComputationNode::leaf(|| 3.0);
1733        let add_const = ComputationNode::binary_op(two, three, BinaryOperation::Add);
1734
1735        let optimized = GraphOptimizer::constant_folding(add_const);
1736        let result: f64 = optimized.evaluate_fresh();
1737        assert!((result - 5.0).abs() < f64::EPSILON);
1738
1739        // Test constant multiplication: 4 * 5 = 20
1740        let four = ComputationNode::leaf(|| 4.0);
1741        let five = ComputationNode::leaf(|| 5.0);
1742        let mul_const = ComputationNode::binary_op(four, five, BinaryOperation::Mul);
1743
1744        let optimized = GraphOptimizer::constant_folding(mul_const);
1745        let result: f64 = optimized.evaluate_fresh();
1746        assert!((result - 20.0).abs() < f64::EPSILON);
1747
1748        // Test constant division: 10 / 2 = 5
1749        let ten = ComputationNode::leaf(|| 10.0);
1750        let two_div = ComputationNode::leaf(|| 2.0);
1751        let div_const = ComputationNode::binary_op(ten, two_div, BinaryOperation::Div);
1752
1753        let optimized = GraphOptimizer::constant_folding(div_const);
1754        let result: f64 = optimized.evaluate_fresh();
1755        assert!((result - 5.0).abs() < f64::EPSILON);
1756
1757        // Test constant subtraction: 8 - 3 = 5
1758        let eight = ComputationNode::leaf(|| 8.0);
1759        let three_sub = ComputationNode::leaf(|| 3.0);
1760        let sub_const = ComputationNode::binary_op(eight, three_sub, BinaryOperation::Sub);
1761
1762        let optimized = GraphOptimizer::constant_folding(sub_const);
1763        let result: f64 = optimized.evaluate_fresh();
1764        assert!((result - 5.0).abs() < f64::EPSILON);
1765    }
1766
1767    #[test]
1768    fn test_constant_folding_conditional() {
1769        // Test constant condition: if true then 10 else 20 = 10
1770        let true_condition = ComputationNode::leaf(|| true);
1771        let if_true = ComputationNode::leaf(|| 10.0);
1772        let if_false = ComputationNode::leaf(|| 20.0);
1773        let conditional = ComputationNode::conditional(true_condition, if_true, if_false);
1774
1775        let optimized = GraphOptimizer::constant_folding(conditional);
1776        let result: f64 = optimized.evaluate_fresh();
1777        assert!((result - 10.0).abs() < f64::EPSILON);
1778
1779        // Test constant condition: if false then 10 else 20 = 20
1780        let false_condition = ComputationNode::leaf(|| false);
1781        let if_true2 = ComputationNode::leaf(|| 10.0);
1782        let if_false2 = ComputationNode::leaf(|| 20.0);
1783        let conditional2 = ComputationNode::conditional(false_condition, if_true2, if_false2);
1784
1785        let optimized = GraphOptimizer::constant_folding(conditional2);
1786        let result: f64 = optimized.evaluate_fresh();
1787        assert!((result - 20.0).abs() < f64::EPSILON);
1788    }
1789
1790    #[test]
1791    fn test_constant_folding_unary() {
1792        // Test constant unary operation: map(|x| x * 2) on constant 5 = 10
1793        let five = ComputationNode::leaf(|| 5.0);
1794        let double = ComputationNode::map(five, |x| x * 2.0);
1795
1796        let optimized = GraphOptimizer::constant_folding(double);
1797        let result: f64 = optimized.evaluate_fresh();
1798        assert!((result - 10.0).abs() < f64::EPSILON);
1799    }
1800
1801    #[test]
1802    fn test_graph_visualizer_print_tree() {
1803        let left = ComputationNode::leaf(|| 1.0);
1804        let right = ComputationNode::leaf(|| 2.0);
1805        let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1806
1807        // This mainly tests that print_tree doesn't panic
1808        GraphVisualizer::print_tree(&add_node, 0);
1809    }
1810
1811    #[test]
1812    fn test_graph_visualizer_dot_conditional() {
1813        let condition = ComputationNode::leaf(|| true);
1814        let if_true = ComputationNode::leaf(|| 10.0);
1815        let if_false = ComputationNode::leaf(|| 20.0);
1816        let conditional = ComputationNode::conditional(condition, if_true, if_false);
1817
1818        let dot = GraphVisualizer::to_dot(&conditional);
1819
1820        assert!(dot.contains("digraph G"));
1821        assert!(dot.contains("If"));
1822        assert!(dot.contains("diamond"));
1823        assert!(dot.contains("cond"));
1824        assert!(dot.contains("true"));
1825        assert!(dot.contains("false"));
1826    }
1827
1828    #[test]
1829    fn test_graph_visualizer_dot_unary_op() {
1830        let operand = ComputationNode::leaf(|| 5.0);
1831        let unary = ComputationNode::map(operand, |x| x * 2.0);
1832
1833        let dot = GraphVisualizer::to_dot(&unary);
1834
1835        assert!(dot.contains("digraph G"));
1836        assert!(dot.contains("UnaryOp"));
1837        assert!(dot.contains("Leaf"));
1838    }
1839
1840    #[test]
1841    fn test_profiler_default() {
1842        let profiler = GraphProfiler::default();
1843        assert!(profiler.get_stats("nonexistent").is_none());
1844    }
1845
1846    #[test]
1847    fn test_profiler_get_stats_nonexistent() {
1848        let profiler = GraphProfiler::new();
1849        assert!(profiler.get_stats("nonexistent").is_none());
1850    }
1851
1852    #[test]
1853    fn test_profiler_multiple_executions() {
1854        let mut profiler = GraphProfiler::new();
1855
1856        profiler.profile_execution("test", || {
1857            std::thread::sleep(std::time::Duration::from_millis(1));
1858        });
1859        profiler.profile_execution("test", || {
1860            std::thread::sleep(std::time::Duration::from_millis(2));
1861        });
1862        profiler.profile_execution("test", || {
1863            std::thread::sleep(std::time::Duration::from_millis(3));
1864        });
1865
1866        let stats = profiler.get_stats("test").unwrap();
1867        assert_eq!(stats.count, 3);
1868        assert!(stats.min <= stats.median);
1869        assert!(stats.median <= stats.max);
1870        assert!(stats.average.as_nanos() > 0);
1871
1872        profiler.print_report();
1873    }
1874
1875    #[test]
1876    #[allow(clippy::float_cmp)]
1877    fn test_conditional_evaluation_false_branch() {
1878        let condition = ComputationNode::leaf(|| false);
1879        let if_true = ComputationNode::leaf(|| 10.0);
1880        let if_false = ComputationNode::leaf(|| 20.0);
1881        let conditional = ComputationNode::conditional(condition, if_true, if_false);
1882
1883        let result = conditional.evaluate_fresh();
1884        assert_eq!(result, 20.0);
1885    }
1886
1887    #[test]
1888    fn test_bool_conditional_evaluation() {
1889        let condition = ComputationNode::leaf(|| true);
1890        let if_true = ComputationNode::leaf(|| true);
1891        let if_false = ComputationNode::leaf(|| false);
1892        let conditional = ComputationNode::conditional(condition, if_true, if_false);
1893
1894        let mut context = SampleContext::new();
1895        let result = conditional.evaluate_bool(&mut context);
1896        assert!(result);
1897    }
1898
1899    #[test]
1900    fn test_bool_unary_operation() {
1901        let operand = ComputationNode::leaf(|| true);
1902        let mapped = ComputationNode::map(operand, |x| !x);
1903
1904        let mut context = SampleContext::new();
1905        let result = mapped.evaluate_bool(&mut context);
1906        assert!(!result);
1907    }
1908
1909    #[test]
1910    #[allow(clippy::float_cmp)]
1911    fn test_binary_operations_subtraction() {
1912        let left = ComputationNode::leaf(|| 10.0);
1913        let right = ComputationNode::leaf(|| 3.0);
1914        let sub_node = ComputationNode::binary_op(left, right, BinaryOperation::Sub);
1915
1916        let result = sub_node.evaluate_fresh();
1917        assert_eq!(result, 7.0);
1918    }
1919
1920    #[test]
1921    #[allow(clippy::float_cmp)]
1922    fn test_binary_operations_multiplication() {
1923        let left = ComputationNode::leaf(|| 4.0);
1924        let right = ComputationNode::leaf(|| 5.0);
1925        let mul_node = ComputationNode::binary_op(left, right, BinaryOperation::Mul);
1926
1927        let result = mul_node.evaluate_fresh();
1928        assert_eq!(result, 20.0);
1929    }
1930
1931    #[test]
1932    #[allow(clippy::float_cmp)]
1933    fn test_binary_operations_division() {
1934        let left = ComputationNode::leaf(|| 15.0);
1935        let right = ComputationNode::leaf(|| 3.0);
1936        let div_node = ComputationNode::binary_op(left, right, BinaryOperation::Div);
1937
1938        let result = div_node.evaluate_fresh();
1939        assert_eq!(result, 5.0);
1940    }
1941
1942    #[test]
1943    fn test_nested_conditional_depth() {
1944        let condition1 = ComputationNode::leaf(|| true);
1945        let condition2 = ComputationNode::leaf(|| false);
1946        let leaf1 = ComputationNode::leaf(|| 1.0);
1947        let _leaf2 = ComputationNode::leaf(|| 2.0);
1948        let leaf3 = ComputationNode::leaf(|| 3.0);
1949        let leaf4 = ComputationNode::leaf(|| 4.0);
1950
1951        let inner_conditional = ComputationNode::conditional(condition2, leaf3, leaf4);
1952        let outer_conditional = ComputationNode::conditional(condition1, leaf1, inner_conditional);
1953
1954        assert_eq!(outer_conditional.depth(), 3);
1955        assert_eq!(outer_conditional.node_count(), 7);
1956        assert!(outer_conditional.has_conditionals());
1957    }
1958
1959    #[test]
1960    #[allow(clippy::float_cmp)]
1961    fn test_evaluate_conditional_with_arithmetic() {
1962        let condition = ComputationNode::leaf(|| true);
1963        let if_true = ComputationNode::leaf(|| 42.0);
1964        let if_false = ComputationNode::leaf(|| 24.0);
1965        let conditional = ComputationNode::conditional(condition, if_true, if_false);
1966
1967        let mut context = SampleContext::new();
1968        let result = conditional.evaluate_conditional_with_arithmetic(&mut context);
1969        assert_eq!(result, 42.0);
1970
1971        let leaf = ComputationNode::leaf(|| 99.0);
1972        let result = leaf.evaluate_conditional_with_arithmetic(&mut context);
1973        assert_eq!(result, 99.0);
1974    }
1975
1976    #[test]
1977    fn test_sample_context_different_types() {
1978        let mut context = SampleContext::new();
1979        let id1 = uuid::Uuid::new_v4();
1980        let id2 = uuid::Uuid::new_v4();
1981
1982        context.set_value(id1, 42.0_f64);
1983        context.set_value(id2, 100_i32);
1984
1985        assert_eq!(context.get_value::<f64>(&id1), Some(42.0));
1986        assert_eq!(context.get_value::<i32>(&id2), Some(100));
1987        assert_eq!(context.get_value::<f64>(&id2), None); // Wrong type
1988        assert_eq!(context.get_value::<i32>(&id1), None); // Wrong type
1989
1990        assert_eq!(context.len(), 2);
1991    }
1992
1993    #[test]
1994    fn test_profile_stats_debug() {
1995        let stats = ProfileStats {
1996            count: 5,
1997            total: std::time::Duration::from_millis(100),
1998            average: std::time::Duration::from_millis(20),
1999            median: std::time::Duration::from_millis(18),
2000            min: std::time::Duration::from_millis(15),
2001            max: std::time::Duration::from_millis(30),
2002        };
2003
2004        let debug_str = format!("{stats:?}");
2005        assert!(debug_str.contains("ProfileStats"));
2006
2007        let cloned = stats.clone();
2008        assert_eq!(cloned.count, stats.count);
2009    }
2010}