rig/embeddings/
embedding.rs1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, thiserror::Error)]
12pub enum EmbeddingError {
13 #[error("HttpError: {0}")]
15 HttpError(#[from] reqwest::Error),
16
17 #[error("JsonError: {0}")]
19 JsonError(#[from] serde_json::Error),
20
21 #[error("DocumentError: {0}")]
23 DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
24
25 #[error("ResponseError: {0}")]
27 ResponseError(String),
28
29 #[error("ProviderError: {0}")]
31 ProviderError(String),
32}
33
34pub trait EmbeddingModel: Clone + Sync + Send {
36 const MAX_DOCUMENTS: usize;
38
39 fn ndims(&self) -> usize;
41
42 fn embed_texts(
44 &self,
45 texts: impl IntoIterator<Item = String> + Send,
46 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
47
48 fn embed_text(
50 &self,
51 text: &str,
52 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
53 async {
54 Ok(self
55 .embed_texts(vec![text.to_string()])
56 .await?
57 .pop()
58 .expect("There should be at least one embedding"))
59 }
60 }
61}
62
63pub trait ImageEmbeddingModel: Clone + Sync + Send {
65 const MAX_DOCUMENTS: usize;
67
68 fn ndims(&self) -> usize;
70
71 fn embed_images(
73 &self,
74 images: impl IntoIterator<Item = Vec<u8>> + Send,
75 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
76
77 fn embed_image<'a>(
79 &'a self,
80 bytes: &'a [u8],
81 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
82 async move {
83 Ok(self
84 .embed_images(vec![bytes.to_owned()])
85 .await?
86 .pop()
87 .expect("There should be at least one embedding"))
88 }
89 }
90}
91
92#[derive(Clone, Default, Deserialize, Serialize, Debug)]
94pub struct Embedding {
95 pub document: String,
97 pub vec: Vec<f64>,
99}
100
101impl PartialEq for Embedding {
102 fn eq(&self, other: &Self) -> bool {
103 self.document == other.document
104 }
105}
106
107impl Eq for Embedding {}