use std::sync::Arc;
use oxirs_embed::distributed_training::{
ModelShardManager, ParameterServer, ParameterServerConfig, ShardingStrategy, TripleSample,
UpdateMode, Worker, WorkerConfig,
};
fn toy_kg() -> Vec<TripleSample> {
let mut samples = Vec::new();
for i in 0..16 {
samples.push(TripleSample::new(
format!("e{i}"),
"rel0",
format!("e{}", (i + 1) % 16),
));
}
for i in 0..16 {
if i % 2 == 0 {
samples.push(TripleSample::new(
format!("e{i}"),
"rel1",
format!("e{}", (i + 2) % 16),
));
}
}
samples
}
fn entity_ids(n: usize) -> Vec<String> {
(0..n).map(|i| format!("e{i}")).collect()
}
fn relation_ids() -> Vec<String> {
vec!["rel0".to_string(), "rel1".to_string()]
}
fn build_server(workers: usize, mode: UpdateMode, num_shards: usize) -> Arc<ParameterServer> {
let cfg = ParameterServerConfig {
embedding_dim: 16,
num_entities: 16,
num_relations: 2,
num_shards,
expected_workers: workers,
update_mode: mode,
learning_rate: 0.05,
max_staleness: 32,
barrier_timeout: std::time::Duration::from_secs(2),
};
let mgr = ModelShardManager::new(num_shards, ShardingStrategy::EntityHash);
Arc::new(
ParameterServer::new(cfg, entity_ids(16), relation_ids(), mgr)
.expect("parameter server construction"),
)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn shard_partition_stable_across_managers() {
let mgr_a = ModelShardManager::new(4, ShardingStrategy::EntityHash);
let mgr_b = ModelShardManager::new(4, ShardingStrategy::EntityHash);
for id in entity_ids(64) {
assert_eq!(
mgr_a.shard_for(&id),
mgr_b.shard_for(&id),
"two managers with the same num_shards must route the same id to the same shard"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn parameter_server_pull_returns_full_table() {
let server = build_server(1, UpdateMode::Async, 4);
let mut total = 0usize;
for shard in 0..server.num_shards() {
let snap = server.pull(shard).await.expect("pull");
total += snap.entity_ids.len();
assert_eq!(snap.relations.len(), 2);
}
assert_eq!(total, 16);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn single_worker_async_training_completes() {
let server = build_server(1, UpdateMode::Async, 4);
let cfg = WorkerConfig {
worker_id: 0,
max_steps: 10,
margin: 1.0,
l2_reg: 0.0,
seed: 42,
};
let w = Worker::new(cfg, Arc::clone(&server), toy_kg());
let loss = w.run().await.expect("worker run");
assert!(!loss.history.is_empty(), "worker recorded zero losses");
assert!(loss.history.iter().all(|x| x.is_finite()));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn four_workers_async_match_single_worker_baseline() {
let kg = toy_kg();
let baseline_server = build_server(1, UpdateMode::Async, 4);
let baseline_worker = Worker::new(
WorkerConfig {
worker_id: 0,
max_steps: 30,
margin: 1.0,
l2_reg: 0.0,
seed: 11,
},
Arc::clone(&baseline_server),
kg.clone(),
);
let baseline_loss = baseline_worker.run().await.expect("baseline run");
let baseline_tail_mean = tail_mean(&baseline_loss.history, 0.1);
let multi_server = build_server(1, UpdateMode::Async, 4);
let mut workers = Vec::new();
for i in 0..4 {
workers.push(Worker::new(
WorkerConfig {
worker_id: i,
max_steps: 30,
margin: 1.0,
l2_reg: 0.0,
seed: 11 + i as u64,
},
Arc::clone(&multi_server),
kg.clone(),
));
}
let losses = oxirs_embed::distributed_training::worker::run_workers(workers)
.await
.expect("multi-worker run");
assert_eq!(losses.len(), 4);
let multi_tail_mean: f64 = losses
.iter()
.map(|l| tail_mean(&l.history, 0.1))
.sum::<f64>()
/ 4.0;
let epsilon = 1.0_f64;
let delta = (multi_tail_mean - baseline_tail_mean).abs();
assert!(
delta < epsilon,
"4-worker mean tail loss {multi_tail_mean} differs from baseline {baseline_tail_mean} by {delta}, exceeding ε={epsilon}"
);
for (i, l) in losses.iter().enumerate() {
let head_mean = head_mean(&l.history, 0.1);
let tail_mean = tail_mean(&l.history, 0.1);
assert!(
tail_mean <= head_mean + 0.5,
"worker {i} did not improve: head={head_mean}, tail={tail_mean}"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn four_workers_sync_barrier_advances_step() {
let server = build_server(1, UpdateMode::Sync, 4);
let mut workers = Vec::new();
for i in 0..4 {
workers.push(Worker::new(
WorkerConfig {
worker_id: i,
max_steps: 2,
margin: 1.0,
l2_reg: 0.0,
seed: 1 + i as u64,
},
Arc::clone(&server),
toy_kg(),
));
}
let losses = oxirs_embed::distributed_training::worker::run_workers(workers)
.await
.expect("sync run");
assert_eq!(losses.len(), 4);
let stats = server.stats().await;
assert!(
stats.barriers_completed > 0,
"sync server completed {} barriers, expected ≥1",
stats.barriers_completed
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn parameter_server_runs_with_eight_workers() {
let server = build_server(1, UpdateMode::Async, 4);
let mut workers = Vec::new();
for i in 0..8 {
workers.push(Worker::new(
WorkerConfig {
worker_id: i,
max_steps: 3,
margin: 1.0,
l2_reg: 0.0,
seed: 1 + i as u64,
},
Arc::clone(&server),
toy_kg(),
));
}
let losses = oxirs_embed::distributed_training::worker::run_workers(workers)
.await
.expect("8-worker run");
assert_eq!(losses.len(), 8);
}
fn head_mean(losses: &[f64], frac: f64) -> f64 {
if losses.is_empty() {
return 0.0;
}
let n = ((losses.len() as f64) * frac).ceil() as usize;
let n = n.max(1).min(losses.len());
losses[..n].iter().sum::<f64>() / n as f64
}
fn tail_mean(losses: &[f64], frac: f64) -> f64 {
if losses.is_empty() {
return 0.0;
}
let n = ((losses.len() as f64) * frac).ceil() as usize;
let n = n.max(1).min(losses.len());
losses[losses.len() - n..].iter().sum::<f64>() / n as f64
}
#[test]
fn helm_chart_yaml_files_exist() {
use std::path::PathBuf;
let mut chart_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
chart_dir.push("deploy");
chart_dir.push("helm");
chart_dir.push("oxirs-embed");
assert!(
chart_dir.join("Chart.yaml").is_file(),
"missing Chart.yaml at {chart_dir:?}"
);
assert!(
chart_dir.join("values.yaml").is_file(),
"missing values.yaml"
);
let templates = chart_dir.join("templates");
assert!(templates.is_dir(), "missing templates dir");
for f in [
"deployment.yaml",
"service.yaml",
"configmap.yaml",
"hpa.yaml",
"pdb.yaml",
] {
assert!(templates.join(f).is_file(), "missing helm template {f}");
}
}
#[test]
fn raw_k8s_manifests_exist() {
use std::path::PathBuf;
let mut k8s_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
k8s_dir.push("deploy");
k8s_dir.push("k8s");
for f in [
"deployment.yaml",
"service.yaml",
"configmap.yaml",
"hpa.yaml",
"pdb.yaml",
] {
assert!(
k8s_dir.join(f).is_file(),
"missing raw manifest {f} at {k8s_dir:?}"
);
}
}
#[test]
fn deploy_dockerfile_and_compose_exist() {
use std::path::PathBuf;
let mut deploy_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
deploy_dir.push("deploy");
assert!(
deploy_dir.join("Dockerfile").is_file(),
"missing Dockerfile at {deploy_dir:?}"
);
assert!(
deploy_dir.join("docker-compose.yml").is_file(),
"missing docker-compose.yml"
);
assert!(
deploy_dir.join("README.md").is_file(),
"missing deploy README"
);
assert!(
deploy_dir
.join("monitoring")
.join("prometheus.yml")
.is_file(),
"missing prometheus.yml"
);
assert!(
deploy_dir
.join("monitoring")
.join("grafana-dashboard.json")
.is_file(),
"missing grafana-dashboard.json"
);
}
#[test]
#[ignore = "requires `helm` binary on PATH; run manually with `cargo test --ignored helm_template_smoke`"]
fn helm_template_smoke() {
use std::path::PathBuf;
use std::process::Command;
let mut chart_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
chart_dir.push("deploy");
chart_dir.push("helm");
chart_dir.push("oxirs-embed");
let out = Command::new("helm")
.args(["template", "oxirs-embed-test"])
.arg(&chart_dir)
.output()
.expect("helm binary not available");
assert!(
out.status.success(),
"helm template failed: {}\n--- stderr ---\n{}",
String::from_utf8_lossy(&out.stdout),
String::from_utf8_lossy(&out.stderr)
);
let rendered = String::from_utf8_lossy(&out.stdout);
assert!(rendered.contains("kind: Deployment"));
assert!(rendered.contains("kind: Service"));
}
#[test]
#[ignore = "requires `docker` binary on PATH and a network connection; run manually with `cargo test --ignored docker_build_smoke`"]
fn docker_build_smoke() {
use std::path::PathBuf;
use std::process::Command;
let mut deploy_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
deploy_dir.push("deploy");
let out = Command::new("docker")
.args(["build", "-f"])
.arg(deploy_dir.join("Dockerfile"))
.arg("-t")
.arg("oxirs-embed:test")
.arg(env!("CARGO_MANIFEST_DIR"))
.output()
.expect("docker binary not available");
assert!(
out.status.success(),
"docker build failed: stderr={}",
String::from_utf8_lossy(&out.stderr)
);
}