oxirs-embed 0.3.1

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
//! Causal Representation Learning — Tests

#[cfg(test)]
mod tests {
    use crate::causal_representation_learning::CausalRepresentationModel;
    use crate::causal_representation_learning::{
        CausalGraph, CausalRepresentationConfig, CounterfactualQuery, Intervention,
        InterventionType, StructuralEquation,
    };
    use crate::EmbeddingModel;
    use scirs2_core::ndarray_ext::Array1;
    use std::collections::HashMap;

    #[test]
    fn test_causal_representation_config_default() {
        use crate::causal_representation_learning::CausalDiscoveryAlgorithm;
        let config = CausalRepresentationConfig::default();
        assert!(matches!(
            config.causal_discovery.algorithm,
            CausalDiscoveryAlgorithm::PC
        ));
        assert_eq!(config.causal_discovery.significance_threshold, 0.05);
    }

    #[test]
    fn test_causal_graph_creation() {
        let variables = vec!["X".to_string(), "Y".to_string(), "Z".to_string()];
        let mut graph = CausalGraph::new(variables);

        graph.add_edge(0, 1, 0.5);
        graph.add_edge(1, 2, 0.8);

        assert_eq!(graph.get_children(0), vec![1]);
        assert_eq!(graph.get_parents(1), vec![0]);
        assert!(graph.is_acyclic());
    }

    #[test]
    fn test_structural_equation_creation() {
        let equation = StructuralEquation::new("Y".to_string(), vec!["X".to_string()]);
        assert_eq!(equation.target, "Y");
        assert_eq!(equation.parents, vec!["X".to_string()]);
    }

    #[test]
    fn test_intervention_creation() {
        let intervention = Intervention::new(
            vec!["X".to_string()],
            Array1::from_vec(vec![1.0]),
            InterventionType::Do,
        );
        assert_eq!(intervention.targets, vec!["X".to_string()]);
        assert!(matches!(
            intervention.intervention_type,
            InterventionType::Do
        ));
    }

    #[test]
    fn test_causal_representation_model_creation() {
        let config = CausalRepresentationConfig::default();
        let model = CausalRepresentationModel::new(config);
        assert_eq!(model.entities.len(), 0);
        assert_eq!(model.causal_graph.variables.len(), 0);
        assert!(!model.is_trained);
    }

    #[tokio::test]
    async fn test_causal_training() {
        let config = CausalRepresentationConfig::default();
        let mut model = CausalRepresentationModel::new(config);

        let mut data1 = HashMap::new();
        data1.insert("X".to_string(), 1.0);
        data1.insert("Y".to_string(), 2.0);
        model.add_observational_data(data1);

        let stats = model.train(Some(5)).await.expect("should succeed");
        assert_eq!(stats.epochs_completed, 5);
        assert!(model.is_trained());
    }

    #[test]
    fn test_causal_discovery() {
        let config = CausalRepresentationConfig::default();
        let mut model = CausalRepresentationModel::new(config);

        let mut data = HashMap::new();
        data.insert("X".to_string(), 1.0);
        data.insert("Y".to_string(), 2.0);
        model.add_observational_data(data);

        let result = model.discover_causal_structure();
        assert!(result.is_ok());
    }

    #[test]
    fn test_counterfactual_query() {
        let config = CausalRepresentationConfig::default();
        let model = CausalRepresentationModel::new(config);

        let mut evidence = HashMap::new();
        evidence.insert("X".to_string(), 1.0);

        let intervention = Intervention::new(
            vec!["X".to_string()],
            Array1::from_vec(vec![2.0]),
            InterventionType::Do,
        );

        let query = CounterfactualQuery {
            factual_evidence: evidence,
            intervention,
            query_variables: vec!["Y".to_string()],
        };

        let result = model.answer_counterfactual(&query);
        assert!(result.is_ok());
    }
}