mod types;
pub use types::*;
use async_trait::async_trait;
use serde_json::Value;
use crate::error::Result;
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a.sqrt() * norm_b.sqrt())
}
impl InMemoryVectorStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.entries.lock().map(|e| e.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[async_trait]
impl VectorStore for InMemoryVectorStore {
async fn add(&self, id: String, vector: Vec<f32>, metadata: Value) -> Result<()> {
let mut entries = self.entries.lock().map_err(|e| {
crate::error::TinyAgentsError::Embedding(format!("vector store lock poisoned: {e}"))
})?;
if let Some(existing) = entries.iter_mut().find(|e| e.id == id) {
existing.vector = vector;
existing.metadata = metadata;
} else {
entries.push(VectorEntry {
id,
vector,
metadata,
});
}
Ok(())
}
async fn query(&self, vector: &[f32], top_k: usize) -> Result<Vec<ScoredDoc>> {
if top_k == 0 {
return Ok(Vec::new());
}
let entries = self.entries.lock().map_err(|e| {
crate::error::TinyAgentsError::Embedding(format!("vector store lock poisoned: {e}"))
})?;
let mut scored: Vec<ScoredDoc> = entries
.iter()
.map(|e| ScoredDoc {
id: e.id.clone(),
score: cosine_similarity(vector, &e.vector),
metadata: e.metadata.clone(),
})
.collect();
scored.sort_by(|a, b| b.score.total_cmp(&a.score));
scored.truncate(top_k);
Ok(scored)
}
}
impl Retriever {
pub fn new(
model: std::sync::Arc<dyn EmbeddingModel>,
store: std::sync::Arc<dyn VectorStore>,
) -> Self {
Self { model, store }
}
pub fn model(&self) -> &std::sync::Arc<dyn EmbeddingModel> {
&self.model
}
pub fn store(&self) -> &std::sync::Arc<dyn VectorStore> {
&self.store
}
pub async fn index(&self, docs: Vec<(String, String, Value)>) -> Result<()> {
if docs.is_empty() {
return Ok(());
}
let texts: Vec<String> = docs.iter().map(|(_, text, _)| text.clone()).collect();
let vectors = self.model.embed(&texts).await?;
for ((id, _text, metadata), vector) in docs.into_iter().zip(vectors) {
self.store.add(id, vector, metadata).await?;
}
Ok(())
}
pub async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<ScoredDoc>> {
let mut vectors = self.model.embed(&[query.to_string()]).await?;
let query_vector = vectors.pop().unwrap_or_default();
self.store.query(&query_vector, top_k).await
}
}
#[cfg(feature = "openai")]
mod openai {
use async_trait::async_trait;
use serde_json::{Value, json};
use super::EmbeddingModel;
use crate::error::{Result, TinyAgentsError};
const DEFAULT_MODEL: &str = "text-embedding-3-small";
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
const DEFAULT_DIMENSIONS: usize = 1536;
pub struct OpenAiEmbeddingModel {
client: reqwest::Client,
api_key: String,
model: String,
base_url: String,
dimensions: usize,
}
impl OpenAiEmbeddingModel {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: DEFAULT_MODEL.to_string(),
base_url: DEFAULT_BASE_URL.to_string(),
dimensions: DEFAULT_DIMENSIONS,
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into().trim_end_matches('/').to_string();
self
}
pub fn with_dimensions(mut self, dimensions: usize) -> Self {
self.dimensions = dimensions;
self
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.ok()
.filter(|k| !k.trim().is_empty())
.ok_or_else(|| {
TinyAgentsError::Validation(
"OPENAI_API_KEY is not set; export it or add it to a .env file".to_string(),
)
})?;
let mut model = Self::new(api_key);
if let Ok(name) = std::env::var("OPENAI_EMBEDDING_MODEL")
&& !name.trim().is_empty()
{
model = model.with_model(name);
}
if let Ok(url) = std::env::var("OPENAI_BASE_URL")
&& !url.trim().is_empty()
{
model = model.with_base_url(url);
}
Ok(model)
}
}
#[async_trait]
impl EmbeddingModel for OpenAiEmbeddingModel {
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/embeddings", self.base_url);
let body = json!({
"model": self.model,
"input": texts,
"dimensions": self.dimensions,
});
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| {
TinyAgentsError::Embedding(format!(
"openai embeddings request to {url} failed: {e}"
))
})?;
let status = response.status();
let text = response.text().await.map_err(|e| {
TinyAgentsError::Embedding(format!("openai embeddings body read failed: {e}"))
})?;
if !status.is_success() {
return Err(TinyAgentsError::Embedding(format!(
"openai embeddings returned HTTP {status}: {text}"
)));
}
let value: Value = serde_json::from_str(&text)?;
let data = value
.get("data")
.and_then(|d| d.as_array())
.ok_or_else(|| {
TinyAgentsError::Embedding(
"openai embeddings response missing `data` array".into(),
)
})?;
let mut vectors = Vec::with_capacity(data.len());
for item in data {
let embedding = item
.get("embedding")
.and_then(|e| e.as_array())
.ok_or_else(|| {
TinyAgentsError::Embedding(
"openai embeddings response missing `embedding` array".into(),
)
})?;
vectors.push(
embedding
.iter()
.map(|n| n.as_f64().unwrap_or(0.0) as f32)
.collect(),
);
}
Ok(vectors)
}
fn dimensions(&self) -> usize {
self.dimensions
}
}
}
#[cfg(feature = "openai")]
pub use openai::OpenAiEmbeddingModel;
#[cfg(test)]
mod test;