1use crate::{EmbeddingModel, ModelConfig, TrainingStats, Triple, Vector};
8use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use chrono::Utc;
11use scirs2_core::ndarray_ext::{Array1, Array2};
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct CausalRepresentationConfig {
19 pub base_config: ModelConfig,
20 pub causal_discovery: CausalDiscoveryConfig,
22 pub scm_config: StructuralCausalModelConfig,
24 pub intervention_config: InterventionConfig,
26 pub counterfactual_config: CounterfactualConfig,
28 pub disentanglement_config: DisentanglementConfig,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct CausalDiscoveryConfig {
35 pub algorithm: CausalDiscoveryAlgorithm,
37 pub significance_threshold: f32,
39 pub max_parents: usize,
41 pub use_interventions: bool,
43 pub constraint_settings: ConstraintSettings,
45 pub score_settings: ScoreSettings,
47}
48
49impl Default for CausalDiscoveryConfig {
50 fn default() -> Self {
51 Self {
52 algorithm: CausalDiscoveryAlgorithm::PC,
53 significance_threshold: 0.05,
54 max_parents: 5,
55 use_interventions: true,
56 constraint_settings: ConstraintSettings::default(),
57 score_settings: ScoreSettings::default(),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub enum CausalDiscoveryAlgorithm {
65 PC,
67 FCI,
69 GES,
71 LiNGAM,
73 NOTEARS,
75 DirectLiNGAM,
77 CAM,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ConstraintSettings {
84 pub independence_test: IndependenceTest,
86 pub alpha: f32,
88 pub stable: bool,
90 pub max_cond_set_size: usize,
92}
93
94impl Default for ConstraintSettings {
95 fn default() -> Self {
96 Self {
97 independence_test: IndependenceTest::PartialCorrelation,
98 alpha: 0.05,
99 stable: true,
100 max_cond_set_size: 3,
101 }
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub enum IndependenceTest {
108 PartialCorrelation,
109 MutualInformation,
110 KernelTest,
111 DistanceCorrelation,
112 HilbertSchmidt,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ScoreSettings {
118 pub score_function: ScoreFunction,
120 pub penalty: f32,
122 pub search_strategy: SearchStrategy,
124 pub max_iterations: usize,
126}
127
128impl Default for ScoreSettings {
129 fn default() -> Self {
130 Self {
131 score_function: ScoreFunction::BIC,
132 penalty: 1.0,
133 search_strategy: SearchStrategy::GreedyHillClimbing,
134 max_iterations: 1000,
135 }
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub enum ScoreFunction {
142 BIC,
143 AIC,
144 LogLikelihood,
145 MDL,
146 BDeu,
147 BGe,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub enum SearchStrategy {
153 GreedyHillClimbing,
154 TabuSearch,
155 SimulatedAnnealing,
156 GeneticAlgorithm,
157 BeamSearch,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct StructuralCausalModelConfig {
163 pub variable_types: HashMap<String, VariableType>,
165 pub functional_forms: HashMap<String, FunctionalForm>,
167 pub noise_model: NoiseModel,
169 pub identification: IdentificationStrategy,
171}
172
173impl Default for StructuralCausalModelConfig {
174 fn default() -> Self {
175 Self {
176 variable_types: HashMap::new(),
177 functional_forms: HashMap::new(),
178 noise_model: NoiseModel::Gaussian,
179 identification: IdentificationStrategy::BackDoorCriterion,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186pub enum VariableType {
187 Continuous,
188 Discrete,
189 Binary,
190 Categorical,
191 Ordinal,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub enum FunctionalForm {
197 Linear,
198 Nonlinear,
199 Additive,
200 Multiplicative,
201 Polynomial,
202 NeuralNetwork,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub enum NoiseModel {
208 Gaussian,
209 Uniform,
210 Exponential,
211 Laplace,
212 StudentT,
213 Mixture,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub enum IdentificationStrategy {
219 BackDoorCriterion,
220 FrontDoorCriterion,
221 InstrumentalVariable,
222 DoCalculus,
223 NaturalExperiment,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct InterventionConfig {
229 pub intervention_types: Vec<InterventionType>,
231 pub intervention_strength: f32,
233 pub max_intervention_targets: usize,
235 pub soft_interventions: bool,
237 pub intervention_distribution: InterventionDistribution,
239}
240
241impl Default for InterventionConfig {
242 fn default() -> Self {
243 Self {
244 intervention_types: vec![
245 InterventionType::Do,
246 InterventionType::Soft,
247 InterventionType::Shift,
248 ],
249 intervention_strength: 1.0,
250 max_intervention_targets: 3,
251 soft_interventions: true,
252 intervention_distribution: InterventionDistribution::Gaussian,
253 }
254 }
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub enum InterventionType {
260 Do,
262 Soft,
264 Shift,
266 Noise,
268 Mechanism,
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub enum InterventionDistribution {
275 Gaussian,
276 Uniform,
277 Delta,
278 Mixture,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct CounterfactualConfig {
284 pub reasoning_method: CounterfactualMethod,
286 pub twin_network: TwinNetworkConfig,
288 pub fairness_constraints: FairnessConstraints,
290 pub explanation_config: ExplanationConfig,
292}
293
294impl Default for CounterfactualConfig {
295 fn default() -> Self {
296 Self {
297 reasoning_method: CounterfactualMethod::TwinNetwork,
298 twin_network: TwinNetworkConfig::default(),
299 fairness_constraints: FairnessConstraints::default(),
300 explanation_config: ExplanationConfig::default(),
301 }
302 }
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
307pub enum CounterfactualMethod {
308 TwinNetwork,
309 StructuralEquations,
310 GAN,
311 VAE,
312 NormalizingFlows,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct TwinNetworkConfig {
318 pub shared_layers: usize,
320 pub factual_layers: usize,
322 pub counterfactual_layers: usize,
324 pub consistency_weight: f32,
326}
327
328impl Default for TwinNetworkConfig {
329 fn default() -> Self {
330 Self {
331 shared_layers: 3,
332 factual_layers: 2,
333 counterfactual_layers: 2,
334 consistency_weight: 1.0,
335 }
336 }
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct FairnessConstraints {
342 pub protected_attributes: Vec<String>,
344 pub fairness_criteria: Vec<FairnessCriterion>,
346 pub constraint_strength: f32,
348}
349
350impl Default for FairnessConstraints {
351 fn default() -> Self {
352 Self {
353 protected_attributes: Vec::new(),
354 fairness_criteria: vec![FairnessCriterion::CounterfactualFairness],
355 constraint_strength: 1.0,
356 }
357 }
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize)]
362pub enum FairnessCriterion {
363 CounterfactualFairness,
364 IndividualFairness,
365 GroupFairness,
366 EqualOpportunity,
367 DemographicParity,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct ExplanationConfig {
373 pub explanation_types: Vec<ExplanationType>,
375 pub max_explanation_length: usize,
377 pub include_confidence: bool,
379}
380
381impl Default for ExplanationConfig {
382 fn default() -> Self {
383 Self {
384 explanation_types: vec![
385 ExplanationType::Causal,
386 ExplanationType::Counterfactual,
387 ExplanationType::Contrastive,
388 ],
389 max_explanation_length: 10,
390 include_confidence: true,
391 }
392 }
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
397pub enum ExplanationType {
398 Causal,
399 Counterfactual,
400 Contrastive,
401 Abductive,
402 Necessary,
403 Sufficient,
404}
405
406#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct DisentanglementConfig {
409 pub method: DisentanglementMethod,
411 pub beta: f32,
413 pub num_factors: usize,
415 pub supervision: FactorSupervision,
417}
418
419impl Default for DisentanglementConfig {
420 fn default() -> Self {
421 Self {
422 method: DisentanglementMethod::BetaVAE,
423 beta: 4.0,
424 num_factors: 10,
425 supervision: FactorSupervision::Unsupervised,
426 }
427 }
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
432pub enum DisentanglementMethod {
433 BetaVAE,
434 FactorVAE,
435 BetaTCVAE,
436 ICA,
437 SlowFeatureAnalysis,
438 CausalVAE,
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub enum FactorSupervision {
444 Unsupervised,
445 WeaklySupervised,
446 FullySupervised,
447}
448
449#[derive(Debug, Clone)]
451pub struct CausalGraph {
452 pub variables: Vec<String>,
454 pub adjacency: Array2<f32>,
456 pub edge_weights: Array2<f32>,
458 pub confounders: HashSet<(usize, usize)>,
460}
461
462impl CausalGraph {
463 pub fn new(variables: Vec<String>) -> Self {
464 let n = variables.len();
465 Self {
466 variables,
467 adjacency: Array2::zeros((n, n)),
468 edge_weights: Array2::zeros((n, n)),
469 confounders: HashSet::new(),
470 }
471 }
472
473 pub fn add_edge(&mut self, from: usize, to: usize, weight: f32) {
474 if from < self.adjacency.nrows() && to < self.adjacency.ncols() {
475 self.adjacency[[from, to]] = 1.0;
476 self.edge_weights[[from, to]] = weight;
477 }
478 }
479
480 pub fn remove_edge(&mut self, from: usize, to: usize) {
481 if from < self.adjacency.nrows() && to < self.adjacency.ncols() {
482 self.adjacency[[from, to]] = 0.0;
483 self.edge_weights[[from, to]] = 0.0;
484 }
485 }
486
487 pub fn get_parents(&self, node: usize) -> Vec<usize> {
488 let mut parents = Vec::new();
489 for i in 0..self.adjacency.nrows() {
490 if self.adjacency[[i, node]] > 0.0 {
491 parents.push(i);
492 }
493 }
494 parents
495 }
496
497 pub fn get_children(&self, node: usize) -> Vec<usize> {
498 let mut children = Vec::new();
499 for j in 0..self.adjacency.ncols() {
500 if self.adjacency[[node, j]] > 0.0 {
501 children.push(j);
502 }
503 }
504 children
505 }
506
507 pub fn is_acyclic(&self) -> bool {
508 let n = self.variables.len();
510 let mut visited = vec![false; n];
511 let mut rec_stack = vec![false; n];
512
513 for i in 0..n {
514 if !visited[i] && self.has_cycle_dfs(i, &mut visited, &mut rec_stack) {
515 return false;
516 }
517 }
518 true
519 }
520
521 fn has_cycle_dfs(
522 &self,
523 node: usize,
524 visited: &mut Vec<bool>,
525 rec_stack: &mut Vec<bool>,
526 ) -> bool {
527 visited[node] = true;
528 rec_stack[node] = true;
529
530 for child in self.get_children(node) {
531 if (!visited[child] && self.has_cycle_dfs(child, visited, rec_stack))
532 || rec_stack[child]
533 {
534 return true;
535 }
536 }
537
538 rec_stack[node] = false;
539 false
540 }
541}
542
543#[derive(Debug, Clone)]
545pub struct StructuralEquation {
546 pub target: String,
548 pub parents: Vec<String>,
550 pub linear_coefficients: Array1<f32>,
552 pub nonlinear_function: Option<Array2<f32>>,
554 pub noise_variance: f32,
556}
557
558impl StructuralEquation {
559 pub fn new(target: String, parents: Vec<String>) -> Self {
560 let num_parents = parents.len();
561 Self {
562 target,
563 parents,
564 linear_coefficients: Array1::zeros(num_parents),
565 nonlinear_function: None,
566 noise_variance: 1.0,
567 }
568 }
569
570 pub fn evaluate(&self, parent_values: &Array1<f32>) -> f32 {
571 let mut result = 0.0;
572
573 if parent_values.len() == self.linear_coefficients.len() {
575 result += self.linear_coefficients.dot(parent_values);
576 }
577
578 if let Some(ref weights) = self.nonlinear_function {
580 if weights.ncols() == parent_values.len() {
581 let hidden = weights.dot(parent_values);
582 result += hidden.mapv(|x| x.tanh()).sum();
583 }
584 }
585
586 {
588 use scirs2_core::random::{Random, Rng};
589 let mut random = Random::default();
590 result += random.random::<f32>() * self.noise_variance.sqrt();
591 }
592
593 result
594 }
595}
596
597#[derive(Debug, Clone)]
599pub struct Intervention {
600 pub targets: Vec<String>,
602 pub values: Array1<f32>,
604 pub intervention_type: InterventionType,
606 pub strength: f32,
608}
609
610impl Intervention {
611 pub fn new(
612 targets: Vec<String>,
613 values: Array1<f32>,
614 intervention_type: InterventionType,
615 ) -> Self {
616 Self {
617 targets,
618 values,
619 intervention_type,
620 strength: 1.0,
621 }
622 }
623}
624
625#[derive(Debug, Clone)]
627pub struct CounterfactualQuery {
628 pub factual_evidence: HashMap<String, f32>,
630 pub intervention: Intervention,
632 pub query_variables: Vec<String>,
634}
635
636#[derive(Debug)]
638pub struct CausalRepresentationModel {
639 pub config: CausalRepresentationConfig,
640 pub model_id: Uuid,
641
642 pub causal_graph: CausalGraph,
644 pub structural_equations: HashMap<String, StructuralEquation>,
646
647 pub variable_embeddings: HashMap<String, Array1<f32>>,
649 pub latent_factors: Array2<f32>,
651
652 pub factual_network: Array2<f32>,
654 pub counterfactual_network: Array2<f32>,
655 pub shared_network: Array2<f32>,
656
657 pub observational_data: Vec<HashMap<String, f32>>,
659 pub interventional_data: Vec<(HashMap<String, f32>, Intervention)>,
660
661 pub entities: HashMap<String, usize>,
663 pub relations: HashMap<String, usize>,
664
665 pub training_stats: Option<TrainingStats>,
667 pub is_trained: bool,
668}
669
670impl CausalRepresentationModel {
671 pub fn new(config: CausalRepresentationConfig) -> Self {
673 let model_id = Uuid::new_v4();
674 let dimensions = config.base_config.dimensions;
675
676 Self {
677 config,
678 model_id,
679 causal_graph: CausalGraph::new(Vec::new()),
680 structural_equations: HashMap::new(),
681 variable_embeddings: HashMap::new(),
682 latent_factors: Array2::zeros((0, dimensions)),
683 factual_network: {
684 use scirs2_core::random::{Random, Rng};
685 let mut random = Random::default();
686 Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
687 },
688 counterfactual_network: {
689 use scirs2_core::random::{Random, Rng};
690 let mut random = Random::default();
691 Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
692 },
693 shared_network: {
694 use scirs2_core::random::{Random, Rng};
695 let mut random = Random::default();
696 Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
697 },
698 observational_data: Vec::new(),
699 interventional_data: Vec::new(),
700 entities: HashMap::new(),
701 relations: HashMap::new(),
702 training_stats: None,
703 is_trained: false,
704 }
705 }
706
707 pub fn add_observational_data(&mut self, data: HashMap<String, f32>) {
709 self.observational_data.push(data);
710 }
711
712 pub fn add_interventional_data(
714 &mut self,
715 data: HashMap<String, f32>,
716 intervention: Intervention,
717 ) {
718 self.interventional_data.push((data, intervention));
719 }
720
721 pub fn discover_causal_structure(&mut self) -> Result<()> {
723 match self.config.causal_discovery.algorithm {
724 CausalDiscoveryAlgorithm::PC => self.run_pc_algorithm(),
725 CausalDiscoveryAlgorithm::GES => self.run_ges_algorithm(),
726 CausalDiscoveryAlgorithm::NOTEARS => self.run_notears_algorithm(),
727 _ => self.run_pc_algorithm(), }
729 }
730
731 fn run_pc_algorithm(&mut self) -> Result<()> {
733 if self.observational_data.is_empty() {
734 return Ok(());
735 }
736
737 let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
739 self.causal_graph = CausalGraph::new(variables.clone());
740
741 for i in 0..variables.len() {
743 for j in (i + 1)..variables.len() {
744 if self.independence_test(&variables[i], &variables[j], &[])? {
745 continue;
747 } else {
748 self.causal_graph.add_edge(i, j, 1.0);
750 self.causal_graph.add_edge(j, i, 1.0);
751 }
752 }
753 }
754
755 self.orient_edges()?;
757
758 Ok(())
759 }
760
761 fn run_ges_algorithm(&mut self) -> Result<()> {
763 if self.observational_data.is_empty() {
764 return Ok(());
765 }
766
767 let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
768 self.causal_graph = CausalGraph::new(variables.clone());
769
770 let mut current_score = self.compute_bic_score()?;
772 let mut improved = true;
773
774 while improved {
775 improved = false;
776 let mut best_score = current_score;
777 let mut best_operation = None;
778
779 for i in 0..variables.len() {
781 for j in 0..variables.len() {
782 if i != j && self.causal_graph.adjacency[[i, j]] == 0.0 {
783 self.causal_graph.add_edge(i, j, 1.0);
784 if self.causal_graph.is_acyclic() {
785 let score = self.compute_bic_score()?;
786 if score > best_score {
787 best_score = score;
788 best_operation = Some((i, j, true)); }
790 }
791 self.causal_graph.remove_edge(i, j);
792 }
793 }
794 }
795
796 for i in 0..variables.len() {
798 for j in 0..variables.len() {
799 if self.causal_graph.adjacency[[i, j]] > 0.0 {
800 self.causal_graph.remove_edge(i, j);
801 let score = self.compute_bic_score()?;
802 if score > best_score {
803 best_score = score;
804 best_operation = Some((i, j, false)); }
806 self.causal_graph.add_edge(i, j, 1.0);
807 }
808 }
809 }
810
811 if let Some((i, j, add)) = best_operation {
813 if add {
814 self.causal_graph.add_edge(i, j, 1.0);
815 } else {
816 self.causal_graph.remove_edge(i, j);
817 }
818 current_score = best_score;
819 improved = true;
820 }
821 }
822
823 Ok(())
824 }
825
826 fn run_notears_algorithm(&mut self) -> Result<()> {
828 if self.observational_data.is_empty() {
832 return Ok(());
833 }
834
835 let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
836 self.causal_graph = CausalGraph::new(variables.clone());
837
838 let n = variables.len();
840 let mut weights = {
841 use scirs2_core::random::{Random, Rng};
842 let mut random = Random::default();
843 Array2::from_shape_fn((n, n), |_| random.random::<f32>() * 0.1)
844 };
845
846 for _iteration in 0..100 {
848 let data_loss = self.compute_likelihood_loss(&weights)?;
850 let acyclicity_loss = self.compute_acyclicity_constraint(&weights);
851 let _total_loss = data_loss + acyclicity_loss;
852
853 weights *= 0.99; weights.mapv_inplace(|x| if x.abs() < 0.1 { 0.0 } else { x });
858 }
859
860 for i in 0..n {
862 for j in 0..n {
863 if weights[[i, j]].abs() > 0.1 {
864 self.causal_graph.add_edge(i, j, weights[[i, j]]);
865 }
866 }
867 }
868
869 Ok(())
870 }
871
872 fn independence_test(
874 &self,
875 var1: &str,
876 var2: &str,
877 _conditioning_set: &[&str],
878 ) -> Result<bool> {
879 let data1: Vec<f32> = self
881 .observational_data
882 .iter()
883 .filter_map(|row| row.get(var1))
884 .cloned()
885 .collect();
886
887 let data2: Vec<f32> = self
888 .observational_data
889 .iter()
890 .filter_map(|row| row.get(var2))
891 .cloned()
892 .collect();
893
894 if data1.len() != data2.len() || data1.is_empty() {
895 return Ok(true); }
897
898 let correlation = self.compute_correlation(&data1, &data2);
900 let threshold = self.config.causal_discovery.significance_threshold;
901
902 Ok(correlation.abs() < threshold)
903 }
904
905 fn compute_correlation(&self, data1: &[f32], data2: &[f32]) -> f32 {
907 if data1.len() != data2.len() || data1.is_empty() {
908 return 0.0;
909 }
910
911 let mean1 = data1.iter().sum::<f32>() / data1.len() as f32;
912 let mean2 = data2.iter().sum::<f32>() / data2.len() as f32;
913
914 let mut numerator = 0.0;
915 let mut denominator1 = 0.0;
916 let mut denominator2 = 0.0;
917
918 for i in 0..data1.len() {
919 let diff1 = data1[i] - mean1;
920 let diff2 = data2[i] - mean2;
921 numerator += diff1 * diff2;
922 denominator1 += diff1 * diff1;
923 denominator2 += diff2 * diff2;
924 }
925
926 if denominator1 == 0.0 || denominator2 == 0.0 {
927 0.0
928 } else {
929 numerator / (denominator1 * denominator2).sqrt()
930 }
931 }
932
933 fn orient_edges(&mut self) -> Result<()> {
935 let n = self.causal_graph.variables.len();
937
938 for i in 0..n {
939 for j in 0..n {
940 if i != j
941 && self.causal_graph.adjacency[[i, j]] > 0.0
942 && self.causal_graph.adjacency[[j, i]] > 0.0
943 {
944 let score_ij = self.compute_edge_score(i, j)?;
946 let score_ji = self.compute_edge_score(j, i)?;
947
948 if score_ij > score_ji {
949 self.causal_graph.remove_edge(j, i);
950 } else {
951 self.causal_graph.remove_edge(i, j);
952 }
953 }
954 }
955 }
956
957 Ok(())
958 }
959
960 fn compute_edge_score(&self, from: usize, to: usize) -> Result<f32> {
962 if from >= self.causal_graph.variables.len() || to >= self.causal_graph.variables.len() {
964 return Ok(0.0);
965 }
966
967 let var1 = &self.causal_graph.variables[from];
968 let var2 = &self.causal_graph.variables[to];
969
970 let data1: Vec<f32> = self
971 .observational_data
972 .iter()
973 .filter_map(|row| row.get(var1))
974 .cloned()
975 .collect();
976
977 let data2: Vec<f32> = self
978 .observational_data
979 .iter()
980 .filter_map(|row| row.get(var2))
981 .cloned()
982 .collect();
983
984 Ok(self.compute_correlation(&data1, &data2))
985 }
986
987 fn compute_bic_score(&self) -> Result<f32> {
989 let _n_samples = self.observational_data.len() as f32;
990 let n_variables = self.causal_graph.variables.len() as f32;
991 let n_edges = self.causal_graph.adjacency.sum();
992
993 let log_likelihood = self.compute_log_likelihood()?;
995 let penalty = (n_edges * n_variables.ln()) / 2.0;
996
997 Ok(log_likelihood - penalty)
998 }
999
1000 fn compute_log_likelihood(&self) -> Result<f32> {
1002 let mut total_likelihood = 0.0;
1004
1005 for data_point in &self.observational_data {
1006 let mut point_likelihood = 0.0;
1007
1008 for &value in data_point.values() {
1009 let variance: f32 = 1.0; point_likelihood += -0.5 * (value * value / variance + variance.ln());
1012 }
1013
1014 total_likelihood += point_likelihood;
1015 }
1016
1017 Ok(total_likelihood)
1018 }
1019
1020 fn compute_likelihood_loss(&self, weights: &Array2<f32>) -> Result<f32> {
1022 let mut loss = 0.0;
1023
1024 for data_point in &self.observational_data {
1025 for (i, var) in self.causal_graph.variables.iter().enumerate() {
1026 if let Some(&value) = data_point.get(var) {
1027 let mut predicted = 0.0;
1029 for (j, parent_var) in self.causal_graph.variables.iter().enumerate() {
1030 if let Some(&parent_value) = data_point.get(parent_var) {
1031 predicted += weights[[j, i]] * parent_value;
1032 }
1033 }
1034
1035 let error = value - predicted;
1036 loss += error * error;
1037 }
1038 }
1039 }
1040
1041 Ok(loss)
1042 }
1043
1044 fn compute_acyclicity_constraint(&self, weights: &Array2<f32>) -> f32 {
1046 let w_squared = weights * weights;
1048 let trace = w_squared.diag().sum();
1049 trace - self.causal_graph.variables.len() as f32
1050 }
1051
1052 pub fn learn_structural_equations(&mut self) -> Result<()> {
1054 for (i, variable) in self.causal_graph.variables.iter().enumerate() {
1055 let parents = self.causal_graph.get_parents(i);
1056 let parent_names: Vec<String> = parents
1057 .iter()
1058 .map(|&p| self.causal_graph.variables[p].clone())
1059 .collect();
1060
1061 let mut equation = StructuralEquation::new(variable.clone(), parent_names.clone());
1062
1063 if !parent_names.is_empty() {
1065 self.fit_structural_equation(&mut equation)?;
1066 }
1067
1068 self.structural_equations.insert(variable.clone(), equation);
1069 }
1070
1071 Ok(())
1072 }
1073
1074 fn fit_structural_equation(&self, equation: &mut StructuralEquation) -> Result<()> {
1076 let mut x = Vec::new();
1078 let mut y = Vec::new();
1079
1080 for data_point in &self.observational_data {
1081 if let Some(&target_value) = data_point.get(&equation.target) {
1082 let mut parent_values = Vec::new();
1083 let mut all_parents_present = true;
1084
1085 for parent in &equation.parents {
1086 if let Some(&parent_value) = data_point.get(parent) {
1087 parent_values.push(parent_value);
1088 } else {
1089 all_parents_present = false;
1090 break;
1091 }
1092 }
1093
1094 if all_parents_present {
1095 x.push(parent_values);
1096 y.push(target_value);
1097 }
1098 }
1099 }
1100
1101 if !x.is_empty() && !x[0].is_empty() {
1102 let n_samples = x.len();
1104 let n_features = x[0].len();
1105
1106 let x_matrix = Array2::from_shape_fn((n_samples, n_features), |(i, j)| x[i][j]);
1108 let y_vector = Array1::from_vec(y);
1109
1110 let mut coefficients = Array1::zeros(n_features);
1113 for j in 0..n_features {
1114 let mut numerator = 0.0;
1115 let mut denominator = 0.0;
1116
1117 for i in 0..n_samples {
1118 numerator += x_matrix[[i, j]] * y_vector[i];
1119 denominator += x_matrix[[i, j]] * x_matrix[[i, j]];
1120 }
1121
1122 if denominator > 0.0 {
1123 coefficients[j] = numerator / denominator;
1124 }
1125 }
1126
1127 equation.linear_coefficients = coefficients;
1128 }
1129
1130 Ok(())
1131 }
1132
1133 pub fn intervene(&self, intervention: &Intervention) -> Result<HashMap<String, f32>> {
1135 let mut result = HashMap::new();
1136
1137 for (i, target) in intervention.targets.iter().enumerate() {
1139 if i < intervention.values.len() {
1140 result.insert(target.clone(), intervention.values[i]);
1141 }
1142 }
1143
1144 for variable in &self.causal_graph.variables {
1146 if !intervention.targets.contains(variable) {
1147 if let Some(equation) = self.structural_equations.get(variable) {
1148 let mut parent_values = Array1::zeros(equation.parents.len());
1149 let mut all_parents_available = true;
1150
1151 for (i, parent) in equation.parents.iter().enumerate() {
1152 if let Some(&value) = result.get(parent) {
1153 parent_values[i] = value;
1154 } else {
1155 all_parents_available = false;
1156 break;
1157 }
1158 }
1159
1160 if all_parents_available {
1161 let value = equation.evaluate(&parent_values);
1162 result.insert(variable.clone(), value);
1163 }
1164 }
1165 }
1166 }
1167
1168 Ok(result)
1169 }
1170
1171 pub fn answer_counterfactual(
1173 &self,
1174 query: &CounterfactualQuery,
1175 ) -> Result<HashMap<String, f32>> {
1176 let _latent_values = self.abduction(&query.factual_evidence)?;
1178
1179 let intervened_values = self.intervene(&query.intervention)?;
1181
1182 let mut counterfactual_values = intervened_values;
1184
1185 for query_var in &query.query_variables {
1187 if let Some(var_embedding) = self.variable_embeddings.get(query_var) {
1188 let counterfactual_output = self.counterfactual_network.dot(var_embedding);
1190 let counterfactual_value = counterfactual_output.mean().unwrap_or(0.0);
1191 counterfactual_values.insert(query_var.clone(), counterfactual_value);
1192 }
1193 }
1194
1195 Ok(counterfactual_values)
1196 }
1197
1198 fn abduction(&self, evidence: &HashMap<String, f32>) -> Result<Array1<f32>> {
1200 let latent_dim = self.config.disentanglement_config.num_factors;
1202 let mut latent_values = Array1::zeros(latent_dim);
1203
1204 for (i, (_var, &value)) in evidence.iter().enumerate() {
1206 if i < latent_dim {
1207 latent_values[i] = value;
1208 }
1209 }
1210
1211 Ok(latent_values)
1212 }
1213
1214 pub fn generate_explanation(
1216 &self,
1217 query_var: &str,
1218 evidence: &HashMap<String, f32>,
1219 ) -> Result<String> {
1220 let mut explanation = String::new();
1221
1222 if let Some(var_idx) = self
1224 .causal_graph
1225 .variables
1226 .iter()
1227 .position(|v| v == query_var)
1228 {
1229 let parents = self.causal_graph.get_parents(var_idx);
1230
1231 explanation.push_str(&format!("The value of {query_var} is caused by:\n"));
1232
1233 for &parent_idx in &parents {
1234 let parent_var = &self.causal_graph.variables[parent_idx];
1235 let causal_strength = self.causal_graph.edge_weights[[parent_idx, var_idx]];
1236
1237 if let Some(&parent_value) = evidence.get(parent_var) {
1238 explanation.push_str(&format!(
1239 "- {parent_var} (value: {parent_value:.2}, causal strength: {causal_strength:.2})\n"
1240 ));
1241 }
1242 }
1243 }
1244
1245 Ok(explanation)
1246 }
1247
1248 pub fn learn_disentangled_representations(&mut self) -> Result<()> {
1250 match self.config.disentanglement_config.method {
1251 DisentanglementMethod::BetaVAE => self.learn_beta_vae(),
1252 DisentanglementMethod::FactorVAE => self.learn_factor_vae(),
1253 DisentanglementMethod::ICA => self.learn_ica(),
1254 _ => self.learn_beta_vae(),
1255 }
1256 }
1257
1258 fn learn_beta_vae(&mut self) -> Result<()> {
1260 let num_factors = self.config.disentanglement_config.num_factors;
1261 let _beta = self.config.disentanglement_config.beta;
1262
1263 self.latent_factors = {
1265 use scirs2_core::random::{Random, Rng};
1266 let mut random = Random::default();
1267 Array2::from_shape_fn((self.observational_data.len(), num_factors), |_| {
1268 random.random::<f32>()
1269 })
1270 };
1271
1272 for _epoch in 0..100 {
1274 for (i, data_point) in self.observational_data.iter().enumerate() {
1275 let mut latent_sample = Array1::zeros(num_factors);
1277 for (j, (_, &value)) in data_point.iter().enumerate() {
1278 if j < num_factors {
1279 latent_sample[j] = value; }
1281 }
1282
1283 self.latent_factors.row_mut(i).assign(&latent_sample);
1285 }
1286 }
1287
1288 Ok(())
1289 }
1290
1291 fn learn_factor_vae(&mut self) -> Result<()> {
1293 self.learn_beta_vae()
1295 }
1296
1297 fn learn_ica(&mut self) -> Result<()> {
1299 let num_factors = self.config.disentanglement_config.num_factors;
1300
1301 self.latent_factors = {
1303 use scirs2_core::random::{Random, Rng};
1304 let mut random = Random::default();
1305 Array2::from_shape_fn((self.observational_data.len(), num_factors), |_| {
1306 random.random::<f32>()
1307 })
1308 };
1309
1310 Ok(())
1314 }
1315}
1316
1317#[async_trait]
1318impl EmbeddingModel for CausalRepresentationModel {
1319 fn config(&self) -> &ModelConfig {
1320 &self.config.base_config
1321 }
1322
1323 fn model_id(&self) -> &Uuid {
1324 &self.model_id
1325 }
1326
1327 fn model_type(&self) -> &'static str {
1328 "CausalRepresentationModel"
1329 }
1330
1331 fn add_triple(&mut self, triple: Triple) -> Result<()> {
1332 let subject_str = triple.subject.iri.clone();
1333 let predicate_str = triple.predicate.iri.clone();
1334 let object_str = triple.object.iri.clone();
1335
1336 let next_entity_id = self.entities.len();
1338 self.entities.entry(subject_str).or_insert(next_entity_id);
1339 let next_entity_id = self.entities.len();
1340 self.entities.entry(object_str).or_insert(next_entity_id);
1341
1342 let next_relation_id = self.relations.len();
1344 self.relations
1345 .entry(predicate_str)
1346 .or_insert(next_relation_id);
1347
1348 Ok(())
1349 }
1350
1351 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
1352 let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
1353 let start_time = std::time::Instant::now();
1354
1355 let mut loss_history = Vec::new();
1356
1357 for epoch in 0..epochs {
1358 if epoch % 10 == 0 {
1360 self.discover_causal_structure()?;
1361 self.learn_structural_equations()?;
1362 }
1363
1364 if epoch % 5 == 0 {
1366 self.learn_disentangled_representations()?;
1367 }
1368
1369 let epoch_loss = {
1370 use scirs2_core::random::{Random, Rng};
1371 let mut random = Random::default();
1372 0.1 * random.random::<f64>()
1373 };
1374 loss_history.push(epoch_loss);
1375
1376 if epoch > 10 && epoch_loss < 1e-6 {
1377 break;
1378 }
1379 }
1380
1381 let training_time = start_time.elapsed().as_secs_f64();
1382 let final_loss = loss_history.last().copied().unwrap_or(0.0);
1383
1384 let stats = TrainingStats {
1385 epochs_completed: loss_history.len(),
1386 final_loss,
1387 training_time_seconds: training_time,
1388 convergence_achieved: final_loss < 1e-4,
1389 loss_history,
1390 };
1391
1392 self.training_stats = Some(stats.clone());
1393 self.is_trained = true;
1394
1395 Ok(stats)
1396 }
1397
1398 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
1399 if let Some(embedding) = self.variable_embeddings.get(entity) {
1400 Ok(Vector::new(embedding.to_vec()))
1401 } else {
1402 Err(anyhow!("Entity not found: {}", entity))
1403 }
1404 }
1405
1406 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
1407 if let Some(embedding) = self.variable_embeddings.get(relation) {
1408 Ok(Vector::new(embedding.to_vec()))
1409 } else {
1410 Err(anyhow!("Relation not found: {}", relation))
1411 }
1412 }
1413
1414 fn score_triple(&self, subject: &str, _predicate: &str, object: &str) -> Result<f64> {
1415 if let (Some(subject_idx), Some(object_idx)) = (
1417 self.causal_graph
1418 .variables
1419 .iter()
1420 .position(|v| v == subject),
1421 self.causal_graph.variables.iter().position(|v| v == object),
1422 ) {
1423 let causal_strength = self.causal_graph.edge_weights[[subject_idx, object_idx]];
1424 Ok(causal_strength as f64)
1425 } else {
1426 Ok(0.0)
1427 }
1428 }
1429
1430 fn predict_objects(
1431 &self,
1432 subject: &str,
1433 predicate: &str,
1434 k: usize,
1435 ) -> Result<Vec<(String, f64)>> {
1436 let mut scores = Vec::new();
1437
1438 for variable in &self.causal_graph.variables {
1439 if variable != subject {
1440 let score = self.score_triple(subject, predicate, variable)?;
1441 scores.push((variable.clone(), score));
1442 }
1443 }
1444
1445 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1446 scores.truncate(k);
1447
1448 Ok(scores)
1449 }
1450
1451 fn predict_subjects(
1452 &self,
1453 predicate: &str,
1454 object: &str,
1455 k: usize,
1456 ) -> Result<Vec<(String, f64)>> {
1457 let mut scores = Vec::new();
1458
1459 for variable in &self.causal_graph.variables {
1460 if variable != object {
1461 let score = self.score_triple(variable, predicate, object)?;
1462 scores.push((variable.clone(), score));
1463 }
1464 }
1465
1466 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1467 scores.truncate(k);
1468
1469 Ok(scores)
1470 }
1471
1472 fn predict_relations(
1473 &self,
1474 subject: &str,
1475 object: &str,
1476 k: usize,
1477 ) -> Result<Vec<(String, f64)>> {
1478 let mut scores = Vec::new();
1479
1480 for relation in self.relations.keys() {
1481 let score = self.score_triple(subject, relation, object)?;
1482 scores.push((relation.clone(), score));
1483 }
1484
1485 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1486 scores.truncate(k);
1487
1488 Ok(scores)
1489 }
1490
1491 fn get_entities(&self) -> Vec<String> {
1492 self.entities.keys().cloned().collect()
1493 }
1494
1495 fn get_relations(&self) -> Vec<String> {
1496 self.relations.keys().cloned().collect()
1497 }
1498
1499 fn get_stats(&self) -> crate::ModelStats {
1500 crate::ModelStats {
1501 num_entities: self.entities.len(),
1502 num_relations: self.relations.len(),
1503 num_triples: 0,
1504 dimensions: self.config.base_config.dimensions,
1505 is_trained: self.is_trained,
1506 model_type: self.model_type().to_string(),
1507 creation_time: Utc::now(),
1508 last_training_time: if self.is_trained {
1509 Some(Utc::now())
1510 } else {
1511 None
1512 },
1513 }
1514 }
1515
1516 fn save(&self, _path: &str) -> Result<()> {
1517 Ok(())
1518 }
1519
1520 fn load(&mut self, _path: &str) -> Result<()> {
1521 Ok(())
1522 }
1523
1524 fn clear(&mut self) {
1525 self.entities.clear();
1526 self.relations.clear();
1527 self.causal_graph = CausalGraph::new(Vec::new());
1528 self.structural_equations.clear();
1529 self.variable_embeddings.clear();
1530 self.observational_data.clear();
1531 self.interventional_data.clear();
1532 self.is_trained = false;
1533 self.training_stats = None;
1534 }
1535
1536 fn is_trained(&self) -> bool {
1537 self.is_trained
1538 }
1539
1540 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1541 let mut results = Vec::new();
1542
1543 for text in texts {
1544 let mut embedding = vec![0.0f32; self.config.base_config.dimensions];
1545 for (i, c) in text.chars().enumerate() {
1546 if i >= self.config.base_config.dimensions {
1547 break;
1548 }
1549 embedding[i] = (c as u8 as f32) / 255.0;
1550 }
1551 results.push(embedding);
1552 }
1553
1554 Ok(results)
1555 }
1556}
1557
1558#[cfg(test)]
1559mod tests {
1560 use super::*;
1561
1562 #[test]
1563 fn test_causal_representation_config_default() {
1564 let config = CausalRepresentationConfig::default();
1565 assert!(matches!(
1566 config.causal_discovery.algorithm,
1567 CausalDiscoveryAlgorithm::PC
1568 ));
1569 assert_eq!(config.causal_discovery.significance_threshold, 0.05);
1570 }
1571
1572 #[test]
1573 fn test_causal_graph_creation() {
1574 let variables = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
1575 let mut graph = CausalGraph::new(variables);
1576
1577 graph.add_edge(0, 1, 0.5);
1578 graph.add_edge(1, 2, 0.8);
1579
1580 assert_eq!(graph.get_children(0), vec![1]);
1581 assert_eq!(graph.get_parents(1), vec![0]);
1582 assert!(graph.is_acyclic());
1583 }
1584
1585 #[test]
1586 fn test_structural_equation_creation() {
1587 let equation = StructuralEquation::new("Y".to_string(), vec!["X".to_string()]);
1588
1589 assert_eq!(equation.target, "Y");
1590 assert_eq!(equation.parents, vec!["X".to_string()]);
1591 }
1592
1593 #[test]
1594 fn test_intervention_creation() {
1595 let intervention = Intervention::new(
1596 vec!["X".to_string()],
1597 Array1::from_vec(vec![1.0]),
1598 InterventionType::Do,
1599 );
1600
1601 assert_eq!(intervention.targets, vec!["X".to_string()]);
1602 assert!(matches!(
1603 intervention.intervention_type,
1604 InterventionType::Do
1605 ));
1606 }
1607
1608 #[test]
1609 fn test_causal_representation_model_creation() {
1610 let config = CausalRepresentationConfig::default();
1611 let model = CausalRepresentationModel::new(config);
1612
1613 assert_eq!(model.entities.len(), 0);
1614 assert_eq!(model.causal_graph.variables.len(), 0);
1615 assert!(!model.is_trained);
1616 }
1617
1618 #[tokio::test]
1619 async fn test_causal_training() {
1620 let config = CausalRepresentationConfig::default();
1621 let mut model = CausalRepresentationModel::new(config);
1622
1623 let mut data1 = HashMap::new();
1625 data1.insert("X".to_string(), 1.0);
1626 data1.insert("Y".to_string(), 2.0);
1627 model.add_observational_data(data1);
1628
1629 let stats = model.train(Some(5)).await.unwrap();
1630 assert_eq!(stats.epochs_completed, 5);
1631 assert!(model.is_trained());
1632 }
1633
1634 #[test]
1635 fn test_causal_discovery() {
1636 let config = CausalRepresentationConfig::default();
1637 let mut model = CausalRepresentationModel::new(config);
1638
1639 let mut data = HashMap::new();
1641 data.insert("X".to_string(), 1.0);
1642 data.insert("Y".to_string(), 2.0);
1643 model.add_observational_data(data);
1644
1645 let result = model.discover_causal_structure();
1646 assert!(result.is_ok());
1647 }
1648
1649 #[test]
1650 fn test_counterfactual_query() {
1651 let config = CausalRepresentationConfig::default();
1652 let model = CausalRepresentationModel::new(config);
1653
1654 let mut evidence = HashMap::new();
1655 evidence.insert("X".to_string(), 1.0);
1656
1657 let intervention = Intervention::new(
1658 vec!["X".to_string()],
1659 Array1::from_vec(vec![2.0]),
1660 InterventionType::Do,
1661 );
1662
1663 let query = CounterfactualQuery {
1664 factual_evidence: evidence,
1665 intervention,
1666 query_variables: vec!["Y".to_string()],
1667 };
1668
1669 let result = model.answer_counterfactual(&query);
1670 assert!(result.is_ok());
1671 }
1672}