#![deny(warnings)]
#![deny(missing_docs)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Document {
pub id: String,
pub content: String,
pub metadata: Option<serde_json::Value>,
pub score: f64,
}
impl Document {
#[must_use]
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: id.into(),
content: content.into(),
metadata: None,
score: 0.0,
}
}
}
#[async_trait]
pub trait Retriever: Send + Sync + 'static {
async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>>;
}
pub trait GroundingStrategy: Send + Sync + 'static {
fn ground(&self, documents: &[Document]) -> String;
}
pub struct PrependStrategy;
impl GroundingStrategy for PrependStrategy {
fn ground(&self, documents: &[Document]) -> String {
if documents.is_empty() {
return String::new();
}
let mut ctx = String::from("Relevant context:\n\n");
for (i, doc) in documents.iter().enumerate() {
use std::fmt::Write;
let _ = write!(ctx, "[{}] {}\n\n", i + 1, doc.content);
}
ctx
}
}
pub struct KeywordRetriever {
documents: Vec<Document>,
}
impl KeywordRetriever {
#[must_use]
pub fn new() -> Self {
Self {
documents: Vec::new(),
}
}
pub fn add(&mut self, doc: Document) {
self.documents.push(doc);
}
pub fn add_many(&mut self, docs: impl IntoIterator<Item = Document>) {
self.documents.extend(docs);
}
#[allow(clippy::cast_precision_loss)]
fn score(query_terms: &[String], content: &str) -> f64 {
let content_lower = content.to_lowercase();
let words: Vec<&str> = content_lower.split_whitespace().collect();
let doc_len = words.len() as f64;
if doc_len == 0.0 {
return 0.0;
}
let mut total_score = 0.0;
for term in query_terms {
let tf = words.iter().filter(|w| **w == term.as_str()).count() as f64;
let score = tf / (tf + 1.0);
total_score += score;
}
total_score
}
}
impl Default for KeywordRetriever {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Retriever for KeywordRetriever {
async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>> {
let terms: Vec<String> = query
.to_lowercase()
.split_whitespace()
.map(String::from)
.collect();
let mut scored: Vec<Document> = self
.documents
.iter()
.map(|doc| {
let mut d = doc.clone();
d.score = Self::score(&terms, &doc.content);
d
})
.filter(|d| d.score > 0.0)
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(limit);
Ok(scored)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_keyword_retriever_basic() {
let mut r = KeywordRetriever::new();
r.add(Document::new("1", "Rust is a systems programming language"));
r.add(Document::new("2", "Python is great for data science"));
r.add(Document::new("3", "Rust has zero-cost abstractions"));
let results = r.retrieve("Rust programming", 10).await.unwrap();
assert!(!results.is_empty());
assert!(results.len() >= 2);
assert_eq!(results[0].id, "1");
}
#[tokio::test]
async fn test_keyword_retriever_empty_query() {
let mut r = KeywordRetriever::new();
r.add(Document::new("1", "Some content"));
let results = r.retrieve("", 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_keyword_retriever_no_match() {
let mut r = KeywordRetriever::new();
r.add(Document::new("1", "Hello world"));
let results = r.retrieve("quantum computing", 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_keyword_retriever_limit() {
let mut r = KeywordRetriever::new();
for i in 0..10 {
r.add(Document::new(format!("{i}"), format!("rust item {i}")));
}
let results = r.retrieve("rust", 3).await.unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn test_prepend_strategy() {
let docs = vec![
Document::new("1", "First doc"),
Document::new("2", "Second doc"),
];
let ctx = PrependStrategy.ground(&docs);
assert!(ctx.contains("[1] First doc"));
assert!(ctx.contains("[2] Second doc"));
}
#[test]
fn test_prepend_strategy_empty() {
let ctx = PrependStrategy.ground(&[]);
assert!(ctx.is_empty());
}
#[test]
fn test_document_new() {
let doc = Document::new("id1", "content1");
assert_eq!(doc.id, "id1");
assert_eq!(doc.content, "content1");
assert!(doc.metadata.is_none());
assert!((doc.score - 0.0).abs() < f64::EPSILON);
}
}