use std::sync::Arc;
use rustagents::{InMemoryVectorStore, MockEmbeddingModel, Retriever};
use serde_json::json;
async fn indexed_retriever() -> Retriever {
let retriever = Retriever::new(
Arc::new(MockEmbeddingModel::new(64)),
Arc::new(InMemoryVectorStore::new()),
);
retriever
.index(vec![
(
"cats".into(),
"cats are great pets that purr".into(),
json!({ "topic": "animals" }),
),
(
"finance".into(),
"the stock market crashed today".into(),
json!({ "topic": "finance" }),
),
(
"rust".into(),
"rust is a systems programming language".into(),
json!({ "topic": "programming" }),
),
])
.await
.expect("indexing succeeds");
retriever
}
#[tokio::test]
async fn retrieve_ranks_exact_match_first() {
let retriever = indexed_retriever().await;
let hits = retriever
.retrieve("rust is a systems programming language", 3)
.await
.expect("retrieve succeeds");
assert_eq!(hits.len(), 3, "top_k returns all three docs");
assert_eq!(hits[0].id, "rust", "the exact match ranks first");
assert!(
(hits[0].score - 1.0).abs() < 1e-3,
"exact match scores ~1.0, got {}",
hits[0].score
);
assert_eq!(hits[0].metadata, json!({ "topic": "programming" }));
for pair in hits.windows(2) {
assert!(
pair[0].score >= pair[1].score,
"scores are non-increasing: {} then {}",
pair[0].score,
pair[1].score
);
}
}
#[tokio::test]
async fn retrieve_respects_top_k() {
let retriever = indexed_retriever().await;
let top1 = retriever
.retrieve("the stock market crashed today", 1)
.await
.expect("retrieve succeeds");
assert_eq!(top1.len(), 1, "top_k=1 returns a single hit");
assert_eq!(top1[0].id, "finance", "finance doc is the closest match");
let all = retriever
.retrieve("cats are great pets that purr", 10)
.await
.expect("retrieve succeeds");
assert_eq!(all.len(), 3);
assert_eq!(all[0].id, "cats");
}
#[tokio::test]
async fn reindexing_same_id_updates_in_place() {
let store = Arc::new(InMemoryVectorStore::new());
let retriever = Retriever::new(Arc::new(MockEmbeddingModel::new(64)), store.clone());
retriever
.index(vec![
(
"cats".into(),
"cats are great pets that purr".into(),
json!({ "topic": "animals" }),
),
(
"finance".into(),
"the stock market crashed today".into(),
json!({ "topic": "finance" }),
),
(
"rust".into(),
"rust is a systems programming language".into(),
json!({ "topic": "programming" }),
),
])
.await
.expect("indexing succeeds");
assert_eq!(store.len(), 3, "three distinct documents indexed");
retriever
.index(vec![(
"cats".into(),
"kittens love to nap in the sun".into(),
json!({ "topic": "animals", "v": 2 }),
)])
.await
.expect("re-indexing succeeds");
assert_eq!(store.len(), 3, "re-indexing does not add a duplicate entry");
let hits = retriever
.retrieve("kittens love to nap in the sun", 1)
.await
.expect("retrieve succeeds");
assert_eq!(hits[0].id, "cats");
assert_eq!(hits[0].metadata, json!({ "topic": "animals", "v": 2 }));
}