use umi_memory::extraction::{EntityExtractor, ExtractionOptions};
use umi_memory::llm::OpenAIProvider;
use umi_memory::embedding::{EmbeddingProvider, OpenAIEmbeddingProvider, SimEmbeddingProvider};
use umi_memory::retrieval::DualRetriever;
use umi_memory::storage::{SimStorageBackend, SimVectorBackend};
use umi_memory::dst::SimConfig;
use std::env;
use umi_memory::extraction::EntityType as ExtractionEntityType;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Testing OpenAI LLM and Embedding Providers ===\n");
let api_key = match env::var("OPENAI_API_KEY") {
Ok(key) => key,
Err(_) => {
eprintln!("❌ ERROR: OPENAI_API_KEY environment variable not set");
eprintln!("\nSetup:");
eprintln!(" 1. Get API key from https://platform.openai.com/api-keys");
eprintln!(" 2. Export OPENAI_API_KEY=<your-key>");
eprintln!(" 3. Run: cargo run --example test_openai --features openai,embedding-openai\n");
std::process::exit(1);
}
};
println!("✓ API key found ({}...)", &api_key[..8]);
println!();
let llm = OpenAIProvider::new(api_key.clone());
let embedder = OpenAIEmbeddingProvider::new(api_key.clone());
println!("✓ Created OpenAI providers (LLM + Embeddings)\n");
println!("--- Test 1: Entity Extraction (LLM) ---");
test_entity_extraction(&llm).await?;
println!();
println!("--- Test 2: Query Rewriting (LLM) ---");
test_query_rewriting(&llm).await?;
println!();
println!("--- Test 3: Embedding Generation ---");
test_embedding_generation(&embedder).await?;
println!();
println!("--- Test 4: Full Pipeline (LLM + Embeddings) ---");
test_full_pipeline(&llm, &embedder).await?;
println!();
println!("=== All Tests Passed! ===");
println!("\n✅ OpenAI integration works correctly");
println!("💰 Approximate cost: $0.01-0.02");
Ok(())
}
async fn test_entity_extraction(llm: &OpenAIProvider) -> Result<(), Box<dyn std::error::Error>> {
let extractor = EntityExtractor::new(llm.clone());
let text = "Bob is the CTO at TechCo. He has 15 years of experience in AI and machine learning.";
println!(" Input: \"{}\"", text);
let result = extractor
.extract(text, ExtractionOptions::default())
.await?;
println!(" ✓ Extraction succeeded");
println!(" Found {} entities:", result.entities.len());
for entity in &result.entities {
println!(" - Type: {:?}, Name: {}", entity.entity_type, entity.name);
println!(" Content: {}", entity.content);
println!(" Confidence: {:.2}", entity.confidence);
}
assert!(
!result.entities.is_empty(),
"Expected at least one entity"
);
let has_fallback = result.entities.iter().any(|e| e.entity_type == ExtractionEntityType::Note);
assert!(
!has_fallback,
"Expected real entity extraction, got fallback (ExtractionEntityType::Note)"
);
Ok(())
}
async fn test_query_rewriting(llm: &OpenAIProvider) -> Result<(), Box<dyn std::error::Error>> {
let embedder = SimEmbeddingProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let retriever = DualRetriever::new(llm.clone(), embedder, vector, storage);
let query = "What companies are people working at?";
println!(" Input query: \"{}\"", query);
let variations = retriever.rewrite_query(query).await;
println!(" ✓ Query rewriting succeeded");
println!(" Generated {} variations:", variations.len());
for (i, variation) in variations.iter().enumerate() {
println!(" {}. {}", i + 1, variation);
}
assert!(
variations.contains(&query.to_string()),
"Expected variations to include original query"
);
if variations.len() == 1 {
println!(" ℹ Query rewriting used fallback (returned original query only)");
println!(" This is acceptable - graceful degradation working correctly");
} else {
println!(" ✓ Query expansion generated {} variations", variations.len());
}
Ok(())
}
async fn test_embedding_generation(embedder: &OpenAIEmbeddingProvider) -> Result<(), Box<dyn std::error::Error>> {
let texts = vec![
"Alice is a software engineer",
"Bob is a data scientist",
"The weather is nice today",
];
println!(" Generating embeddings for {} texts:", texts.len());
for (i, text) in texts.iter().enumerate() {
println!(" {}. \"{}\"", i + 1, text);
}
let mut embeddings = Vec::new();
for text in &texts {
let embedding = embedder.embed(text).await?;
embeddings.push(embedding);
}
println!(" ✓ Embedding generation succeeded");
println!(" Generated {} embeddings", embeddings.len());
println!(" Embedding dimensions: {}", embeddings[0].len());
assert_eq!(
embeddings.len(),
texts.len(),
"Expected {} embeddings, got {}",
texts.len(),
embeddings.len()
);
for embedding in &embeddings {
assert!(
embedding.len() >= 512 && embedding.len() <= 4096,
"Expected embedding dimensions 512-4096, got {}",
embedding.len()
);
}
let sim_work = cosine_similarity(&embeddings[0], &embeddings[1]);
let sim_weather = cosine_similarity(&embeddings[0], &embeddings[2]);
println!(" Cosine similarity (work-related): {:.4}", sim_work);
println!(" Cosine similarity (work vs weather): {:.4}", sim_weather);
assert!(
sim_work > sim_weather,
"Expected work-related texts to be more similar than unrelated texts"
);
Ok(())
}
async fn test_full_pipeline(llm: &OpenAIProvider, embedder: &OpenAIEmbeddingProvider) -> Result<(), Box<dyn std::error::Error>> {
println!(" Testing complete extraction + embedding pipeline");
let extractor = EntityExtractor::new(llm.clone());
let text = "Carol is a product manager at StartupX";
let extraction = extractor
.extract(text, ExtractionOptions::default())
.await?;
println!(" ✓ Extracted {} entities", extraction.entities.len());
let mut embeddings = Vec::new();
for entity in &extraction.entities {
let entity_text = format!("{}: {}", entity.name, entity.content);
let embedding = embedder.embed(&entity_text).await?;
embeddings.push(embedding);
}
println!(" ✓ Generated {} embeddings", embeddings.len());
println!(" ✓ Full pipeline works end-to-end!");
Ok(())
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vectors must have same length");
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot_product / (magnitude_a * magnitude_b)
}