oxirs-embed 0.3.1

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
use crate::continual_learning_types::{
    ContinualLearningConfig, ContinualLearningModel, EWCState, MemoryConfig, MemoryEntry,
    MemoryType, MemoryUpdateStrategy, TaskInfo,
};
use crate::ModelConfig;
use scirs2_core::ndarray_ext::Array1;
use scirs2_core::ndarray_ext::Array2;
use scirs2_core::random::{Random, RngExt};

#[test]
fn test_continual_learning_config_default() {
    let config = ContinualLearningConfig::default();
    assert!(matches!(
        config.memory_config.memory_type,
        MemoryType::EpisodicMemory
    ));
    assert_eq!(config.memory_config.memory_capacity, 10000);
}

#[test]
fn test_task_info_creation() {
    let task = TaskInfo::new("task1".to_string(), "classification".to_string());
    assert_eq!(task.task_id, "task1");
    assert_eq!(task.task_type, "classification");
    assert_eq!(task.examples_seen, 0);
}

#[test]
fn test_memory_entry_creation() {
    let data = Array1::from_vec(vec![1.0, 2.0, 3.0]);
    let target = Array1::from_vec(vec![0.0, 1.0]);
    let entry = MemoryEntry::new(data, target, "task1".to_string());

    assert_eq!(entry.task_id, "task1");
    assert_eq!(entry.importance, 1.0);
    assert_eq!(entry.access_count, 0);
}

#[test]
fn test_continual_learning_model_creation() {
    let config = ContinualLearningConfig::default();
    let model = ContinualLearningModel::new(config);

    assert_eq!(model.entities.len(), 0);
    assert_eq!(model.examples_seen, 0);
    assert!(model.current_task.is_none());
}

#[tokio::test]
async fn test_task_management() {
    let config = ContinualLearningConfig::default();
    let mut model = ContinualLearningModel::new(config);

    model
        .start_task("task1".to_string(), "test".to_string())
        .expect("should succeed");
    assert!(model.current_task.is_some());
    assert_eq!(
        model.current_task.as_ref().expect("should succeed").task_id,
        "task1"
    );

    model
        .start_task("task2".to_string(), "test".to_string())
        .expect("should succeed");
    assert_eq!(model.task_history.len(), 1);
    assert_eq!(
        model.current_task.as_ref().expect("should succeed").task_id,
        "task2"
    );
}

#[tokio::test]
async fn test_add_example() {
    let config = ContinualLearningConfig {
        base_config: ModelConfig {
            dimensions: 3,
            ..Default::default()
        },
        ..Default::default()
    };
    let mut model = ContinualLearningModel::new(config);

    model
        .start_task("task1".to_string(), "test".to_string())
        .expect("should succeed");

    let data = Array1::from_vec(vec![1.0, 2.0, 3.0]);
    let target = Array1::from_vec(vec![1.0, 2.0, 3.0]);

    model
        .add_example(data, target, Some("task1".to_string()))
        .await
        .expect("should succeed");

    assert_eq!(model.examples_seen, 1);
    assert_eq!(model.episodic_memory.len(), 1);
    assert_eq!(
        model
            .current_task
            .as_ref()
            .expect("should succeed")
            .examples_seen,
        1
    );
}

#[tokio::test]
async fn test_memory_management() {
    let config = ContinualLearningConfig {
        memory_config: MemoryConfig {
            memory_capacity: 3,
            update_strategy: MemoryUpdateStrategy::FIFO,
            ..Default::default()
        },
        ..Default::default()
    };

    let mut model = ContinualLearningModel::new(config);
    model
        .start_task("task1".to_string(), "test".to_string())
        .expect("should succeed");

    for i in 0..5 {
        let data = Array1::from_vec(vec![i as f32]);
        let target = Array1::from_vec(vec![i as f32]);
        model
            .add_example(data, target, Some("task1".to_string()))
            .await
            .expect("should succeed");
    }

    assert_eq!(model.episodic_memory.len(), 3);
}

#[tokio::test]
async fn test_continual_training() {
    let config = ContinualLearningConfig {
        base_config: ModelConfig {
            dimensions: 3,
            max_epochs: 10,
            ..Default::default()
        },
        ..Default::default()
    };
    let mut model = ContinualLearningModel::new(config);

    model
        .start_task("initial_task".to_string(), "training".to_string())
        .expect("should succeed");

    use crate::EmbeddingModel;
    let stats = model.train(Some(10)).await.expect("should succeed");
    assert_eq!(stats.epochs_completed, 10);
    assert!(model.is_trained());
    assert!(!model.task_history.is_empty());
}

#[test]
fn test_forgetting_evaluation() {
    let config = ContinualLearningConfig::default();
    let model = ContinualLearningModel::new(config);

    let forgetting = model.evaluate_forgetting();
    assert_eq!(forgetting, 0.0);
}

#[test]
fn test_ewc_state_creation() {
    let mut random = Random::default();
    let fisher = Array2::from_shape_fn((5, 5), |_| random.random::<f32>());
    let params = Array2::from_shape_fn((5, 5), |_| random.random::<f32>());

    let ewc_state = EWCState {
        fisher_information: fisher,
        optimal_parameters: params,
        task_id: "task1".to_string(),
        importance: 1.0,
    };

    assert_eq!(ewc_state.task_id, "task1");
    assert_eq!(ewc_state.importance, 1.0);
}