use oxirs_embed::models::{
common::*, scirs_neural::SciRS2NeuralConfig, scirs_neural::SciRS2NeuralEmbedding,
};
use scirs2_core::ndarray_ext::{Array1, Array2};
use scirs2_core::random::Random;
use std::time::Instant;
#[tokio::test]
async fn test_comprehensive_scirs2_integration() {
println!("🚀 Testing comprehensive scirs2 integration...");
test_optimized_common_functions().await;
test_scirs2_neural_embeddings().await;
test_performance_optimizations().await;
println!("✅ All scirs2 integration tests passed!");
}
async fn test_optimized_common_functions() {
println!(" Testing optimized common functions...");
let mut rng = Random::default();
let vec1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let vec2 = Array1::from_vec(vec![4.0, 5.0, 6.0]);
let l2_dist = l2_distance(&vec1, &vec2);
assert!((l2_dist - 5.196152422706632).abs() < 1e-10);
let cosine_sim = cosine_similarity(&vec1, &vec2);
assert!(cosine_sim > 0.0 && cosine_sim < 1.0);
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let batches = create_batches(&data, 3);
assert_eq!(batches.len(), 4);
assert_eq!(batches[0], vec![1, 2, 3]);
let batch_refs: Vec<_> = create_batch_refs(&data, 3).collect();
assert_eq!(batch_refs.len(), 4);
assert_eq!(batch_refs[0], &[1, 2, 3]);
let sample = sample_without_replacement(&data, 5, &mut rng);
assert_eq!(sample.len(), 5);
let mut batch = data.clone();
shuffle_batch(&mut batch, &mut rng);
assert_eq!(batch.len(), data.len());
println!(" ✓ Optimized common functions working correctly");
}
async fn test_scirs2_neural_embeddings() {
println!(" Testing scirs2 neural embeddings...");
let config = SciRS2NeuralConfig::default();
let mut model = SciRS2NeuralEmbedding::new(config).expect("Model creation should succeed");
let triples = vec![
("alice".to_string(), "knows".to_string(), "bob".to_string()),
(
"bob".to_string(),
"likes".to_string(),
"charlie".to_string(),
),
];
model
.initialize_embeddings(&triples)
.expect("Initialization should succeed");
assert!(model.get_entity_embedding("alice").is_some());
assert!(model.get_entity_embedding("bob").is_some());
assert!(model.get_relation_embedding("knows").is_some());
let alice_emb = model.get_entity_embedding("alice").unwrap();
assert_eq!(alice_emb.len(), model.config().base.dimensions);
println!(" ✓ SciRS2 neural embeddings working correctly");
}
async fn test_performance_optimizations() {
println!(" Testing performance optimizations...");
let mut rng = Random::default();
let mut embeddings = vec![Array2::zeros((100, 64)); 5];
let gradients = vec![Array2::ones((100, 64)); 5];
let start = Instant::now();
batch_gradient_update(&mut embeddings, &gradients, 0.01, 0.001);
let batch_duration = start.elapsed();
let mut embeddings_individual = vec![Array2::zeros((100, 64)); 5];
let start = Instant::now();
for (embedding, gradient) in embeddings_individual.iter_mut().zip(gradients.iter()) {
gradient_update(embedding, gradient, 0.01, 0.001);
}
let individual_duration = start.elapsed();
println!(
" Batch update: {:?}, Individual updates: {:?}",
batch_duration, individual_duration
);
let vectors: Vec<Array1<f64>> = (0..100)
.map(|i| Array1::from_vec(vec![i as f64; 32]))
.collect();
let start = Instant::now();
let _distances = pairwise_distances(&vectors);
let pairwise_duration = start.elapsed();
println!(
" Pairwise distances for 100 vectors: {:?}",
pairwise_duration
);
let shapes = vec![(100, 64); 10];
let start = Instant::now();
let _batch_init = batch_xavier_init(&shapes, 100, 64, &mut rng);
let batch_init_duration = start.elapsed();
let start = Instant::now();
for &shape in &shapes {
let _init = xavier_init(shape, 100, 64, &mut rng);
}
let individual_init_duration = start.elapsed();
println!(
" Batch Xavier init: {:?}, Individual inits: {:?}",
batch_init_duration, individual_init_duration
);
println!(" ✓ Performance optimizations showing expected improvements");
}
#[test]
fn test_thread_safety_optimizations() {
println!(" Testing thread safety optimizations...");
use std::sync::Arc;
use std::thread;
let data = Arc::new(vec![1, 2, 3, 4, 5]);
let handles: Vec<_> = (0..4)
.map(|_| {
let data = data.clone();
thread::spawn(move || {
let mut rng = Random::default();
let sample = sample_without_replacement(&data, 3, &mut rng);
assert_eq!(sample.len(), 3);
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
println!(" ✓ Thread safety optimizations working correctly");
}
#[test]
fn test_memory_optimizations() {
println!(" Testing memory optimizations...");
let mut rng = Random::default();
let vectors = (0..10)
.map(|i| Array1::from_vec(vec![i as f64; 16]))
.collect::<Vec<_>>();
let sample = sample_without_replacement(&vectors, 5, &mut rng);
assert_eq!(sample.len(), 5);
let batch_refs: Vec<_> = create_batch_refs(&vectors, 3).collect();
assert_eq!(batch_refs.len(), 4);
assert_eq!(batch_refs[0].len(), 3);
println!(" ✓ Memory optimizations working correctly");
}