rig/embeddings/
embedding.rs1use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, thiserror::Error)]
13pub enum EmbeddingError {
14 #[error("HttpError: {0}")]
16 HttpError(#[from] reqwest::Error),
17
18 #[error("JsonError: {0}")]
20 JsonError(#[from] serde_json::Error),
21
22 #[error("DocumentError: {0}")]
24 DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
25
26 #[error("ResponseError: {0}")]
28 ResponseError(String),
29
30 #[error("ProviderError: {0}")]
32 ProviderError(String),
33}
34
35pub trait EmbeddingModel: Clone + Sync + Send {
37 const MAX_DOCUMENTS: usize;
39
40 fn ndims(&self) -> usize;
42
43 fn embed_texts(
45 &self,
46 texts: impl IntoIterator<Item = String> + Send,
47 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
48
49 fn embed_text(
51 &self,
52 text: &str,
53 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
54 async {
55 Ok(self
56 .embed_texts(vec![text.to_string()])
57 .await?
58 .pop()
59 .expect("There should be at least one embedding"))
60 }
61 }
62}
63
64pub trait EmbeddingModelDyn: Sync + Send {
65 fn max_documents(&self) -> usize;
66 fn ndims(&self) -> usize;
67 fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>>;
68 fn embed_texts(
69 &self,
70 texts: Vec<String>,
71 ) -> BoxFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
72}
73
74impl<T: EmbeddingModel> EmbeddingModelDyn for T {
75 fn max_documents(&self) -> usize {
76 T::MAX_DOCUMENTS
77 }
78
79 fn ndims(&self) -> usize {
80 self.ndims()
81 }
82
83 fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>> {
84 Box::pin(self.embed_text(text))
85 }
86
87 fn embed_texts(&self, texts: Vec<String>) -> BoxFuture<Result<Vec<Embedding>, EmbeddingError>> {
88 Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
89 }
90}
91
92pub trait ImageEmbeddingModel: Clone + Sync + Send {
94 const MAX_DOCUMENTS: usize;
96
97 fn ndims(&self) -> usize;
99
100 fn embed_images(
102 &self,
103 images: impl IntoIterator<Item = Vec<u8>> + Send,
104 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
105
106 fn embed_image<'a>(
108 &'a self,
109 bytes: &'a [u8],
110 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
111 async move {
112 Ok(self
113 .embed_images(vec![bytes.to_owned()])
114 .await?
115 .pop()
116 .expect("There should be at least one embedding"))
117 }
118 }
119}
120
121#[derive(Clone, Default, Deserialize, Serialize, Debug)]
123pub struct Embedding {
124 pub document: String,
126 pub vec: Vec<f64>,
128}
129
130impl PartialEq for Embedding {
131 fn eq(&self, other: &Self) -> bool {
132 self.document == other.document
133 }
134}
135
136impl Eq for Embedding {}