deep_causality/utils_test/
test_utils.rs1use crate::*;
6use deep_causality_haft::LogAddEntry;
7
8use deep_causality_uncertain::{Uncertain, UncertainBool, UncertainF64};
9use std::sync::{Arc, RwLock};
10
11pub fn get_context() -> BaseContext {
12 let id = 1;
13 let name = "base context";
14 let capacity = 10; Context::with_capacity(id, name, capacity)
16}
17
18pub fn get_test_assumption_vec() -> Vec<Assumption> {
19 let a1 = get_test_assumption();
20 let a2 = get_test_assumption();
21 let a3 = get_test_assumption();
22 Vec::from_iter([a1, a2, a3])
23}
24
25pub fn get_test_obs_vec() -> Vec<Observation> {
26 let o1 = Observation::new(0, 10.0, 1.0);
27 let o2 = Observation::new(1, 10.0, 1.0);
28 let o3 = Observation::new(2, 10.0, 1.0);
29 let o4 = Observation::new(3, 12.0, 0.0);
30 let o5 = Observation::new(4, 14.0, 0.0);
31 Vec::from_iter([o1, o2, o3, o4, o5])
32}
33
34pub fn get_test_inf_vec() -> Vec<Inference> {
35 let i1 = get_test_inferable(0, true);
36 let i2 = get_test_inferable(1, false);
37 Vec::from_iter([i1, i2])
38}
39
40pub fn get_deterministic_test_causality_vec() -> BaseCausaloidVec<f64, bool> {
41 let q1 = get_test_causaloid_deterministic(1);
42 let q2 = get_test_causaloid_deterministic(2);
43 let q3 = get_test_causaloid_deterministic(3);
44 Vec::from_iter([q1, q2, q3])
45}
46pub fn get_probabilistic_test_causality_vec() -> BaseCausaloidVec<f64, f64> {
47 let q1 = get_test_causaloid_probabilistic();
48 let q2 = get_test_causaloid_probabilistic();
49 let q3 = get_test_causaloid_probabilistic();
50 Vec::from_iter([q1, q2, q3])
51}
52
53pub fn get_uncertain_bool_test_causality_vec() -> BaseCausaloidVec<f64, UncertainBool> {
54 let q1 = get_test_causaloid_uncertain_bool();
55 let q2 = get_test_causaloid_uncertain_bool();
56 let q3 = get_test_causaloid_uncertain_bool();
57 Vec::from_iter([q1, q2, q3])
58}
59
60pub fn get_uncertain_float_test_causality_vec() -> BaseCausaloidVec<f64, UncertainF64> {
61 let q1 = get_test_causaloid_uncertain_float();
62 let q2 = get_test_causaloid_uncertain_float();
63 let q3 = get_test_causaloid_uncertain_float();
64 Vec::from_iter([q1, q2, q3])
65}
66
67pub fn get_test_single_data(val: NumericalValue) -> PropagatingEffect<NumericalValue> {
68 PropagatingEffect::pure(val)
69}
70
71pub fn get_test_causaloid_deterministic_true() -> BaseCausaloid<bool, bool> {
72 let description = "tests nothing; always returns true";
73 fn causal_fn(_: bool) -> PropagatingEffect<bool> {
74 let mut effect = PropagatingEffect::pure(true);
75 effect.logs.add_entry("Just return true");
76 effect
77 }
78 Causaloid::new(3, causal_fn, description)
79}
80
81pub fn get_test_causaloid_deterministic_false() -> BaseCausaloid<bool, bool> {
82 let description = "tests nothing; always returns true";
83 fn causal_fn(_: bool) -> PropagatingEffect<bool> {
84 PropagatingEffect::pure(false)
85 }
86 Causaloid::new(3, causal_fn, description)
87}
88
89pub fn get_test_causaloid_probabilistic() -> BaseCausaloid<NumericalValue, f64> {
90 let id: IdentificationValue = 3;
91 let description = "tests whether data exceeds threshold of 0.55";
92
93 fn causal_fn(obs: NumericalValue) -> PropagatingEffect<NumericalValue> {
94 let threshold: NumericalValue = 0.55;
95 let output = if obs.ge(&threshold) { 1.0 } else { 0.0 };
96 PropagatingEffect::pure(output)
97 }
98
99 Causaloid::new(id, causal_fn, description)
100}
101
102pub fn get_test_causaloid_uncertain_bool() -> BaseCausaloid<f64, UncertainBool> {
103 let description = "tests whether data exceeds threshold of 0.55 and returns uncertain bool";
104
105 fn causal_fn(obs: NumericalValue) -> PropagatingEffect<UncertainBool> {
106 let threshold: NumericalValue = 0.55;
107
108 let output = if obs > threshold {
109 Uncertain::<bool>::point(true)
110 } else {
111 Uncertain::<bool>::point(false)
112 };
113 PropagatingEffect::pure(output)
114 }
115
116 Causaloid::new(3, causal_fn, description)
117}
118
119pub fn get_test_causaloid_uncertain_float() -> BaseCausaloid<f64, UncertainF64> {
120 let description = "tests whether data exceeds threshold of 0.55 and returns uncertain bool";
121 fn causal_fn(obs: NumericalValue) -> PropagatingEffect<UncertainF64> {
122 let threshold: NumericalValue = 0.55;
123 let output = if obs > threshold {
124 Uncertain::<f64>::point(1.0f64)
125 } else {
126 Uncertain::<f64>::point(0.0f64)
127 };
128 PropagatingEffect::pure(output)
129 }
130
131 Causaloid::new(3, causal_fn, description)
132}
133
134pub fn get_test_causaloid_deterministic(
135 id: IdentificationValue,
136) -> BaseCausaloid<NumericalValue, bool> {
137 let description = "tests whether data exceeds threshold of 0.55";
138 fn causal_fn(obs: NumericalValue) -> PropagatingEffect<bool> {
139 let threshold: NumericalValue = 0.55;
140 let output = obs.ge(&threshold);
141 PropagatingEffect::pure(output)
142 }
143
144 Causaloid::new(id, causal_fn, description)
145}
146
147pub fn get_test_causaloid_probabilistic_bool_output() -> BaseCausaloid<NumericalValue, f64> {
148 let id: IdentificationValue = 4;
149 let description =
150 "tests whether data exceeds threshold of 0.55 and returns bool probabilistically";
151
152 fn causal_fn(obs: NumericalValue) -> PropagatingEffect<f64> {
153 let threshold: NumericalValue = 0.55;
154 let output = if obs.ge(&threshold) { 1.0 } else { 0.0 };
155 PropagatingEffect::pure(output)
156 }
157
158 Causaloid::new(id, causal_fn, description)
159}
160pub fn get_test_causaloid_deterministic_with_context(
161 context: BaseContext,
162) -> BaseCausaloid<bool, bool> {
163 let id: IdentificationValue = 1;
164 let context = Arc::new(RwLock::new(context));
165 let description = "Inverts any input";
166
167 fn causal_fn_deterministic_with_context(
168 obs: EffectValue<bool>,
169 _state: (),
170 context: Option<Arc<RwLock<BaseContext>>>,
171 ) -> PropagatingProcess<bool, (), Arc<RwLock<BaseContext>>> {
172 if context.is_none() {
173 return PropagatingProcess::from_error(CausalityError(CausalityErrorEnum::Custom(
174 "Context is missing".into(),
175 )));
176 }
177
178 let input_val = obs.into_value().unwrap_or(false);
179
180 let ctx_ref = context.as_ref().unwrap().clone();
182 let ctx = ctx_ref.read().expect("Failed to read context");
183
184 let current_id = ctx.id();
185 if current_id == 1 {
187 PropagatingProcess::pure(input_val)
188 } else {
189 PropagatingProcess::pure(!input_val)
190 }
191 }
192 Causaloid::new_with_context(
193 id,
194 causal_fn_deterministic_with_context
195 as fn(
196 EffectValue<bool>,
197 (),
198 Option<Arc<RwLock<BaseContext>>>,
199 ) -> PropagatingProcess<bool, (), Arc<RwLock<BaseContext>>>,
200 context,
201 description,
202 )
203}
204
205pub fn get_test_causaloid_deterministic_input_output() -> BaseCausaloid<bool, bool> {
206 let id: IdentificationValue = 2;
207 let description = "Inverts any input";
208 fn causal_fn(obs: bool) -> PropagatingEffect<bool> {
209 PropagatingEffect::pure(!obs)
210 }
211 Causaloid::new(id, causal_fn, description)
212}
213
214pub fn get_test_error_causaloid() -> BaseCausaloid<bool, bool> {
215 let id: IdentificationValue = 1;
216 let description = "tests whether data exceeds threshold of 0.55";
217
218 fn causal_fn(_: bool) -> PropagatingEffect<bool> {
219 PropagatingEffect::from_error(CausalityError::new(CausalityErrorEnum::Custom(
220 "Test error".into(),
221 )))
222 }
223
224 Causaloid::new(id, causal_fn, description)
225}
226
227pub fn get_base_context() -> BaseContext {
231 let id = 1;
232 let name = "base context";
233 let mut context = Context::with_capacity(id, name, 10);
234 assert_eq!(context.size(), 0);
235
236 let root = Root::new(id);
237 let contextoid = Contextoid::new(id, ContextoidType::Root(root));
238 let idx = context.add_node(contextoid).expect("Failed to add node");
239 assert_eq!(idx, 0);
240 assert_eq!(context.size(), 1);
241
242 context
243}
244
245pub fn get_test_context() -> BaseContext {
246 let mut context = Context::with_capacity(1, "Test-Context", 10);
247
248 let id = 1;
249 let root = Root::new(id);
250 let contextoid = Contextoid::new(id, ContextoidType::Root(root));
251 context.add_node(contextoid).expect("Failed to add node");
252
253 context
254}
255
256pub fn get_test_inferable(id: IdentificationValue, inverse: bool) -> Inference {
257 let question = "".to_string() as DescriptionValue;
258 let all_obs = get_test_obs_vec();
259
260 if inverse {
261 let target_threshold = 11.0;
262 let target_effect = 0.0;
263 let observation = all_obs.percent_observation(target_threshold, target_effect);
264 let threshold = 0.55;
265 let effect = 0.0; let target = 0.0; Inference::new(id, question, observation, threshold, effect, target)
269 } else {
270 let target_threshold = 10.0;
271 let target_effect = 1.0;
272 let observation = all_obs.percent_observation(target_threshold, target_effect);
273 let threshold = 0.55;
274 let effect = 1.0; let target = 1.0; Inference::new(id, question, observation, threshold, effect, target)
278 }
279}
280
281pub fn get_test_observation() -> Observation {
282 Observation::new(0, 14.0, 1.0)
283}
284
285pub fn get_test_assumption() -> Assumption {
286 let id: IdentificationValue = 1;
287 let description: String = "Test assumption that data are there".to_string() as DescriptionValue;
288 let assumption_fn: EvalFn = test_fn_has_data;
289
290 Assumption::new(id, description, assumption_fn)
291}
292
293fn test_fn_has_data(data: &[PropagatingEffect<f64>]) -> Result<bool, AssumptionError> {
294 Ok(!data.is_empty()) }
296
297pub fn get_test_assumption_false() -> Assumption {
298 let id: IdentificationValue = 2;
299 let description: String =
300 "Test assumption that is always false".to_string() as DescriptionValue;
301 let assumption_fn: EvalFn = test_fn_is_false;
302 Assumption::new(id, description, assumption_fn)
303}
304
305fn test_fn_is_false(_data: &[PropagatingEffect<f64>]) -> Result<bool, AssumptionError> {
306 Ok(false)
307}
308
309pub fn get_test_assumption_error() -> Assumption {
310 let id: IdentificationValue = 2;
311 let description: String =
312 "Test assumption that raises an error".to_string() as DescriptionValue;
313 let assumption_fn: EvalFn = test_fn_is_error;
314 Assumption::new(id, description, assumption_fn)
315}
316
317fn test_fn_is_error(_data: &[PropagatingEffect<f64>]) -> Result<bool, AssumptionError> {
318 Err(AssumptionError::AssumptionFailed(String::from(
319 "Test error",
320 )))
321}
322
323pub fn get_test_num_array() -> [NumericalValue; 10] {
324 [8.4, 8.5, 9.1, 9.3, 9.4, 9.5, 9.7, 9.7, 9.9, 9.9]
325}
326
327pub fn get_test_causaloid(id: IdentificationValue) -> BaseCausaloid<f64, bool> {
328 let description = "tests whether data exceeds threshold of 0.55";
329
330 fn causal_fn(evidence: f64) -> PropagatingEffect<bool> {
331 let mut log = EffectLog::new();
332 log.add_entry(&format!("Processing evidence: {}", evidence));
333
334 if evidence.is_sign_negative() {
335 log.add_entry("Observation is negative, returning error.");
336 let mut effect = PropagatingEffect::from_error(CausalityError::new(
337 CausalityErrorEnum::Custom("Observation is negative".into()),
338 ));
339 effect.logs = log;
340 return effect;
341 }
342
343 let threshold: NumericalValue = 0.55;
344 let is_active = evidence.ge(&threshold);
345 log.add_entry(&format!(
346 "Evidence {} >= threshold {}: {}",
347 evidence, threshold, is_active
348 ));
349
350 let mut effect = PropagatingEffect::pure(is_active);
351 effect.logs = log;
352 effect
353 }
354
355 Causaloid::new(id, causal_fn, description)
356}
357
358pub fn get_test_causaloid_num_input_output(id: IdentificationValue) -> BaseCausaloid<f64, f64> {
359 let description = "tests whether data exceeds threshold of 0.55";
360
361 fn causal_fn(evidence: f64) -> PropagatingEffect<f64> {
362 let mut log = EffectLog::new();
363 log.add_entry(&format!("Processing evidence: {}", evidence));
364
365 if evidence.is_sign_negative() {
366 log.add_entry("Observation is negative, returning error.");
367 let mut effect = PropagatingEffect::from_error(CausalityError::new(
368 CausalityErrorEnum::Custom("Observation is negative".into()),
369 ));
370 effect.logs = log;
371 return effect;
372 }
373
374 let threshold: NumericalValue = 0.55;
375 let is_active = if evidence.ge(&threshold) { 1.0 } else { 0.0 };
376 log.add_entry(&format!(
377 "Evidence {} >= threshold {}: {}",
378 evidence, threshold, is_active
379 ));
380
381 let mut effect = PropagatingEffect::pure(is_active);
382 effect.logs = log;
383 effect
384 }
385
386 Causaloid::new(id, causal_fn, description)
387}
388
389pub fn generate_sample_data<const N: usize>() -> [f64; N] {
390 [0.99; N]
391}