rig/providers/cohere/
embeddings.rs

1use super::{client::ApiResponse, client::Client};
2use crate::{
3    embeddings::{self, EmbeddingError},
4    http_client::HttpClientExt,
5    wasm_compat::*,
6};
7use serde::Deserialize;
8use serde_json::json;
9
10#[derive(Deserialize)]
11pub struct EmbeddingResponse {
12    #[serde(default)]
13    pub response_type: Option<String>,
14    pub id: String,
15    pub embeddings: Vec<Vec<f64>>,
16    pub texts: Vec<String>,
17    #[serde(default)]
18    pub meta: Option<Meta>,
19}
20
21#[derive(Deserialize)]
22pub struct Meta {
23    pub api_version: ApiVersion,
24    pub billed_units: BilledUnits,
25    #[serde(default)]
26    pub warnings: Vec<String>,
27}
28
29#[derive(Deserialize)]
30pub struct ApiVersion {
31    pub version: String,
32    #[serde(default)]
33    pub is_deprecated: Option<bool>,
34    #[serde(default)]
35    pub is_experimental: Option<bool>,
36}
37
38#[derive(Deserialize, Debug)]
39pub struct BilledUnits {
40    #[serde(default)]
41    pub input_tokens: u32,
42    #[serde(default)]
43    pub output_tokens: u32,
44    #[serde(default)]
45    pub search_units: u32,
46    #[serde(default)]
47    pub classifications: u32,
48}
49
50impl std::fmt::Display for BilledUnits {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(
53            f,
54            "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
55            self.input_tokens, self.output_tokens, self.search_units, self.classifications
56        )
57    }
58}
59
60#[derive(Clone)]
61pub struct EmbeddingModel<T = reqwest::Client> {
62    client: Client<T>,
63    pub model: String,
64    pub input_type: String,
65    ndims: usize,
66}
67
68impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
69where
70    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
71{
72    const MAX_DOCUMENTS: usize = 96;
73    type Client = Client<T>;
74
75    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
76        let model = model.into();
77        let dims = dims
78            .or(super::model_dimensions_from_identifier(&model))
79            .unwrap_or_default();
80
81        Self::new(client.clone(), model, "search_document", dims)
82    }
83
84    fn ndims(&self) -> usize {
85        self.ndims
86    }
87
88    async fn embed_texts(
89        &self,
90        documents: impl IntoIterator<Item = String>,
91    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
92        let documents = documents.into_iter().collect::<Vec<_>>();
93
94        let body = json!({
95            "model": self.model.to_string(),
96            "texts": documents,
97            "input_type": self.input_type
98        });
99
100        let body = serde_json::to_vec(&body)?;
101
102        let req = self
103            .client
104            .post("/v1/embed")?
105            .body(body)
106            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
107
108        let response = self
109            .client
110            .send::<_, Vec<u8>>(req)
111            .await
112            .map_err(EmbeddingError::HttpError)?;
113
114        if response.status().is_success() {
115            let body: ApiResponse<EmbeddingResponse> =
116                serde_json::from_slice(response.into_body().await?.as_slice())?;
117
118            match body {
119                ApiResponse::Ok(response) => {
120                    match response.meta {
121                        Some(meta) => tracing::info!(target: "rig",
122                            "Cohere embeddings billed units: {}",
123                            meta.billed_units,
124                        ),
125                        None => tracing::info!(target: "rig",
126                            "Cohere embeddings billed units: n/a",
127                        ),
128                    };
129
130                    if response.embeddings.len() != documents.len() {
131                        return Err(EmbeddingError::DocumentError(
132                            format!(
133                                "Expected {} embeddings, got {}",
134                                documents.len(),
135                                response.embeddings.len()
136                            )
137                            .into(),
138                        ));
139                    }
140
141                    Ok(response
142                        .embeddings
143                        .into_iter()
144                        .zip(documents.into_iter())
145                        .map(|(embedding, document)| embeddings::Embedding {
146                            document,
147                            vec: embedding,
148                        })
149                        .collect())
150                }
151                ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
152            }
153        } else {
154            let text = String::from_utf8_lossy(&response.into_body().await?).into();
155            Err(EmbeddingError::ProviderError(text))
156        }
157    }
158}
159
160impl<T> EmbeddingModel<T> {
161    pub fn new(
162        client: Client<T>,
163        model: impl Into<String>,
164        input_type: &str,
165        ndims: usize,
166    ) -> Self {
167        Self {
168            client,
169            model: model.into(),
170            input_type: input_type.to_string(),
171            ndims,
172        }
173    }
174
175    pub fn with_model(client: Client<T>, model: &str, input_type: &str, ndims: usize) -> Self {
176        Self {
177            client,
178            model: model.into(),
179            input_type: input_type.into(),
180            ndims,
181        }
182    }
183}