Skip to main content

deep_causality/utils_test/
test_utils.rs

1/*
2 * SPDX-License-Identifier: MIT
3 * Copyright (c) 2023 - 2026. The DeepCausality Authors and Contributors. All Rights Reserved.
4 */
5use crate::*;
6use deep_causality_haft::LogAddEntry;
7
8use deep_causality_uncertain::{Uncertain, UncertainBool, UncertainF64};
9use std::sync::{Arc, RwLock};
10
11pub fn get_context() -> BaseContext {
12    let id = 1;
13    let name = "base context";
14    let capacity = 10; // adjust as needed
15    Context::with_capacity(id, name, capacity)
16}
17
18pub fn get_test_assumption_vec() -> Vec<Assumption> {
19    let a1 = get_test_assumption();
20    let a2 = get_test_assumption();
21    let a3 = get_test_assumption();
22    Vec::from_iter([a1, a2, a3])
23}
24
25pub fn get_test_obs_vec() -> Vec<Observation> {
26    let o1 = Observation::new(0, 10.0, 1.0);
27    let o2 = Observation::new(1, 10.0, 1.0);
28    let o3 = Observation::new(2, 10.0, 1.0);
29    let o4 = Observation::new(3, 12.0, 0.0);
30    let o5 = Observation::new(4, 14.0, 0.0);
31    Vec::from_iter([o1, o2, o3, o4, o5])
32}
33
34pub fn get_test_inf_vec() -> Vec<Inference> {
35    let i1 = get_test_inferable(0, true);
36    let i2 = get_test_inferable(1, false);
37    Vec::from_iter([i1, i2])
38}
39
40pub fn get_deterministic_test_causality_vec() -> BaseCausaloidVec<f64, bool> {
41    let q1 = get_test_causaloid_deterministic(1);
42    let q2 = get_test_causaloid_deterministic(2);
43    let q3 = get_test_causaloid_deterministic(3);
44    Vec::from_iter([q1, q2, q3])
45}
46pub fn get_probabilistic_test_causality_vec() -> BaseCausaloidVec<f64, f64> {
47    let q1 = get_test_causaloid_probabilistic();
48    let q2 = get_test_causaloid_probabilistic();
49    let q3 = get_test_causaloid_probabilistic();
50    Vec::from_iter([q1, q2, q3])
51}
52
53pub fn get_uncertain_bool_test_causality_vec() -> BaseCausaloidVec<f64, UncertainBool> {
54    let q1 = get_test_causaloid_uncertain_bool();
55    let q2 = get_test_causaloid_uncertain_bool();
56    let q3 = get_test_causaloid_uncertain_bool();
57    Vec::from_iter([q1, q2, q3])
58}
59
60pub fn get_uncertain_float_test_causality_vec() -> BaseCausaloidVec<f64, UncertainF64> {
61    let q1 = get_test_causaloid_uncertain_float();
62    let q2 = get_test_causaloid_uncertain_float();
63    let q3 = get_test_causaloid_uncertain_float();
64    Vec::from_iter([q1, q2, q3])
65}
66
67pub fn get_test_single_data(val: NumericalValue) -> PropagatingEffect<NumericalValue> {
68    PropagatingEffect::pure(val)
69}
70
71pub fn get_test_causaloid_deterministic_true() -> BaseCausaloid<bool, bool> {
72    let description = "tests nothing; always returns true";
73    fn causal_fn(_: bool) -> PropagatingEffect<bool> {
74        let mut effect = PropagatingEffect::pure(true);
75        effect.logs.add_entry("Just return true");
76        effect
77    }
78    Causaloid::new(3, causal_fn, description)
79}
80
81pub fn get_test_causaloid_deterministic_false() -> BaseCausaloid<bool, bool> {
82    let description = "tests nothing; always returns true";
83    fn causal_fn(_: bool) -> PropagatingEffect<bool> {
84        PropagatingEffect::pure(false)
85    }
86    Causaloid::new(3, causal_fn, description)
87}
88
89pub fn get_test_causaloid_probabilistic() -> BaseCausaloid<NumericalValue, f64> {
90    let id: IdentificationValue = 3;
91    let description = "tests whether data exceeds threshold of 0.55";
92
93    fn causal_fn(obs: NumericalValue) -> PropagatingEffect<NumericalValue> {
94        let threshold: NumericalValue = 0.55;
95        let output = if obs.ge(&threshold) { 1.0 } else { 0.0 };
96        PropagatingEffect::pure(output)
97    }
98
99    Causaloid::new(id, causal_fn, description)
100}
101
102pub fn get_test_causaloid_uncertain_bool() -> BaseCausaloid<f64, UncertainBool> {
103    let description = "tests whether data exceeds threshold of 0.55 and returns uncertain bool";
104
105    fn causal_fn(obs: NumericalValue) -> PropagatingEffect<UncertainBool> {
106        let threshold: NumericalValue = 0.55;
107
108        let output = if obs > threshold {
109            Uncertain::<bool>::point(true)
110        } else {
111            Uncertain::<bool>::point(false)
112        };
113        PropagatingEffect::pure(output)
114    }
115
116    Causaloid::new(3, causal_fn, description)
117}
118
119pub fn get_test_causaloid_uncertain_float() -> BaseCausaloid<f64, UncertainF64> {
120    let description = "tests whether data exceeds threshold of 0.55 and returns uncertain bool";
121    fn causal_fn(obs: NumericalValue) -> PropagatingEffect<UncertainF64> {
122        let threshold: NumericalValue = 0.55;
123        let output = if obs > threshold {
124            Uncertain::<f64>::point(1.0f64)
125        } else {
126            Uncertain::<f64>::point(0.0f64)
127        };
128        PropagatingEffect::pure(output)
129    }
130
131    Causaloid::new(3, causal_fn, description)
132}
133
134pub fn get_test_causaloid_deterministic(
135    id: IdentificationValue,
136) -> BaseCausaloid<NumericalValue, bool> {
137    let description = "tests whether data exceeds threshold of 0.55";
138    fn causal_fn(obs: NumericalValue) -> PropagatingEffect<bool> {
139        let threshold: NumericalValue = 0.55;
140        let output = obs.ge(&threshold);
141        PropagatingEffect::pure(output)
142    }
143
144    Causaloid::new(id, causal_fn, description)
145}
146
147pub fn get_test_causaloid_probabilistic_bool_output() -> BaseCausaloid<NumericalValue, f64> {
148    let id: IdentificationValue = 4;
149    let description =
150        "tests whether data exceeds threshold of 0.55 and returns bool probabilistically";
151
152    fn causal_fn(obs: NumericalValue) -> PropagatingEffect<f64> {
153        let threshold: NumericalValue = 0.55;
154        let output = if obs.ge(&threshold) { 1.0 } else { 0.0 };
155        PropagatingEffect::pure(output)
156    }
157
158    Causaloid::new(id, causal_fn, description)
159}
160pub fn get_test_causaloid_deterministic_with_context(
161    context: BaseContext,
162) -> BaseCausaloid<bool, bool> {
163    let id: IdentificationValue = 1;
164    let context = Arc::new(RwLock::new(context));
165    let description = "Inverts any input";
166
167    fn causal_fn_deterministic_with_context(
168        obs: EffectValue<bool>,
169        _state: (),
170        context: Option<Arc<RwLock<BaseContext>>>,
171    ) -> PropagatingProcess<bool, (), Arc<RwLock<BaseContext>>> {
172        if context.is_none() {
173            return PropagatingProcess::from_error(CausalityError(CausalityErrorEnum::Custom(
174                "Context is missing".into(),
175            )));
176        }
177
178        let input_val = obs.into_value().unwrap_or(false);
179
180        // Cloning the Arc to keep a reference, then locking
181        let ctx_ref = context.as_ref().unwrap().clone();
182        let ctx = ctx_ref.read().expect("Failed to read context");
183
184        let current_id = ctx.id();
185        // Assuming context ID 1 is "true" logic
186        if current_id == 1 {
187            PropagatingProcess::pure(input_val)
188        } else {
189            PropagatingProcess::pure(!input_val)
190        }
191    }
192    Causaloid::new_with_context(
193        id,
194        causal_fn_deterministic_with_context
195            as fn(
196                EffectValue<bool>,
197                (),
198                Option<Arc<RwLock<BaseContext>>>,
199            ) -> PropagatingProcess<bool, (), Arc<RwLock<BaseContext>>>,
200        context,
201        description,
202    )
203}
204
205pub fn get_test_causaloid_deterministic_input_output() -> BaseCausaloid<bool, bool> {
206    let id: IdentificationValue = 2;
207    let description = "Inverts any input";
208    fn causal_fn(obs: bool) -> PropagatingEffect<bool> {
209        PropagatingEffect::pure(!obs)
210    }
211    Causaloid::new(id, causal_fn, description)
212}
213
214pub fn get_test_error_causaloid() -> BaseCausaloid<bool, bool> {
215    let id: IdentificationValue = 1;
216    let description = "tests whether data exceeds threshold of 0.55";
217
218    fn causal_fn(_: bool) -> PropagatingEffect<bool> {
219        PropagatingEffect::from_error(CausalityError::new(CausalityErrorEnum::Custom(
220            "Test error".into(),
221        )))
222    }
223
224    Causaloid::new(id, causal_fn, description)
225}
226
227// BaseContext is a type alias for a basic context that can be used for testing
228// It matches the type signature of the base causaloid also uses in these tests.
229// See src/types/alias_types/csm_types for definition.
230pub fn get_base_context() -> BaseContext {
231    let id = 1;
232    let name = "base context";
233    let mut context = Context::with_capacity(id, name, 10);
234    assert_eq!(context.size(), 0);
235
236    let root = Root::new(id);
237    let contextoid = Contextoid::new(id, ContextoidType::Root(root));
238    let idx = context.add_node(contextoid).expect("Failed to add node");
239    assert_eq!(idx, 0);
240    assert_eq!(context.size(), 1);
241
242    context
243}
244
245pub fn get_test_context() -> BaseContext {
246    let mut context = Context::with_capacity(1, "Test-Context", 10);
247
248    let id = 1;
249    let root = Root::new(id);
250    let contextoid = Contextoid::new(id, ContextoidType::Root(root));
251    context.add_node(contextoid).expect("Failed to add node");
252
253    context
254}
255
256pub fn get_test_inferable(id: IdentificationValue, inverse: bool) -> Inference {
257    let question = "".to_string() as DescriptionValue;
258    let all_obs = get_test_obs_vec();
259
260    if inverse {
261        let target_threshold = 11.0;
262        let target_effect = 0.0;
263        let observation = all_obs.percent_observation(target_threshold, target_effect);
264        let threshold = 0.55;
265        let effect = 0.0; // false
266        let target = 0.0; // false
267
268        Inference::new(id, question, observation, threshold, effect, target)
269    } else {
270        let target_threshold = 10.0;
271        let target_effect = 1.0;
272        let observation = all_obs.percent_observation(target_threshold, target_effect);
273        let threshold = 0.55;
274        let effect = 1.0; //true
275        let target = 1.0; //true
276
277        Inference::new(id, question, observation, threshold, effect, target)
278    }
279}
280
281pub fn get_test_observation() -> Observation {
282    Observation::new(0, 14.0, 1.0)
283}
284
285pub fn get_test_assumption() -> Assumption {
286    let id: IdentificationValue = 1;
287    let description: String = "Test assumption that data are there".to_string() as DescriptionValue;
288    let assumption_fn: EvalFn = test_fn_has_data;
289
290    Assumption::new(id, description, assumption_fn)
291}
292
293fn test_fn_has_data(data: &[PropagatingEffect<f64>]) -> Result<bool, AssumptionError> {
294    Ok(!data.is_empty()) // Data is NOT empty i.e. true when it is 
295}
296
297pub fn get_test_assumption_false() -> Assumption {
298    let id: IdentificationValue = 2;
299    let description: String =
300        "Test assumption that is always false".to_string() as DescriptionValue;
301    let assumption_fn: EvalFn = test_fn_is_false;
302    Assumption::new(id, description, assumption_fn)
303}
304
305fn test_fn_is_false(_data: &[PropagatingEffect<f64>]) -> Result<bool, AssumptionError> {
306    Ok(false)
307}
308
309pub fn get_test_assumption_error() -> Assumption {
310    let id: IdentificationValue = 2;
311    let description: String =
312        "Test assumption that raises an error".to_string() as DescriptionValue;
313    let assumption_fn: EvalFn = test_fn_is_error;
314    Assumption::new(id, description, assumption_fn)
315}
316
317fn test_fn_is_error(_data: &[PropagatingEffect<f64>]) -> Result<bool, AssumptionError> {
318    Err(AssumptionError::AssumptionFailed(String::from(
319        "Test error",
320    )))
321}
322
323pub fn get_test_num_array() -> [NumericalValue; 10] {
324    [8.4, 8.5, 9.1, 9.3, 9.4, 9.5, 9.7, 9.7, 9.9, 9.9]
325}
326
327pub fn get_test_causaloid(id: IdentificationValue) -> BaseCausaloid<f64, bool> {
328    let description = "tests whether data exceeds threshold of 0.55";
329
330    fn causal_fn(evidence: f64) -> PropagatingEffect<bool> {
331        let mut log = EffectLog::new();
332        log.add_entry(&format!("Processing evidence: {}", evidence));
333
334        if evidence.is_sign_negative() {
335            log.add_entry("Observation is negative, returning error.");
336            let mut effect = PropagatingEffect::from_error(CausalityError::new(
337                CausalityErrorEnum::Custom("Observation is negative".into()),
338            ));
339            effect.logs = log;
340            return effect;
341        }
342
343        let threshold: NumericalValue = 0.55;
344        let is_active = evidence.ge(&threshold);
345        log.add_entry(&format!(
346            "Evidence {} >= threshold {}: {}",
347            evidence, threshold, is_active
348        ));
349
350        let mut effect = PropagatingEffect::pure(is_active);
351        effect.logs = log;
352        effect
353    }
354
355    Causaloid::new(id, causal_fn, description)
356}
357
358pub fn get_test_causaloid_num_input_output(id: IdentificationValue) -> BaseCausaloid<f64, f64> {
359    let description = "tests whether data exceeds threshold of 0.55";
360
361    fn causal_fn(evidence: f64) -> PropagatingEffect<f64> {
362        let mut log = EffectLog::new();
363        log.add_entry(&format!("Processing evidence: {}", evidence));
364
365        if evidence.is_sign_negative() {
366            log.add_entry("Observation is negative, returning error.");
367            let mut effect = PropagatingEffect::from_error(CausalityError::new(
368                CausalityErrorEnum::Custom("Observation is negative".into()),
369            ));
370            effect.logs = log;
371            return effect;
372        }
373
374        let threshold: NumericalValue = 0.55;
375        let is_active = if evidence.ge(&threshold) { 1.0 } else { 0.0 };
376        log.add_entry(&format!(
377            "Evidence {} >= threshold {}: {}",
378            evidence, threshold, is_active
379        ));
380
381        let mut effect = PropagatingEffect::pure(is_active);
382        effect.logs = log;
383        effect
384    }
385
386    Causaloid::new(id, causal_fn, description)
387}
388
389pub fn generate_sample_data<const N: usize>() -> [f64; N] {
390    [0.99; N]
391}