use std::sync::Arc;
use crate::backend::Backend;
use crate::model::LlamaModel;
use crate::tensor::Tensor;
use super::{RagError, RagResult};
pub struct EmbeddingGenerator {
model: LlamaModel,
backend: Arc<dyn Backend>,
dim: usize,
normalize: bool,
}
impl EmbeddingGenerator {
pub fn new(model: LlamaModel, backend: Arc<dyn Backend>) -> Self {
let dim = model.config().hidden_size;
Self {
model,
backend,
dim,
normalize: true,
}
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn embed(&mut self, text: &str) -> RagResult<Vec<f32>> {
let embedding = vec![0.0f32; self.dim];
if self.normalize {
Ok(Self::l2_normalize(&embedding))
} else {
Ok(embedding)
}
}
pub fn embed_batch(&mut self, texts: &[&str]) -> RagResult<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
fn l2_normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
}
pub struct TextChunker {
chunk_size: usize,
chunk_overlap: usize,
separator: String,
}
impl Default for TextChunker {
fn default() -> Self {
Self {
chunk_size: 500,
chunk_overlap: 50,
separator: " ".to_string(),
}
}
}
impl TextChunker {
pub fn new(chunk_size: usize) -> Self {
Self {
chunk_size,
..Default::default()
}
}
pub fn with_overlap(mut self, overlap: usize) -> Self {
self.chunk_overlap = overlap;
self
}
pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
self.separator = sep.into();
self
}
pub fn chunk(&self, text: &str) -> Vec<String> {
let words: Vec<&str> = text.split(&self.separator).collect();
let mut chunks = Vec::new();
let mut i = 0;
while i < words.len() {
let mut chunk_words = Vec::new();
let mut char_count = 0;
for j in i..words.len() {
let word = words[j];
let word_len = word.len() + if chunk_words.is_empty() { 0 } else { 1 };
if char_count + word_len > self.chunk_size && !chunk_words.is_empty() {
break;
}
chunk_words.push(word);
char_count += word_len;
}
if !chunk_words.is_empty() {
chunks.push(chunk_words.join(&self.separator));
let advance = chunk_words.len().saturating_sub(self.chunk_overlap / 10);
i += advance.max(1);
} else {
i += 1;
}
}
chunks
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunker() {
let chunker = TextChunker::new(50).with_overlap(10);
let text = "This is a test sentence. It has multiple words. We want to chunk it.";
let chunks = chunker.chunk(text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(chunk.len() <= 60); }
}
#[test]
fn test_l2_normalize() {
let v = vec![3.0, 4.0];
let normalized = EmbeddingGenerator::l2_normalize(&v);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001);
}
}