pub mod gemini;
pub mod openai;
pub mod ollama;
pub mod synthetic;
pub mod bedrock;
use crate::core::config::Config;
use crate::core::error::Result;
use crate::core::types::{EmbeddingKind, EmbeddingResult, Sector};
use async_trait::async_trait;
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult>;
async fn embed_batch(
&self,
texts: &[(&str, &Sector)],
) -> Result<Vec<EmbeddingResult>> {
let mut results = Vec::with_capacity(texts.len());
for (text, sector) in texts {
results.push(self.embed(text, sector).await?);
}
Ok(results)
}
fn dimensions(&self) -> usize;
fn name(&self) -> &'static str;
fn supports_batch(&self) -> bool {
false
}
}
pub fn create_provider(config: &Config) -> Box<dyn EmbeddingProvider> {
match config.embedding_kind {
EmbeddingKind::Synthetic => {
Box::new(synthetic::SyntheticProvider::new(config.vec_dim))
}
EmbeddingKind::OpenAI => {
Box::new(openai::OpenAIProvider::new(config))
}
EmbeddingKind::Ollama => {
Box::new(ollama::OllamaProvider::new(config))
}
EmbeddingKind::Gemini => {
Box::new(gemini::GeminiProvider::new(config))
}
#[cfg(feature = "aws")]
EmbeddingKind::Bedrock => {
Box::new(bedrock::BedrockProvider::new(config))
}
#[cfg(not(feature = "aws"))]
EmbeddingKind::Bedrock => {
panic!("AWS Bedrock support requires the 'aws' feature to be enabled");
}
}
}
pub fn compress_vector(v: &[f32], target_dim: usize) -> Vec<f32> {
if v.len() <= target_dim {
return v.to_vec();
}
let mut compressed = vec![0.0f32; target_dim];
let bin_size = v.len() as f32 / target_dim as f32;
for i in 0..target_dim {
let start = (i as f32 * bin_size) as usize;
let end = ((i + 1) as f32 * bin_size) as usize;
let end = end.min(v.len());
let mut sum = 0.0f32;
let mut count = 0;
for j in start..end {
sum += v[j];
count += 1;
}
compressed[i] = if count > 0 { sum / count as f32 } else { 0.0 };
}
let norm: f32 = compressed.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in &mut compressed {
*x /= norm;
}
}
compressed
}
pub fn resize_vector(v: &[f32], target_dim: usize) -> Vec<f32> {
if v.len() == target_dim {
return v.to_vec();
}
if v.len() > target_dim {
v[..target_dim].to_vec()
} else {
let mut result = v.to_vec();
result.resize(target_dim, 0.0);
result
}
}
pub fn fuse_vectors(synthetic: &[f32], semantic: &[f32]) -> Vec<f32> {
let total_len = synthetic.len() + semantic.len();
let mut fused = Vec::with_capacity(total_len);
for &v in synthetic {
fused.push(v * 0.6);
}
for &v in semantic {
fused.push(v * 0.4);
}
let norm: f32 = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in &mut fused {
*x /= norm;
}
}
fused
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compress_vector() {
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let compressed = compress_vector(&v, 4);
assert_eq!(compressed.len(), 4);
let norm: f32 = compressed.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_resize_vector_truncate() {
let v = vec![1.0, 2.0, 3.0, 4.0];
let resized = resize_vector(&v, 2);
assert_eq!(resized, vec![1.0, 2.0]);
}
#[test]
fn test_resize_vector_pad() {
let v = vec![1.0, 2.0];
let resized = resize_vector(&v, 4);
assert_eq!(resized, vec![1.0, 2.0, 0.0, 0.0]);
}
#[test]
fn test_fuse_vectors() {
let syn = vec![1.0, 0.0];
let sem = vec![0.0, 1.0];
let fused = fuse_vectors(&syn, &sem);
assert_eq!(fused.len(), 4);
let norm: f32 = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
}