Skip to main content

abu_rag/
embed.rs

1use abu_provider::{embed::EmbedRequest, EmbedProvide};
2use crate::document::Chunk;
3
4pub struct Embedder<P> {
5    provider: P,
6    model: String,
7}
8
9pub struct EmbeddedChunk {
10    pub chunk: Chunk,
11    pub embedding: Vec<f32>,
12}
13
14impl<P: EmbedProvide> Embedder<P> {
15    pub fn new(provider: P, model: impl Into<String>) -> Self {
16        Self {
17            provider,
18            model: model.into(),
19        }
20    }
21
22    pub async fn embed_chunks(&self, chunks: Vec<Chunk>) -> Result<Vec<EmbeddedChunk>, EmbedError> {
23        let texts = chunks.iter().map(|c| c.text.clone()).collect();
24        let embeddings = self.embed_texts(texts).await?;
25        let result = chunks
26            .into_iter()
27            .zip(embeddings)
28            .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
29            .collect();
30        Ok(result)
31    }
32
33    pub async fn embed_texts(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbedError> {
34        let request = EmbedRequest { input: texts ,model: self.model.clone() };
35        let response = self.provider
36            .embed(&request).await
37            .map_err(|e| EmbedError::EmbedProvide(Box::new(e)))?;
38        Ok(response.embeddings)
39    }
40
41    pub async fn embed_text(&self, text: impl Into<String>) -> Result<Vec<f32>, EmbedError> {
42        let request = EmbedRequest { input: vec![text.into()] ,model: self.model.clone() };
43        let response = self.provider
44            .embed(&request).await
45            .map_err(|e| EmbedError::EmbedProvide(Box::new(e)))?;
46        
47        let embedding = response.embeddings.into_iter()
48            .next()
49            .ok_or_else(|| EmbedError::NoEmbedding)?;
50        Ok(embedding)
51    }
52}
53
54#[derive(Debug, thiserror::Error)]
55pub enum EmbedError {
56    #[error("load error: {0}")]
57    EmbedProvide(Box<dyn std::error::Error + 'static + Send + Sync>),
58
59    #[error("empty embedding result")]
60    NoEmbedding,
61}