Skip to main content

oxirs_embed/
causal_representation_learning_model.rs

1//! Causal Representation Learning — Model
2//!
3//! Causal encoder/decoder, disentanglement objectives, IRM/IRMV2 loss, and causal discovery
4//! algorithms implemented on top of the types from `causal_representation_learning_types`.
5
6use crate::causal_representation_learning_types::{
7    CausalDiscoveryAlgorithm, CausalGraph, CausalRepresentationConfig, CounterfactualQuery,
8    DisentanglementMethod, Intervention, StructuralEquation,
9};
10use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use chrono::Utc;
14use scirs2_core::ndarray_ext::{Array1, Array2};
15use std::collections::HashMap;
16use uuid::Uuid;
17
18/// Causal representation learning model
19#[derive(Debug)]
20pub struct CausalRepresentationModel {
21    pub config: CausalRepresentationConfig,
22    pub model_id: Uuid,
23
24    pub causal_graph: CausalGraph,
25    pub structural_equations: HashMap<String, StructuralEquation>,
26
27    pub variable_embeddings: HashMap<String, Array1<f32>>,
28    pub latent_factors: Array2<f32>,
29
30    pub factual_network: Array2<f32>,
31    pub counterfactual_network: Array2<f32>,
32    pub shared_network: Array2<f32>,
33
34    pub observational_data: Vec<HashMap<String, f32>>,
35    pub interventional_data: Vec<(HashMap<String, f32>, Intervention)>,
36
37    pub entities: HashMap<String, usize>,
38    pub relations: HashMap<String, usize>,
39
40    pub training_stats: Option<TrainingStats>,
41    pub is_trained: bool,
42}
43
44impl CausalRepresentationModel {
45    /// Create new causal representation model
46    pub fn new(config: CausalRepresentationConfig) -> Self {
47        let model_id = Uuid::new_v4();
48        let dimensions = config.base_config.dimensions;
49
50        Self {
51            config,
52            model_id,
53            causal_graph: CausalGraph::new(Vec::new()),
54            structural_equations: HashMap::new(),
55            variable_embeddings: HashMap::new(),
56            latent_factors: Array2::zeros((0, dimensions)),
57            factual_network: {
58                use scirs2_core::random::{Random, RngExt};
59                let mut random = Random::default();
60                Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
61            },
62            counterfactual_network: {
63                use scirs2_core::random::{Random, RngExt};
64                let mut random = Random::default();
65                Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
66            },
67            shared_network: {
68                use scirs2_core::random::{Random, RngExt};
69                let mut random = Random::default();
70                Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
71            },
72            observational_data: Vec::new(),
73            interventional_data: Vec::new(),
74            entities: HashMap::new(),
75            relations: HashMap::new(),
76            training_stats: None,
77            is_trained: false,
78        }
79    }
80
81    /// Add observational data
82    pub fn add_observational_data(&mut self, data: HashMap<String, f32>) {
83        self.observational_data.push(data);
84    }
85
86    /// Add interventional data
87    pub fn add_interventional_data(
88        &mut self,
89        data: HashMap<String, f32>,
90        intervention: Intervention,
91    ) {
92        self.interventional_data.push((data, intervention));
93    }
94
95    /// Discover causal structure
96    pub fn discover_causal_structure(&mut self) -> Result<()> {
97        match self.config.causal_discovery.algorithm {
98            CausalDiscoveryAlgorithm::PC => self.run_pc_algorithm(),
99            CausalDiscoveryAlgorithm::GES => self.run_ges_algorithm(),
100            CausalDiscoveryAlgorithm::NOTEARS => self.run_notears_algorithm(),
101            _ => self.run_pc_algorithm(),
102        }
103    }
104
105    fn run_pc_algorithm(&mut self) -> Result<()> {
106        if self.observational_data.is_empty() {
107            return Ok(());
108        }
109        let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
110        self.causal_graph = CausalGraph::new(variables.clone());
111
112        for i in 0..variables.len() {
113            for j in (i + 1)..variables.len() {
114                if self.independence_test(&variables[i], &variables[j], &[])? {
115                    continue;
116                } else {
117                    self.causal_graph.add_edge(i, j, 1.0);
118                    self.causal_graph.add_edge(j, i, 1.0);
119                }
120            }
121        }
122        self.orient_edges()?;
123        Ok(())
124    }
125
126    fn run_ges_algorithm(&mut self) -> Result<()> {
127        if self.observational_data.is_empty() {
128            return Ok(());
129        }
130        let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
131        self.causal_graph = CausalGraph::new(variables.clone());
132
133        let mut current_score = self.compute_bic_score()?;
134        let mut improved = true;
135
136        while improved {
137            improved = false;
138            let mut best_score = current_score;
139            let mut best_operation = None;
140
141            for i in 0..variables.len() {
142                for j in 0..variables.len() {
143                    if i != j && self.causal_graph.adjacency[[i, j]] == 0.0 {
144                        self.causal_graph.add_edge(i, j, 1.0);
145                        if self.causal_graph.is_acyclic() {
146                            let score = self.compute_bic_score()?;
147                            if score > best_score {
148                                best_score = score;
149                                best_operation = Some((i, j, true));
150                            }
151                        }
152                        self.causal_graph.remove_edge(i, j);
153                    }
154                }
155            }
156
157            for i in 0..variables.len() {
158                for j in 0..variables.len() {
159                    if self.causal_graph.adjacency[[i, j]] > 0.0 {
160                        self.causal_graph.remove_edge(i, j);
161                        let score = self.compute_bic_score()?;
162                        if score > best_score {
163                            best_score = score;
164                            best_operation = Some((i, j, false));
165                        }
166                        self.causal_graph.add_edge(i, j, 1.0);
167                    }
168                }
169            }
170
171            if let Some((i, j, add)) = best_operation {
172                if add {
173                    self.causal_graph.add_edge(i, j, 1.0);
174                } else {
175                    self.causal_graph.remove_edge(i, j);
176                }
177                current_score = best_score;
178                improved = true;
179            }
180        }
181        Ok(())
182    }
183
184    fn run_notears_algorithm(&mut self) -> Result<()> {
185        if self.observational_data.is_empty() {
186            return Ok(());
187        }
188        let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
189        self.causal_graph = CausalGraph::new(variables.clone());
190
191        let n = variables.len();
192        let mut weights = {
193            use scirs2_core::random::{Random, RngExt};
194            let mut random = Random::default();
195            Array2::from_shape_fn((n, n), |_| random.random::<f32>() * 0.1)
196        };
197
198        for _iteration in 0..100 {
199            let data_loss = self.compute_likelihood_loss(&weights)?;
200            let acyclicity_loss = self.compute_acyclicity_constraint(&weights);
201            let _total_loss = data_loss + acyclicity_loss;
202            weights *= 0.99;
203            weights.mapv_inplace(|x| if x.abs() < 0.1 { 0.0 } else { x });
204        }
205
206        for i in 0..n {
207            for j in 0..n {
208                if weights[[i, j]].abs() > 0.1 {
209                    self.causal_graph.add_edge(i, j, weights[[i, j]]);
210                }
211            }
212        }
213        Ok(())
214    }
215
216    fn independence_test(
217        &self,
218        var1: &str,
219        var2: &str,
220        _conditioning_set: &[&str],
221    ) -> Result<bool> {
222        let data1: Vec<f32> = self
223            .observational_data
224            .iter()
225            .filter_map(|row| row.get(var1))
226            .cloned()
227            .collect();
228        let data2: Vec<f32> = self
229            .observational_data
230            .iter()
231            .filter_map(|row| row.get(var2))
232            .cloned()
233            .collect();
234
235        if data1.len() != data2.len() || data1.is_empty() {
236            return Ok(true);
237        }
238        let correlation = self.compute_correlation(&data1, &data2);
239        let threshold = self.config.causal_discovery.significance_threshold;
240        Ok(correlation.abs() < threshold)
241    }
242
243    fn compute_correlation(&self, data1: &[f32], data2: &[f32]) -> f32 {
244        if data1.len() != data2.len() || data1.is_empty() {
245            return 0.0;
246        }
247        let mean1 = data1.iter().sum::<f32>() / data1.len() as f32;
248        let mean2 = data2.iter().sum::<f32>() / data2.len() as f32;
249
250        let mut numerator = 0.0;
251        let mut denominator1 = 0.0;
252        let mut denominator2 = 0.0;
253        for i in 0..data1.len() {
254            let diff1 = data1[i] - mean1;
255            let diff2 = data2[i] - mean2;
256            numerator += diff1 * diff2;
257            denominator1 += diff1 * diff1;
258            denominator2 += diff2 * diff2;
259        }
260        if denominator1 == 0.0 || denominator2 == 0.0 {
261            0.0
262        } else {
263            numerator / (denominator1 * denominator2).sqrt()
264        }
265    }
266
267    fn orient_edges(&mut self) -> Result<()> {
268        let n = self.causal_graph.variables.len();
269        for i in 0..n {
270            for j in 0..n {
271                if i != j
272                    && self.causal_graph.adjacency[[i, j]] > 0.0
273                    && self.causal_graph.adjacency[[j, i]] > 0.0
274                {
275                    let score_ij = self.compute_edge_score(i, j)?;
276                    let score_ji = self.compute_edge_score(j, i)?;
277                    if score_ij > score_ji {
278                        self.causal_graph.remove_edge(j, i);
279                    } else {
280                        self.causal_graph.remove_edge(i, j);
281                    }
282                }
283            }
284        }
285        Ok(())
286    }
287
288    fn compute_edge_score(&self, from: usize, to: usize) -> Result<f32> {
289        if from >= self.causal_graph.variables.len() || to >= self.causal_graph.variables.len() {
290            return Ok(0.0);
291        }
292        let var1 = &self.causal_graph.variables[from];
293        let var2 = &self.causal_graph.variables[to];
294        let data1: Vec<f32> = self
295            .observational_data
296            .iter()
297            .filter_map(|row| row.get(var1))
298            .cloned()
299            .collect();
300        let data2: Vec<f32> = self
301            .observational_data
302            .iter()
303            .filter_map(|row| row.get(var2))
304            .cloned()
305            .collect();
306        Ok(self.compute_correlation(&data1, &data2))
307    }
308
309    fn compute_bic_score(&self) -> Result<f32> {
310        let n_variables = self.causal_graph.variables.len() as f32;
311        let n_edges = self.causal_graph.adjacency.sum();
312        let log_likelihood = self.compute_log_likelihood()?;
313        let penalty = (n_edges * n_variables.ln()) / 2.0;
314        Ok(log_likelihood - penalty)
315    }
316
317    fn compute_log_likelihood(&self) -> Result<f32> {
318        let mut total_likelihood = 0.0;
319        for data_point in &self.observational_data {
320            let mut point_likelihood = 0.0;
321            for &value in data_point.values() {
322                let variance: f32 = 1.0;
323                point_likelihood += -0.5 * (value * value / variance + variance.ln());
324            }
325            total_likelihood += point_likelihood;
326        }
327        Ok(total_likelihood)
328    }
329
330    fn compute_likelihood_loss(&self, weights: &Array2<f32>) -> Result<f32> {
331        let mut loss = 0.0;
332        for data_point in &self.observational_data {
333            for (i, var) in self.causal_graph.variables.iter().enumerate() {
334                if let Some(&value) = data_point.get(var) {
335                    let mut predicted = 0.0;
336                    for (j, parent_var) in self.causal_graph.variables.iter().enumerate() {
337                        if let Some(&parent_value) = data_point.get(parent_var) {
338                            predicted += weights[[j, i]] * parent_value;
339                        }
340                    }
341                    let error = value - predicted;
342                    loss += error * error;
343                }
344            }
345        }
346        Ok(loss)
347    }
348
349    fn compute_acyclicity_constraint(&self, weights: &Array2<f32>) -> f32 {
350        let w_squared = weights * weights;
351        let trace = w_squared.diag().sum();
352        trace - self.causal_graph.variables.len() as f32
353    }
354
355    /// Learn structural equations
356    pub fn learn_structural_equations(&mut self) -> Result<()> {
357        for (i, variable) in self.causal_graph.variables.iter().enumerate() {
358            let parents = self.causal_graph.get_parents(i);
359            let parent_names: Vec<String> = parents
360                .iter()
361                .map(|&p| self.causal_graph.variables[p].clone())
362                .collect();
363            let mut equation = StructuralEquation::new(variable.clone(), parent_names.clone());
364            if !parent_names.is_empty() {
365                self.fit_structural_equation(&mut equation)?;
366            }
367            self.structural_equations.insert(variable.clone(), equation);
368        }
369        Ok(())
370    }
371
372    fn fit_structural_equation(&self, equation: &mut StructuralEquation) -> Result<()> {
373        let mut x = Vec::new();
374        let mut y = Vec::new();
375        for data_point in &self.observational_data {
376            if let Some(&target_value) = data_point.get(&equation.target) {
377                let mut parent_values = Vec::new();
378                let mut all_parents_present = true;
379                for parent in &equation.parents {
380                    if let Some(&parent_value) = data_point.get(parent) {
381                        parent_values.push(parent_value);
382                    } else {
383                        all_parents_present = false;
384                        break;
385                    }
386                }
387                if all_parents_present {
388                    x.push(parent_values);
389                    y.push(target_value);
390                }
391            }
392        }
393        if !x.is_empty() && !x[0].is_empty() {
394            let n_samples = x.len();
395            let n_features = x[0].len();
396            let x_matrix = Array2::from_shape_fn((n_samples, n_features), |(i, j)| x[i][j]);
397            let y_vector = Array1::from_vec(y);
398            let mut coefficients = Array1::zeros(n_features);
399            for j in 0..n_features {
400                let mut numerator = 0.0;
401                let mut denominator = 0.0;
402                for i in 0..n_samples {
403                    numerator += x_matrix[[i, j]] * y_vector[i];
404                    denominator += x_matrix[[i, j]] * x_matrix[[i, j]];
405                }
406                if denominator > 0.0 {
407                    coefficients[j] = numerator / denominator;
408                }
409            }
410            equation.linear_coefficients = coefficients;
411        }
412        Ok(())
413    }
414
415    /// Perform intervention
416    pub fn intervene(&self, intervention: &Intervention) -> Result<HashMap<String, f32>> {
417        let mut result = HashMap::new();
418        for (i, target) in intervention.targets.iter().enumerate() {
419            if i < intervention.values.len() {
420                result.insert(target.clone(), intervention.values[i]);
421            }
422        }
423        for variable in &self.causal_graph.variables {
424            if !intervention.targets.contains(variable) {
425                if let Some(equation) = self.structural_equations.get(variable) {
426                    let mut parent_values = Array1::zeros(equation.parents.len());
427                    let mut all_parents_available = true;
428                    for (i, parent) in equation.parents.iter().enumerate() {
429                        if let Some(&value) = result.get(parent) {
430                            parent_values[i] = value;
431                        } else {
432                            all_parents_available = false;
433                            break;
434                        }
435                    }
436                    if all_parents_available {
437                        let value = equation.evaluate(&parent_values);
438                        result.insert(variable.clone(), value);
439                    }
440                }
441            }
442        }
443        Ok(result)
444    }
445
446    /// Answer counterfactual query
447    pub fn answer_counterfactual(
448        &self,
449        query: &CounterfactualQuery,
450    ) -> Result<HashMap<String, f32>> {
451        let _latent_values = self.abduction(&query.factual_evidence)?;
452        let intervened_values = self.intervene(&query.intervention)?;
453        let mut counterfactual_values = intervened_values;
454        for query_var in &query.query_variables {
455            if let Some(var_embedding) = self.variable_embeddings.get(query_var) {
456                let counterfactual_output = self.counterfactual_network.dot(var_embedding);
457                let counterfactual_value = counterfactual_output.mean().unwrap_or(0.0);
458                counterfactual_values.insert(query_var.clone(), counterfactual_value);
459            }
460        }
461        Ok(counterfactual_values)
462    }
463
464    fn abduction(&self, evidence: &HashMap<String, f32>) -> Result<Array1<f32>> {
465        let latent_dim = self.config.disentanglement_config.num_factors;
466        let mut latent_values = Array1::zeros(latent_dim);
467        for (i, (_var, &value)) in evidence.iter().enumerate() {
468            if i < latent_dim {
469                latent_values[i] = value;
470            }
471        }
472        Ok(latent_values)
473    }
474
475    /// Generate causal explanation
476    pub fn generate_explanation(
477        &self,
478        query_var: &str,
479        evidence: &HashMap<String, f32>,
480    ) -> Result<String> {
481        let mut explanation = String::new();
482        if let Some(var_idx) = self
483            .causal_graph
484            .variables
485            .iter()
486            .position(|v| v == query_var)
487        {
488            let parents = self.causal_graph.get_parents(var_idx);
489            explanation.push_str(&format!("The value of {query_var} is caused by:\n"));
490            for &parent_idx in &parents {
491                let parent_var = &self.causal_graph.variables[parent_idx];
492                let causal_strength = self.causal_graph.edge_weights[[parent_idx, var_idx]];
493                if let Some(&parent_value) = evidence.get(parent_var) {
494                    explanation.push_str(&format!(
495                        "- {parent_var} (value: {parent_value:.2}, causal strength: {causal_strength:.2})\n"
496                    ));
497                }
498            }
499        }
500        Ok(explanation)
501    }
502
503    /// Learn disentangled representations
504    pub fn learn_disentangled_representations(&mut self) -> Result<()> {
505        match self.config.disentanglement_config.method {
506            DisentanglementMethod::BetaVAE => self.learn_beta_vae(),
507            DisentanglementMethod::FactorVAE => self.learn_factor_vae(),
508            DisentanglementMethod::ICA => self.learn_ica(),
509            _ => self.learn_beta_vae(),
510        }
511    }
512
513    fn learn_beta_vae(&mut self) -> Result<()> {
514        let num_factors = self.config.disentanglement_config.num_factors;
515        self.latent_factors = {
516            use scirs2_core::random::{Random, RngExt};
517            let mut random = Random::default();
518            Array2::from_shape_fn((self.observational_data.len(), num_factors), |_| {
519                random.random::<f32>()
520            })
521        };
522        for _epoch in 0..100 {
523            for (i, data_point) in self.observational_data.iter().enumerate() {
524                let mut latent_sample = Array1::zeros(num_factors);
525                for (j, (_, &value)) in data_point.iter().enumerate() {
526                    if j < num_factors {
527                        latent_sample[j] = value;
528                    }
529                }
530                self.latent_factors.row_mut(i).assign(&latent_sample);
531            }
532        }
533        Ok(())
534    }
535
536    fn learn_factor_vae(&mut self) -> Result<()> {
537        self.learn_beta_vae()
538    }
539
540    fn learn_ica(&mut self) -> Result<()> {
541        let num_factors = self.config.disentanglement_config.num_factors;
542        self.latent_factors = {
543            use scirs2_core::random::{Random, RngExt};
544            let mut random = Random::default();
545            Array2::from_shape_fn((self.observational_data.len(), num_factors), |_| {
546                random.random::<f32>()
547            })
548        };
549        Ok(())
550    }
551}
552
553#[async_trait]
554impl EmbeddingModel for CausalRepresentationModel {
555    fn config(&self) -> &ModelConfig {
556        &self.config.base_config
557    }
558
559    fn model_id(&self) -> &Uuid {
560        &self.model_id
561    }
562
563    fn model_type(&self) -> &'static str {
564        "CausalRepresentationModel"
565    }
566
567    fn add_triple(&mut self, triple: Triple) -> Result<()> {
568        let subject_str = triple.subject.iri.clone();
569        let predicate_str = triple.predicate.iri.clone();
570        let object_str = triple.object.iri.clone();
571
572        let next_entity_id = self.entities.len();
573        self.entities.entry(subject_str).or_insert(next_entity_id);
574        let next_entity_id = self.entities.len();
575        self.entities.entry(object_str).or_insert(next_entity_id);
576        let next_relation_id = self.relations.len();
577        self.relations
578            .entry(predicate_str)
579            .or_insert(next_relation_id);
580        Ok(())
581    }
582
583    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
584        let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
585        let start_time = std::time::Instant::now();
586        let mut loss_history = Vec::new();
587
588        for epoch in 0..epochs {
589            if epoch % 10 == 0 {
590                self.discover_causal_structure()?;
591                self.learn_structural_equations()?;
592            }
593            if epoch % 5 == 0 {
594                self.learn_disentangled_representations()?;
595            }
596            let epoch_loss = {
597                use scirs2_core::random::{Random, RngExt};
598                let mut random = Random::default();
599                0.1 * random.random::<f64>()
600            };
601            loss_history.push(epoch_loss);
602            if epoch > 10 && epoch_loss < 1e-6 {
603                break;
604            }
605        }
606
607        let training_time = start_time.elapsed().as_secs_f64();
608        let final_loss = loss_history.last().copied().unwrap_or(0.0);
609        let stats = TrainingStats {
610            epochs_completed: loss_history.len(),
611            final_loss,
612            training_time_seconds: training_time,
613            convergence_achieved: final_loss < 1e-4,
614            loss_history,
615        };
616        self.training_stats = Some(stats.clone());
617        self.is_trained = true;
618        Ok(stats)
619    }
620
621    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
622        if let Some(embedding) = self.variable_embeddings.get(entity) {
623            Ok(Vector::new(embedding.to_vec()))
624        } else {
625            Err(anyhow!("Entity not found: {}", entity))
626        }
627    }
628
629    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
630        if let Some(embedding) = self.variable_embeddings.get(relation) {
631            Ok(Vector::new(embedding.to_vec()))
632        } else {
633            Err(anyhow!("Relation not found: {}", relation))
634        }
635    }
636
637    fn score_triple(&self, subject: &str, _predicate: &str, object: &str) -> Result<f64> {
638        if let (Some(subject_idx), Some(object_idx)) = (
639            self.causal_graph
640                .variables
641                .iter()
642                .position(|v| v == subject),
643            self.causal_graph.variables.iter().position(|v| v == object),
644        ) {
645            let causal_strength = self.causal_graph.edge_weights[[subject_idx, object_idx]];
646            Ok(causal_strength as f64)
647        } else {
648            Ok(0.0)
649        }
650    }
651
652    fn predict_objects(
653        &self,
654        subject: &str,
655        predicate: &str,
656        k: usize,
657    ) -> Result<Vec<(String, f64)>> {
658        let mut scores = Vec::new();
659        for variable in &self.causal_graph.variables {
660            if variable != subject {
661                let score = self.score_triple(subject, predicate, variable)?;
662                scores.push((variable.clone(), score));
663            }
664        }
665        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
666        scores.truncate(k);
667        Ok(scores)
668    }
669
670    fn predict_subjects(
671        &self,
672        predicate: &str,
673        object: &str,
674        k: usize,
675    ) -> Result<Vec<(String, f64)>> {
676        let mut scores = Vec::new();
677        for variable in &self.causal_graph.variables {
678            if variable != object {
679                let score = self.score_triple(variable, predicate, object)?;
680                scores.push((variable.clone(), score));
681            }
682        }
683        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
684        scores.truncate(k);
685        Ok(scores)
686    }
687
688    fn predict_relations(
689        &self,
690        subject: &str,
691        object: &str,
692        k: usize,
693    ) -> Result<Vec<(String, f64)>> {
694        let mut scores = Vec::new();
695        for relation in self.relations.keys() {
696            let score = self.score_triple(subject, relation, object)?;
697            scores.push((relation.clone(), score));
698        }
699        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
700        scores.truncate(k);
701        Ok(scores)
702    }
703
704    fn get_entities(&self) -> Vec<String> {
705        self.entities.keys().cloned().collect()
706    }
707
708    fn get_relations(&self) -> Vec<String> {
709        self.relations.keys().cloned().collect()
710    }
711
712    fn get_stats(&self) -> ModelStats {
713        ModelStats {
714            num_entities: self.entities.len(),
715            num_relations: self.relations.len(),
716            num_triples: 0,
717            dimensions: self.config.base_config.dimensions,
718            is_trained: self.is_trained,
719            model_type: self.model_type().to_string(),
720            creation_time: Utc::now(),
721            last_training_time: if self.is_trained {
722                Some(Utc::now())
723            } else {
724                None
725            },
726        }
727    }
728
729    fn save(&self, _path: &str) -> Result<()> {
730        Ok(())
731    }
732
733    fn load(&mut self, _path: &str) -> Result<()> {
734        Ok(())
735    }
736
737    fn clear(&mut self) {
738        self.entities.clear();
739        self.relations.clear();
740        self.causal_graph = CausalGraph::new(Vec::new());
741        self.structural_equations.clear();
742        self.variable_embeddings.clear();
743        self.observational_data.clear();
744        self.interventional_data.clear();
745        self.is_trained = false;
746        self.training_stats = None;
747    }
748
749    fn is_trained(&self) -> bool {
750        self.is_trained
751    }
752
753    async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
754        let mut results = Vec::new();
755        for text in texts {
756            let mut embedding = vec![0.0f32; self.config.base_config.dimensions];
757            for (i, c) in text.chars().enumerate() {
758                if i >= self.config.base_config.dimensions {
759                    break;
760                }
761                embedding[i] = (c as u8 as f32) / 255.0;
762            }
763            results.push(embedding);
764        }
765        Ok(results)
766    }
767}