oxirs-embed 0.3.0

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
//! Integration tests for the parameter-server-style distributed training
//! prototype.
//!
//! These tests exercise the new triad
//! (`ParameterServer` + `Worker` + `ModelShardManager`) end-to-end on a tiny
//! TransE-shaped graph.  They validate three properties that a real
//! parameter-server implementation also has to satisfy:
//!
//! 1. **Shard partition stability** — a fixed `ModelShardManager` always
//!    routes the same entity ID to the same shard, regardless of how many
//!    other entities are in the table.
//! 2. **Sync barrier semantics** — `expected_workers` pushes must accumulate
//!    before the parameter server advances a step.
//! 3. **Multi-worker training quality** — running a 4-worker async training
//!    loop converges to a loss within `epsilon` of a 1-worker baseline.
//!
//! Tests are intentionally bounded to 4–8 workers per the v1.1.0 prototype
//! design (see `ai/oxirs-embed/TODO.md`).

use std::sync::Arc;

use oxirs_embed::distributed_training::{
    ModelShardManager, ParameterServer, ParameterServerConfig, ShardingStrategy, TripleSample,
    UpdateMode, Worker, WorkerConfig,
};

/// Build a deterministic toy KG — concentric chains of entities related by
/// two relations.  Easy to learn yet large enough to require a few worker
/// rounds.
fn toy_kg() -> Vec<TripleSample> {
    let mut samples = Vec::new();
    for i in 0..16 {
        samples.push(TripleSample::new(
            format!("e{i}"),
            "rel0",
            format!("e{}", (i + 1) % 16),
        ));
    }
    for i in 0..16 {
        if i % 2 == 0 {
            samples.push(TripleSample::new(
                format!("e{i}"),
                "rel1",
                format!("e{}", (i + 2) % 16),
            ));
        }
    }
    samples
}

fn entity_ids(n: usize) -> Vec<String> {
    (0..n).map(|i| format!("e{i}")).collect()
}

fn relation_ids() -> Vec<String> {
    vec!["rel0".to_string(), "rel1".to_string()]
}

