#[cfg(feature = "hf-embeddings")]
mod hf_integration_tests {
use indra_db::embedding::{cosine_similarity, Embedder, HFEmbedder};
use std::path::PathBuf;
fn is_ci() -> bool {
std::env::var("CI").is_ok() || std::env::var("GITHUB_ACTIONS").is_ok()
}
fn model_cached(model_name: &str) -> bool {
let cache_dir = HFEmbedder::cache_dir();
let model_dir = format!("models--{}", model_name.replace('/', "--"));
let model_path = cache_dir.join("hub").join(&model_dir).join("snapshots");
if !model_path.exists() {
return false;
}
if let Ok(entries) = std::fs::read_dir(model_path) {
for entry in entries.flatten() {
let path = entry.path();
if path.join("model.safetensors").exists()
|| path.join("pytorch_model.bin").exists()
{
return true;
}
}
}
false
}
#[tokio::test]
#[ignore] async fn test_local_minilm_basic() {
if is_ci() || !model_cached("sentence-transformers/all-MiniLM-L6-v2") {
eprintln!("⏭️ Skipping: sentence-transformers/all-MiniLM-L6-v2 not in cache");
return;
}
println!("✓ Found cached model: sentence-transformers/all-MiniLM-L6-v2");
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.expect("Failed to load model");
assert_eq!(embedder.dimension(), 384);
assert_eq!(
embedder.model_name(),
"sentence-transformers/all-MiniLM-L6-v2"
);
let embedding = embedder.embed("Hello world").expect("Failed to embed");
assert_eq!(embedding.len(), 384);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4, "Embedding not normalized");
println!("✓ Basic embedding test passed");
}
#[tokio::test]
#[ignore]
async fn test_local_minilm_semantic_similarity() {
if is_ci() || !model_cached("sentence-transformers/all-MiniLM-L6-v2") {
eprintln!("⏭️ Skipping semantic similarity test");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.expect("Failed to load model");
let e1 = embedder.embed("The cat sits on the mat").unwrap();
let e2 = embedder.embed("A cat is sitting on a mat").unwrap();
let e3 = embedder.embed("Python is a programming language").unwrap();
let sim_similar = cosine_similarity(&e1, &e2);
let sim_different = cosine_similarity(&e1, &e3);
println!("Similar sentences similarity: {:.3}", sim_similar);
println!("Different sentences similarity: {:.3}", sim_different);
assert!(
sim_similar > sim_different,
"Similar sentences should have higher similarity"
);
assert!(sim_similar > 0.7, "Similar sentences should be > 0.7");
println!("✓ Semantic similarity test passed");
}
#[tokio::test]
#[ignore]
async fn test_local_minilm_deterministic() {
if is_ci() || !model_cached("sentence-transformers/all-MiniLM-L6-v2") {
eprintln!("⏭️ Skipping deterministic test");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.unwrap();
let text = "This is a test sentence for determinism";
let e1 = embedder.embed(text).unwrap();
let e2 = embedder.embed(text).unwrap();
for (a, b) in e1.iter().zip(e2.iter()) {
assert_eq!(a, b, "Embeddings are not deterministic");
}
println!("✓ Determinism test passed");
}
#[tokio::test]
#[ignore]
async fn test_local_minilm_batch() {
if is_ci() || !model_cached("sentence-transformers/all-MiniLM-L6-v2") {
eprintln!("⏭️ Skipping batch test");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.unwrap();
let texts = vec![
"First sentence about cats",
"Second sentence about dogs",
"Third sentence about programming",
"Fourth sentence about weather",
];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 4);
for (i, embedding) in embeddings.iter().enumerate() {
assert_eq!(embedding.len(), 384);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4, "Embedding {} not normalized", i);
}
println!("✓ Batch embedding test passed");
}
#[tokio::test]
#[ignore]
async fn test_local_mpnet_if_cached() {
if is_ci() || !model_cached("sentence-transformers/all-mpnet-base-v2") {
eprintln!("⏭️ Skipping: all-mpnet-base-v2 not in cache");
return;
}
println!("✓ Found cached model: sentence-transformers/all-mpnet-base-v2");
let embedder = HFEmbedder::new("sentence-transformers/all-mpnet-base-v2")
.await
.expect("Failed to load mpnet model");
assert_eq!(embedder.dimension(), 768);
let embedding = embedder.embed("Test with MPNet model").unwrap();
assert_eq!(embedding.len(), 768);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4);
println!("✓ MPNet model test passed");
}
#[tokio::test]
#[ignore]
async fn test_local_bge_if_cached() {
if is_ci() || !model_cached("BAAI/bge-small-en-v1.5") {
eprintln!("⏭️ Skipping: BAAI/bge-small-en-v1.5 not in cache");
return;
}
println!("✓ Found cached model: BAAI/bge-small-en-v1.5");
let embedder = HFEmbedder::new("BAAI/bge-small-en-v1.5")
.await
.expect("Failed to load BGE model");
assert_eq!(embedder.dimension(), 384);
let embedding = embedder.embed("Test with BGE model").unwrap();
assert_eq!(embedding.len(), 384);
println!("✓ BGE model test passed");
}
#[tokio::test]
#[ignore] async fn test_download_minilm() {
if is_ci() {
eprintln!("⏭️ Skipping download test in CI");
return;
}
println!("📥 Downloading sentence-transformers/all-MiniLM-L6-v2...");
println!(" (This may take a few minutes on first run)");
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.expect("Failed to download/load model");
assert_eq!(embedder.dimension(), 384);
let embedding = embedder.embed("Hello from fresh download").unwrap();
assert_eq!(embedding.len(), 384);
println!("✓ Model downloaded and working!");
println!(" Cached at: {:?}", HFEmbedder::cache_dir());
}
#[tokio::test]
#[ignore]
async fn test_download_mpnet() {
if is_ci() {
eprintln!("⏭️ Skipping download test in CI");
return;
}
println!("📥 Downloading sentence-transformers/all-mpnet-base-v2...");
println!(" (This is ~400MB, may take several minutes)");
let embedder = HFEmbedder::new("sentence-transformers/all-mpnet-base-v2")
.await
.expect("Failed to download/load MPNet model");
assert_eq!(embedder.dimension(), 768);
let embedding = embedder
.embed("Testing MPNet after download")
.expect("Failed to embed");
assert_eq!(embedding.len(), 768);
println!("✓ MPNet model downloaded and working!");
}
#[tokio::test]
#[ignore]
async fn test_download_bge_small() {
if is_ci() {
eprintln!("⏭️ Skipping download test in CI");
return;
}
println!("📥 Downloading BAAI/bge-small-en-v1.5...");
println!(" (This is ~130MB)");
let embedder = HFEmbedder::new("BAAI/bge-small-en-v1.5")
.await
.expect("Failed to download/load BGE model");
assert_eq!(embedder.dimension(), 384);
let embedding = embedder.embed("Testing BGE after download").unwrap();
assert_eq!(embedding.len(), 384);
println!("✓ BGE model downloaded and working!");
}
#[test]
fn test_cache_dir_detection() {
let default_cache = HFEmbedder::cache_dir();
assert!(default_cache.to_str().unwrap().contains("huggingface"));
std::env::set_var("HF_HOME", "/tmp/test_cache");
let custom_cache = HFEmbedder::cache_dir();
assert_eq!(custom_cache, PathBuf::from("/tmp/test_cache"));
std::env::remove_var("HF_HOME");
println!("✓ Cache directory detection works");
}
#[test]
fn test_list_cached_models() {
let cache_dir = HFEmbedder::cache_dir();
let hub_dir = cache_dir.join("hub");
if !hub_dir.exists() {
println!("ℹ️ No HF cache directory found at {:?}", hub_dir);
return;
}
println!("📦 Cached models in {:?}:", hub_dir);
if let Ok(entries) = std::fs::read_dir(&hub_dir) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("models--") {
let model_name = name_str
.strip_prefix("models--")
.unwrap()
.replace("--", "/");
let is_complete = model_cached(&model_name);
let status = if is_complete { "✓" } else { "⚠️" };
println!(" {} {}", status, model_name);
}
}
}
}
}