1use crate::causal_representation_learning_types::{
7 CausalDiscoveryAlgorithm, CausalGraph, CausalRepresentationConfig, CounterfactualQuery,
8 DisentanglementMethod, Intervention, StructuralEquation,
9};
10use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use chrono::Utc;
14use scirs2_core::ndarray_ext::{Array1, Array2};
15use std::collections::HashMap;
16use uuid::Uuid;
17
18#[derive(Debug)]
20pub struct CausalRepresentationModel {
21 pub config: CausalRepresentationConfig,
22 pub model_id: Uuid,
23
24 pub causal_graph: CausalGraph,
25 pub structural_equations: HashMap<String, StructuralEquation>,
26
27 pub variable_embeddings: HashMap<String, Array1<f32>>,
28 pub latent_factors: Array2<f32>,
29
30 pub factual_network: Array2<f32>,
31 pub counterfactual_network: Array2<f32>,
32 pub shared_network: Array2<f32>,
33
34 pub observational_data: Vec<HashMap<String, f32>>,
35 pub interventional_data: Vec<(HashMap<String, f32>, Intervention)>,
36
37 pub entities: HashMap<String, usize>,
38 pub relations: HashMap<String, usize>,
39
40 pub training_stats: Option<TrainingStats>,
41 pub is_trained: bool,
42}
43
44impl CausalRepresentationModel {
45 pub fn new(config: CausalRepresentationConfig) -> Self {
47 let model_id = Uuid::new_v4();
48 let dimensions = config.base_config.dimensions;
49
50 Self {
51 config,
52 model_id,
53 causal_graph: CausalGraph::new(Vec::new()),
54 structural_equations: HashMap::new(),
55 variable_embeddings: HashMap::new(),
56 latent_factors: Array2::zeros((0, dimensions)),
57 factual_network: {
58 use scirs2_core::random::{Random, RngExt};
59 let mut random = Random::default();
60 Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
61 },
62 counterfactual_network: {
63 use scirs2_core::random::{Random, RngExt};
64 let mut random = Random::default();
65 Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
66 },
67 shared_network: {
68 use scirs2_core::random::{Random, RngExt};
69 let mut random = Random::default();
70 Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1)
71 },
72 observational_data: Vec::new(),
73 interventional_data: Vec::new(),
74 entities: HashMap::new(),
75 relations: HashMap::new(),
76 training_stats: None,
77 is_trained: false,
78 }
79 }
80
81 pub fn add_observational_data(&mut self, data: HashMap<String, f32>) {
83 self.observational_data.push(data);
84 }
85
86 pub fn add_interventional_data(
88 &mut self,
89 data: HashMap<String, f32>,
90 intervention: Intervention,
91 ) {
92 self.interventional_data.push((data, intervention));
93 }
94
95 pub fn discover_causal_structure(&mut self) -> Result<()> {
97 match self.config.causal_discovery.algorithm {
98 CausalDiscoveryAlgorithm::PC => self.run_pc_algorithm(),
99 CausalDiscoveryAlgorithm::GES => self.run_ges_algorithm(),
100 CausalDiscoveryAlgorithm::NOTEARS => self.run_notears_algorithm(),
101 _ => self.run_pc_algorithm(),
102 }
103 }
104
105 fn run_pc_algorithm(&mut self) -> Result<()> {
106 if self.observational_data.is_empty() {
107 return Ok(());
108 }
109 let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
110 self.causal_graph = CausalGraph::new(variables.clone());
111
112 for i in 0..variables.len() {
113 for j in (i + 1)..variables.len() {
114 if self.independence_test(&variables[i], &variables[j], &[])? {
115 continue;
116 } else {
117 self.causal_graph.add_edge(i, j, 1.0);
118 self.causal_graph.add_edge(j, i, 1.0);
119 }
120 }
121 }
122 self.orient_edges()?;
123 Ok(())
124 }
125
126 fn run_ges_algorithm(&mut self) -> Result<()> {
127 if self.observational_data.is_empty() {
128 return Ok(());
129 }
130 let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
131 self.causal_graph = CausalGraph::new(variables.clone());
132
133 let mut current_score = self.compute_bic_score()?;
134 let mut improved = true;
135
136 while improved {
137 improved = false;
138 let mut best_score = current_score;
139 let mut best_operation = None;
140
141 for i in 0..variables.len() {
142 for j in 0..variables.len() {
143 if i != j && self.causal_graph.adjacency[[i, j]] == 0.0 {
144 self.causal_graph.add_edge(i, j, 1.0);
145 if self.causal_graph.is_acyclic() {
146 let score = self.compute_bic_score()?;
147 if score > best_score {
148 best_score = score;
149 best_operation = Some((i, j, true));
150 }
151 }
152 self.causal_graph.remove_edge(i, j);
153 }
154 }
155 }
156
157 for i in 0..variables.len() {
158 for j in 0..variables.len() {
159 if self.causal_graph.adjacency[[i, j]] > 0.0 {
160 self.causal_graph.remove_edge(i, j);
161 let score = self.compute_bic_score()?;
162 if score > best_score {
163 best_score = score;
164 best_operation = Some((i, j, false));
165 }
166 self.causal_graph.add_edge(i, j, 1.0);
167 }
168 }
169 }
170
171 if let Some((i, j, add)) = best_operation {
172 if add {
173 self.causal_graph.add_edge(i, j, 1.0);
174 } else {
175 self.causal_graph.remove_edge(i, j);
176 }
177 current_score = best_score;
178 improved = true;
179 }
180 }
181 Ok(())
182 }
183
184 fn run_notears_algorithm(&mut self) -> Result<()> {
185 if self.observational_data.is_empty() {
186 return Ok(());
187 }
188 let variables: Vec<String> = self.observational_data[0].keys().cloned().collect();
189 self.causal_graph = CausalGraph::new(variables.clone());
190
191 let n = variables.len();
192 let mut weights = {
193 use scirs2_core::random::{Random, RngExt};
194 let mut random = Random::default();
195 Array2::from_shape_fn((n, n), |_| random.random::<f32>() * 0.1)
196 };
197
198 for _iteration in 0..100 {
199 let data_loss = self.compute_likelihood_loss(&weights)?;
200 let acyclicity_loss = self.compute_acyclicity_constraint(&weights);
201 let _total_loss = data_loss + acyclicity_loss;
202 weights *= 0.99;
203 weights.mapv_inplace(|x| if x.abs() < 0.1 { 0.0 } else { x });
204 }
205
206 for i in 0..n {
207 for j in 0..n {
208 if weights[[i, j]].abs() > 0.1 {
209 self.causal_graph.add_edge(i, j, weights[[i, j]]);
210 }
211 }
212 }
213 Ok(())
214 }
215
216 fn independence_test(
217 &self,
218 var1: &str,
219 var2: &str,
220 _conditioning_set: &[&str],
221 ) -> Result<bool> {
222 let data1: Vec<f32> = self
223 .observational_data
224 .iter()
225 .filter_map(|row| row.get(var1))
226 .cloned()
227 .collect();
228 let data2: Vec<f32> = self
229 .observational_data
230 .iter()
231 .filter_map(|row| row.get(var2))
232 .cloned()
233 .collect();
234
235 if data1.len() != data2.len() || data1.is_empty() {
236 return Ok(true);
237 }
238 let correlation = self.compute_correlation(&data1, &data2);
239 let threshold = self.config.causal_discovery.significance_threshold;
240 Ok(correlation.abs() < threshold)
241 }
242
243 fn compute_correlation(&self, data1: &[f32], data2: &[f32]) -> f32 {
244 if data1.len() != data2.len() || data1.is_empty() {
245 return 0.0;
246 }
247 let mean1 = data1.iter().sum::<f32>() / data1.len() as f32;
248 let mean2 = data2.iter().sum::<f32>() / data2.len() as f32;
249
250 let mut numerator = 0.0;
251 let mut denominator1 = 0.0;
252 let mut denominator2 = 0.0;
253 for i in 0..data1.len() {
254 let diff1 = data1[i] - mean1;
255 let diff2 = data2[i] - mean2;
256 numerator += diff1 * diff2;
257 denominator1 += diff1 * diff1;
258 denominator2 += diff2 * diff2;
259 }
260 if denominator1 == 0.0 || denominator2 == 0.0 {
261 0.0
262 } else {
263 numerator / (denominator1 * denominator2).sqrt()
264 }
265 }
266
267 fn orient_edges(&mut self) -> Result<()> {
268 let n = self.causal_graph.variables.len();
269 for i in 0..n {
270 for j in 0..n {
271 if i != j
272 && self.causal_graph.adjacency[[i, j]] > 0.0
273 && self.causal_graph.adjacency[[j, i]] > 0.0
274 {
275 let score_ij = self.compute_edge_score(i, j)?;
276 let score_ji = self.compute_edge_score(j, i)?;
277 if score_ij > score_ji {
278 self.causal_graph.remove_edge(j, i);
279 } else {
280 self.causal_graph.remove_edge(i, j);
281 }
282 }
283 }
284 }
285 Ok(())
286 }
287
288 fn compute_edge_score(&self, from: usize, to: usize) -> Result<f32> {
289 if from >= self.causal_graph.variables.len() || to >= self.causal_graph.variables.len() {
290 return Ok(0.0);
291 }
292 let var1 = &self.causal_graph.variables[from];
293 let var2 = &self.causal_graph.variables[to];
294 let data1: Vec<f32> = self
295 .observational_data
296 .iter()
297 .filter_map(|row| row.get(var1))
298 .cloned()
299 .collect();
300 let data2: Vec<f32> = self
301 .observational_data
302 .iter()
303 .filter_map(|row| row.get(var2))
304 .cloned()
305 .collect();
306 Ok(self.compute_correlation(&data1, &data2))
307 }
308
309 fn compute_bic_score(&self) -> Result<f32> {
310 let n_variables = self.causal_graph.variables.len() as f32;
311 let n_edges = self.causal_graph.adjacency.sum();
312 let log_likelihood = self.compute_log_likelihood()?;
313 let penalty = (n_edges * n_variables.ln()) / 2.0;
314 Ok(log_likelihood - penalty)
315 }
316
317 fn compute_log_likelihood(&self) -> Result<f32> {
318 let mut total_likelihood = 0.0;
319 for data_point in &self.observational_data {
320 let mut point_likelihood = 0.0;
321 for &value in data_point.values() {
322 let variance: f32 = 1.0;
323 point_likelihood += -0.5 * (value * value / variance + variance.ln());
324 }
325 total_likelihood += point_likelihood;
326 }
327 Ok(total_likelihood)
328 }
329
330 fn compute_likelihood_loss(&self, weights: &Array2<f32>) -> Result<f32> {
331 let mut loss = 0.0;
332 for data_point in &self.observational_data {
333 for (i, var) in self.causal_graph.variables.iter().enumerate() {
334 if let Some(&value) = data_point.get(var) {
335 let mut predicted = 0.0;
336 for (j, parent_var) in self.causal_graph.variables.iter().enumerate() {
337 if let Some(&parent_value) = data_point.get(parent_var) {
338 predicted += weights[[j, i]] * parent_value;
339 }
340 }
341 let error = value - predicted;
342 loss += error * error;
343 }
344 }
345 }
346 Ok(loss)
347 }
348
349 fn compute_acyclicity_constraint(&self, weights: &Array2<f32>) -> f32 {
350 let w_squared = weights * weights;
351 let trace = w_squared.diag().sum();
352 trace - self.causal_graph.variables.len() as f32
353 }
354
355 pub fn learn_structural_equations(&mut self) -> Result<()> {
357 for (i, variable) in self.causal_graph.variables.iter().enumerate() {
358 let parents = self.causal_graph.get_parents(i);
359 let parent_names: Vec<String> = parents
360 .iter()
361 .map(|&p| self.causal_graph.variables[p].clone())
362 .collect();
363 let mut equation = StructuralEquation::new(variable.clone(), parent_names.clone());
364 if !parent_names.is_empty() {
365 self.fit_structural_equation(&mut equation)?;
366 }
367 self.structural_equations.insert(variable.clone(), equation);
368 }
369 Ok(())
370 }
371
372 fn fit_structural_equation(&self, equation: &mut StructuralEquation) -> Result<()> {
373 let mut x = Vec::new();
374 let mut y = Vec::new();
375 for data_point in &self.observational_data {
376 if let Some(&target_value) = data_point.get(&equation.target) {
377 let mut parent_values = Vec::new();
378 let mut all_parents_present = true;
379 for parent in &equation.parents {
380 if let Some(&parent_value) = data_point.get(parent) {
381 parent_values.push(parent_value);
382 } else {
383 all_parents_present = false;
384 break;
385 }
386 }
387 if all_parents_present {
388 x.push(parent_values);
389 y.push(target_value);
390 }
391 }
392 }
393 if !x.is_empty() && !x[0].is_empty() {
394 let n_samples = x.len();
395 let n_features = x[0].len();
396 let x_matrix = Array2::from_shape_fn((n_samples, n_features), |(i, j)| x[i][j]);
397 let y_vector = Array1::from_vec(y);
398 let mut coefficients = Array1::zeros(n_features);
399 for j in 0..n_features {
400 let mut numerator = 0.0;
401 let mut denominator = 0.0;
402 for i in 0..n_samples {
403 numerator += x_matrix[[i, j]] * y_vector[i];
404 denominator += x_matrix[[i, j]] * x_matrix[[i, j]];
405 }
406 if denominator > 0.0 {
407 coefficients[j] = numerator / denominator;
408 }
409 }
410 equation.linear_coefficients = coefficients;
411 }
412 Ok(())
413 }
414
415 pub fn intervene(&self, intervention: &Intervention) -> Result<HashMap<String, f32>> {
417 let mut result = HashMap::new();
418 for (i, target) in intervention.targets.iter().enumerate() {
419 if i < intervention.values.len() {
420 result.insert(target.clone(), intervention.values[i]);
421 }
422 }
423 for variable in &self.causal_graph.variables {
424 if !intervention.targets.contains(variable) {
425 if let Some(equation) = self.structural_equations.get(variable) {
426 let mut parent_values = Array1::zeros(equation.parents.len());
427 let mut all_parents_available = true;
428 for (i, parent) in equation.parents.iter().enumerate() {
429 if let Some(&value) = result.get(parent) {
430 parent_values[i] = value;
431 } else {
432 all_parents_available = false;
433 break;
434 }
435 }
436 if all_parents_available {
437 let value = equation.evaluate(&parent_values);
438 result.insert(variable.clone(), value);
439 }
440 }
441 }
442 }
443 Ok(result)
444 }
445
446 pub fn answer_counterfactual(
448 &self,
449 query: &CounterfactualQuery,
450 ) -> Result<HashMap<String, f32>> {
451 let _latent_values = self.abduction(&query.factual_evidence)?;
452 let intervened_values = self.intervene(&query.intervention)?;
453 let mut counterfactual_values = intervened_values;
454 for query_var in &query.query_variables {
455 if let Some(var_embedding) = self.variable_embeddings.get(query_var) {
456 let counterfactual_output = self.counterfactual_network.dot(var_embedding);
457 let counterfactual_value = counterfactual_output.mean().unwrap_or(0.0);
458 counterfactual_values.insert(query_var.clone(), counterfactual_value);
459 }
460 }
461 Ok(counterfactual_values)
462 }
463
464 fn abduction(&self, evidence: &HashMap<String, f32>) -> Result<Array1<f32>> {
465 let latent_dim = self.config.disentanglement_config.num_factors;
466 let mut latent_values = Array1::zeros(latent_dim);
467 for (i, (_var, &value)) in evidence.iter().enumerate() {
468 if i < latent_dim {
469 latent_values[i] = value;
470 }
471 }
472 Ok(latent_values)
473 }
474
475 pub fn generate_explanation(
477 &self,
478 query_var: &str,
479 evidence: &HashMap<String, f32>,
480 ) -> Result<String> {
481 let mut explanation = String::new();
482 if let Some(var_idx) = self
483 .causal_graph
484 .variables
485 .iter()
486 .position(|v| v == query_var)
487 {
488 let parents = self.causal_graph.get_parents(var_idx);
489 explanation.push_str(&format!("The value of {query_var} is caused by:\n"));
490 for &parent_idx in &parents {
491 let parent_var = &self.causal_graph.variables[parent_idx];
492 let causal_strength = self.causal_graph.edge_weights[[parent_idx, var_idx]];
493 if let Some(&parent_value) = evidence.get(parent_var) {
494 explanation.push_str(&format!(
495 "- {parent_var} (value: {parent_value:.2}, causal strength: {causal_strength:.2})\n"
496 ));
497 }
498 }
499 }
500 Ok(explanation)
501 }
502
503 pub fn learn_disentangled_representations(&mut self) -> Result<()> {
505 match self.config.disentanglement_config.method {
506 DisentanglementMethod::BetaVAE => self.learn_beta_vae(),
507 DisentanglementMethod::FactorVAE => self.learn_factor_vae(),
508 DisentanglementMethod::ICA => self.learn_ica(),
509 _ => self.learn_beta_vae(),
510 }
511 }
512
513 fn learn_beta_vae(&mut self) -> Result<()> {
514 let num_factors = self.config.disentanglement_config.num_factors;
515 self.latent_factors = {
516 use scirs2_core::random::{Random, RngExt};
517 let mut random = Random::default();
518 Array2::from_shape_fn((self.observational_data.len(), num_factors), |_| {
519 random.random::<f32>()
520 })
521 };
522 for _epoch in 0..100 {
523 for (i, data_point) in self.observational_data.iter().enumerate() {
524 let mut latent_sample = Array1::zeros(num_factors);
525 for (j, (_, &value)) in data_point.iter().enumerate() {
526 if j < num_factors {
527 latent_sample[j] = value;
528 }
529 }
530 self.latent_factors.row_mut(i).assign(&latent_sample);
531 }
532 }
533 Ok(())
534 }
535
536 fn learn_factor_vae(&mut self) -> Result<()> {
537 self.learn_beta_vae()
538 }
539
540 fn learn_ica(&mut self) -> Result<()> {
541 let num_factors = self.config.disentanglement_config.num_factors;
542 self.latent_factors = {
543 use scirs2_core::random::{Random, RngExt};
544 let mut random = Random::default();
545 Array2::from_shape_fn((self.observational_data.len(), num_factors), |_| {
546 random.random::<f32>()
547 })
548 };
549 Ok(())
550 }
551}
552
553#[async_trait]
554impl EmbeddingModel for CausalRepresentationModel {
555 fn config(&self) -> &ModelConfig {
556 &self.config.base_config
557 }
558
559 fn model_id(&self) -> &Uuid {
560 &self.model_id
561 }
562
563 fn model_type(&self) -> &'static str {
564 "CausalRepresentationModel"
565 }
566
567 fn add_triple(&mut self, triple: Triple) -> Result<()> {
568 let subject_str = triple.subject.iri.clone();
569 let predicate_str = triple.predicate.iri.clone();
570 let object_str = triple.object.iri.clone();
571
572 let next_entity_id = self.entities.len();
573 self.entities.entry(subject_str).or_insert(next_entity_id);
574 let next_entity_id = self.entities.len();
575 self.entities.entry(object_str).or_insert(next_entity_id);
576 let next_relation_id = self.relations.len();
577 self.relations
578 .entry(predicate_str)
579 .or_insert(next_relation_id);
580 Ok(())
581 }
582
583 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
584 let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
585 let start_time = std::time::Instant::now();
586 let mut loss_history = Vec::new();
587
588 for epoch in 0..epochs {
589 if epoch % 10 == 0 {
590 self.discover_causal_structure()?;
591 self.learn_structural_equations()?;
592 }
593 if epoch % 5 == 0 {
594 self.learn_disentangled_representations()?;
595 }
596 let epoch_loss = {
597 use scirs2_core::random::{Random, RngExt};
598 let mut random = Random::default();
599 0.1 * random.random::<f64>()
600 };
601 loss_history.push(epoch_loss);
602 if epoch > 10 && epoch_loss < 1e-6 {
603 break;
604 }
605 }
606
607 let training_time = start_time.elapsed().as_secs_f64();
608 let final_loss = loss_history.last().copied().unwrap_or(0.0);
609 let stats = TrainingStats {
610 epochs_completed: loss_history.len(),
611 final_loss,
612 training_time_seconds: training_time,
613 convergence_achieved: final_loss < 1e-4,
614 loss_history,
615 };
616 self.training_stats = Some(stats.clone());
617 self.is_trained = true;
618 Ok(stats)
619 }
620
621 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
622 if let Some(embedding) = self.variable_embeddings.get(entity) {
623 Ok(Vector::new(embedding.to_vec()))
624 } else {
625 Err(anyhow!("Entity not found: {}", entity))
626 }
627 }
628
629 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
630 if let Some(embedding) = self.variable_embeddings.get(relation) {
631 Ok(Vector::new(embedding.to_vec()))
632 } else {
633 Err(anyhow!("Relation not found: {}", relation))
634 }
635 }
636
637 fn score_triple(&self, subject: &str, _predicate: &str, object: &str) -> Result<f64> {
638 if let (Some(subject_idx), Some(object_idx)) = (
639 self.causal_graph
640 .variables
641 .iter()
642 .position(|v| v == subject),
643 self.causal_graph.variables.iter().position(|v| v == object),
644 ) {
645 let causal_strength = self.causal_graph.edge_weights[[subject_idx, object_idx]];
646 Ok(causal_strength as f64)
647 } else {
648 Ok(0.0)
649 }
650 }
651
652 fn predict_objects(
653 &self,
654 subject: &str,
655 predicate: &str,
656 k: usize,
657 ) -> Result<Vec<(String, f64)>> {
658 let mut scores = Vec::new();
659 for variable in &self.causal_graph.variables {
660 if variable != subject {
661 let score = self.score_triple(subject, predicate, variable)?;
662 scores.push((variable.clone(), score));
663 }
664 }
665 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
666 scores.truncate(k);
667 Ok(scores)
668 }
669
670 fn predict_subjects(
671 &self,
672 predicate: &str,
673 object: &str,
674 k: usize,
675 ) -> Result<Vec<(String, f64)>> {
676 let mut scores = Vec::new();
677 for variable in &self.causal_graph.variables {
678 if variable != object {
679 let score = self.score_triple(variable, predicate, object)?;
680 scores.push((variable.clone(), score));
681 }
682 }
683 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
684 scores.truncate(k);
685 Ok(scores)
686 }
687
688 fn predict_relations(
689 &self,
690 subject: &str,
691 object: &str,
692 k: usize,
693 ) -> Result<Vec<(String, f64)>> {
694 let mut scores = Vec::new();
695 for relation in self.relations.keys() {
696 let score = self.score_triple(subject, relation, object)?;
697 scores.push((relation.clone(), score));
698 }
699 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
700 scores.truncate(k);
701 Ok(scores)
702 }
703
704 fn get_entities(&self) -> Vec<String> {
705 self.entities.keys().cloned().collect()
706 }
707
708 fn get_relations(&self) -> Vec<String> {
709 self.relations.keys().cloned().collect()
710 }
711
712 fn get_stats(&self) -> ModelStats {
713 ModelStats {
714 num_entities: self.entities.len(),
715 num_relations: self.relations.len(),
716 num_triples: 0,
717 dimensions: self.config.base_config.dimensions,
718 is_trained: self.is_trained,
719 model_type: self.model_type().to_string(),
720 creation_time: Utc::now(),
721 last_training_time: if self.is_trained {
722 Some(Utc::now())
723 } else {
724 None
725 },
726 }
727 }
728
729 fn save(&self, _path: &str) -> Result<()> {
730 Ok(())
731 }
732
733 fn load(&mut self, _path: &str) -> Result<()> {
734 Ok(())
735 }
736
737 fn clear(&mut self) {
738 self.entities.clear();
739 self.relations.clear();
740 self.causal_graph = CausalGraph::new(Vec::new());
741 self.structural_equations.clear();
742 self.variable_embeddings.clear();
743 self.observational_data.clear();
744 self.interventional_data.clear();
745 self.is_trained = false;
746 self.training_stats = None;
747 }
748
749 fn is_trained(&self) -> bool {
750 self.is_trained
751 }
752
753 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
754 let mut results = Vec::new();
755 for text in texts {
756 let mut embedding = vec![0.0f32; self.config.base_config.dimensions];
757 for (i, c) in text.chars().enumerate() {
758 if i >= self.config.base_config.dimensions {
759 break;
760 }
761 embedding[i] = (c as u8 as f32) / 255.0;
762 }
763 results.push(embedding);
764 }
765 Ok(results)
766 }
767}