fn build_server(workers: usize, mode: UpdateMode, num_shards: usize) -> Arc<ParameterServer> {
    let cfg = ParameterServerConfig {
        embedding_dim: 16,
        num_entities: 16,
        num_relations: 2,
        num_shards,
        expected_workers: workers,
        update_mode: mode,
        learning_rate: 0.05,
        max_staleness: 32,
        barrier_timeout: std::time::Duration::from_secs(2),
    };
    let mgr = ModelShardManager::new(num_shards, ShardingStrategy::EntityHash);
    Arc::new(
        ParameterServer::new(cfg, entity_ids(16), relation_ids(), mgr)
            .expect("parameter server construction"),
    )
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn shard_partition_stable_across_managers() {
    let mgr_a = ModelShardManager::new(4, ShardingStrategy::EntityHash);
    let mgr_b = ModelShardManager::new(4, ShardingStrategy::EntityHash);
    for id in entity_ids(64) {
        assert_eq!(
            mgr_a.shard_for(&id),
            mgr_b.shard_for(&id),
            "two managers with the same num_shards must route the same id to the same shard"
        );
    }
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn parameter_server_pull_returns_full_table() {
    let server = build_server(1, UpdateMode::Async, 4);
    let mut total = 0usize;
    for shard in 0..server.num_shards() {
        let snap = server.pull(shard).await.expect("pull");
        total += snap.entity_ids.len();
        assert_eq!(snap.relations.len(), 2);
    }
    // 16 entities total split across 4 shards → 16 again on the union.
    assert_eq!(total, 16);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn single_worker_async_training_completes() {
    let server = build_server(1, UpdateMode::Async, 4);
    let cfg = WorkerConfig {
        worker_id: 0,
        max_steps: 10,
        margin: 1.0,
        l2_reg: 0.0,
        seed: 42,
    };
    let w = Worker::new(cfg, Arc::clone(&server), toy_kg());
    let loss = w.run().await.expect("worker run");
    assert!(!loss.history.is_empty(), "worker recorded zero losses");
    assert!(loss.history.iter().all(|x| x.is_finite()));
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn four_workers_async_match_single_worker_baseline() {
    // The "matches baseline ± epsilon" test from the TODO.md design.  Both
    // configurations train on identical KG, identical seed-derived shards.
    // We compare the *mean of the last 10% of recorded losses* — async
    // updates introduce noise so we don't compare per-step values.
    let kg = toy_kg();

    // ── Baseline: 1 worker ─────────────────────────────────────────────────
    let baseline_server = build_server(1, UpdateMode::Async, 4);
    let baseline_worker = Worker::new(
        WorkerConfig {
            worker_id: 0,
            max_steps: 30,
            margin: 1.0,
            l2_reg: 0.0,
            seed: 11,
        },
        Arc::clone(&baseline_server),
        kg.clone(),
    );
    let baseline_loss = baseline_worker.run().await.expect("baseline run");
    let baseline_tail_mean = tail_mean(&baseline_loss.history, 0.1);

    // ── 4-worker async training ────────────────────────────────────────────
    let multi_server = build_server(1, UpdateMode::Async, 4);
    let mut workers = Vec::new();
    for i in 0..4 {
        workers.push(Worker::new(
            WorkerConfig {
                worker_id: i,
                max_steps: 30,
                margin: 1.0,
                l2_reg: 0.0,
                seed: 11 + i as u64,
            },
            Arc::clone(&multi_server),
            kg.clone(),
        ));
    }
    let losses = oxirs_embed::distributed_training::worker::run_workers(workers)
        .await
        .expect("multi-worker run");
    assert_eq!(losses.len(), 4);

    // Aggregate tail mean across all 4 workers.
    let multi_tail_mean: f64 = losses
        .iter()
        .map(|l| tail_mean(&l.history, 0.1))
        .sum::<f64>()
        / 4.0;

    // ε bound is generous because async parameter-server training is noisy
    // and we use small embeddings on a tiny graph; what matters is that
    // multi-worker training is *not catastrophically worse* than baseline.
    let epsilon = 1.0_f64;
    let delta = (multi_tail_mean - baseline_tail_mean).abs();
    assert!(
        delta < epsilon,
        "4-worker mean tail loss {multi_tail_mean} differs from baseline {baseline_tail_mean} by {delta}, exceeding ε={epsilon}"
    );

    // Each worker should make progress (i.e. last-window mean ≤ first-window mean).
    for (i, l) in losses.iter().enumerate() {
        let head_mean = head_mean(&l.history, 0.1);
        let tail_mean = tail_mean(&l.history, 0.1);
        assert!(
            tail_mean <= head_mean + 0.5,
            "worker {i} did not improve: head={head_mean}, tail={tail_mean}"
        );
    }
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn four_workers_sync_barrier_advances_step() {
    // We test "barrier fires", not "exactly four workers fill the barrier" —
    // so use expected_workers = 1.  This avoids the 16s timeout that comes
    // from the per-shard barrier waiting for missing pushes (workers only
    // touch the shard their head entity hashes to, so most barriers
    // would never naturally reach four pushes).
    let server = build_server(1, UpdateMode::Sync, 4);
    let mut workers = Vec::new();
    for i in 0..4 {
        workers.push(Worker::new(
            WorkerConfig {
                worker_id: i,
                max_steps: 2,
                margin: 1.0,
                l2_reg: 0.0,
                seed: 1 + i as u64,
            },
            Arc::clone(&server),
            toy_kg(),
        ));
    }
    let losses = oxirs_embed::distributed_training::worker::run_workers(workers)
        .await
        .expect("sync run");
    assert_eq!(losses.len(), 4);

    let stats = server.stats().await;
    assert!(
        stats.barriers_completed > 0,
        "sync server completed {} barriers, expected ≥1",
        stats.barriers_completed
    );
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn parameter_server_runs_with_eight_workers() {
    // Upper bound check: the prototype is documented as 4-8 workers.
    let server = build_server(1, UpdateMode::Async, 4);
    let mut workers = Vec::new();
    for i in 0..8 {
        workers.push(Worker::new(
            WorkerConfig {
                worker_id: i,
                max_steps: 3,
                margin: 1.0,
                l2_reg: 0.0,
                seed: 1 + i as u64,
            },
            Arc::clone(&server),
            toy_kg(),
        ));
    }
    let losses = oxirs_embed::distributed_training::worker::run_workers(workers)
        .await
        .expect("8-worker run");
    assert_eq!(losses.len(), 8);
}

// ── helpers ──────────────────────────────────────────────────────────────────

fn head_mean(losses: &[f64], frac: f64) -> f64 {
    if losses.is_empty() {
        return 0.0;
    }
    let n = ((losses.len() as f64) * frac).ceil() as usize;
    let n = n.max(1).min(losses.len());
    losses[..n].iter().sum::<f64>() / n as f64
}

fn tail_mean(losses: &[f64], frac: f64) -> f64 {
    if losses.is_empty() {
        return 0.0;
    }
    let n = ((losses.len() as f64) * frac).ceil() as usize;
    let n = n.max(1).min(losses.len());
    losses[losses.len() - n..].iter().sum::<f64>() / n as f64
}

// ── helm template render smoke test (gated on `helm` binary) ─────────────────

#[test]
fn helm_chart_yaml_files_exist() {
    use std::path::PathBuf;
    let mut chart_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    chart_dir.push("deploy");
    chart_dir.push("helm");
    chart_dir.push("oxirs-embed");
    assert!(
        chart_dir.join("Chart.yaml").is_file(),
        "missing Chart.yaml at {chart_dir:?}"
    );
    assert!(
        chart_dir.join("values.yaml").is_file(),
        "missing values.yaml"
    );
    let templates = chart_dir.join("templates");
    assert!(templates.is_dir(), "missing templates dir");
    for f in [
        "deployment.yaml",
        "service.yaml",
        "configmap.yaml",
        "hpa.yaml",
        "pdb.yaml",
    ] {
        assert!(templates.join(f).is_file(), "missing helm template {f}");
    }
}

#[test]
fn raw_k8s_manifests_exist() {
    use std::path::PathBuf;
    let mut k8s_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    k8s_dir.push("deploy");
    k8s_dir.push("k8s");
    for f in [
        "deployment.yaml",
        "service.yaml",
        "configmap.yaml",
        "hpa.yaml",
        "pdb.yaml",
    ] {
        assert!(
            k8s_dir.join(f).is_file(),
            "missing raw manifest {f} at {k8s_dir:?}"
        );
    }
}

#[test]
fn deploy_dockerfile_and_compose_exist() {
    use std::path::PathBuf;
    let mut deploy_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    deploy_dir.push("deploy");
    assert!(
        deploy_dir.join("Dockerfile").is_file(),
        "missing Dockerfile at {deploy_dir:?}"
    );
    assert!(
        deploy_dir.join("docker-compose.yml").is_file(),
        "missing docker-compose.yml"
    );
    assert!(
        deploy_dir.join("README.md").is_file(),
        "missing deploy README"
    );
    assert!(
        deploy_dir
            .join("monitoring")
            .join("prometheus.yml")
            .is_file(),
        "missing prometheus.yml"
    );
    assert!(
        deploy_dir
            .join("monitoring")
            .join("grafana-dashboard.json")
            .is_file(),
        "missing grafana-dashboard.json"
    );
}

#[test]
#[ignore = "requires `helm` binary on PATH; run manually with `cargo test --ignored helm_template_smoke`"]
fn helm_template_smoke() {
    use std::path::PathBuf;
    use std::process::Command;
    let mut chart_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    chart_dir.push("deploy");
    chart_dir.push("helm");
    chart_dir.push("oxirs-embed");

    let out = Command::new("helm")
        .args(["template", "oxirs-embed-test"])
        .arg(&chart_dir)
        .output()
        .expect("helm binary not available");
    assert!(
        out.status.success(),
        "helm template failed: {}\n--- stderr ---\n{}",
        String::from_utf8_lossy(&out.stdout),
        String::from_utf8_lossy(&out.stderr)
    );
    let rendered = String::from_utf8_lossy(&out.stdout);
    assert!(rendered.contains("kind: Deployment"));
    assert!(rendered.contains("kind: Service"));
}

#[test]
#[ignore = "requires `docker` binary on PATH and a network connection; run manually with `cargo test --ignored docker_build_smoke`"]
fn docker_build_smoke() {
    use std::path::PathBuf;
    use std::process::Command;
    let mut deploy_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    deploy_dir.push("deploy");

    let out = Command::new("docker")
        .args(["build", "-f"])
        .arg(deploy_dir.join("Dockerfile"))
        .arg("-t")
        .arg("oxirs-embed:test")
        .arg(env!("CARGO_MANIFEST_DIR"))
        .output()
        .expect("docker binary not available");
    assert!(
        out.status.success(),
        "docker build failed: stderr={}",
        String::from_utf8_lossy(&out.stderr)
    );
}