rig/providers/cohere/
embeddings.rs

1use super::{client::ApiResponse, client::Client};
2use crate::{
3    embeddings::{self, EmbeddingError},
4    http_client::HttpClientExt,
5    wasm_compat::*,
6};
7use serde::Deserialize;
8use serde_json::json;
9
10#[derive(Deserialize)]
11pub struct EmbeddingResponse {
12    #[serde(default)]
13    pub response_type: Option<String>,
14    pub id: String,
15    pub embeddings: Vec<Vec<f64>>,
16    pub texts: Vec<String>,
17    #[serde(default)]
18    pub meta: Option<Meta>,
19}
20
21#[derive(Deserialize)]
22pub struct Meta {
23    pub api_version: ApiVersion,
24    pub billed_units: BilledUnits,
25    #[serde(default)]
26    pub warnings: Vec<String>,
27}
28
29#[derive(Deserialize)]
30pub struct ApiVersion {
31    pub version: String,
32    #[serde(default)]
33    pub is_deprecated: Option<bool>,
34    #[serde(default)]
35    pub is_experimental: Option<bool>,
36}
37
38#[derive(Deserialize, Debug)]
39pub struct BilledUnits {
40    #[serde(default)]
41    pub input_tokens: u32,
42    #[serde(default)]
43    pub output_tokens: u32,
44    #[serde(default)]
45    pub search_units: u32,
46    #[serde(default)]
47    pub classifications: u32,
48}
49
50impl std::fmt::Display for BilledUnits {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(
53            f,
54            "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
55            self.input_tokens, self.output_tokens, self.search_units, self.classifications
56        )
57    }
58}
59
60#[derive(Clone)]
61pub struct EmbeddingModel<T = reqwest::Client> {
62    client: Client<T>,
63    pub model: String,
64    pub input_type: String,
65    ndims: usize,
66}
67
68impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
69where
70    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
71{
72    const MAX_DOCUMENTS: usize = 96;
73    type Client = Client<T>;
74
75    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
76        let model = model.into();
77        let dims = dims
78            .or(super::model_dimensions_from_identifier(&model))
79            .unwrap_or_default();
80
81        Self::new(client.clone(), model, "search_document", dims)
82    }
83
84    fn ndims(&self) -> usize {
85        self.ndims
86    }
87
88    #[cfg_attr(feature = "worker", worker::send)]
89    async fn embed_texts(
90        &self,
91        documents: impl IntoIterator<Item = String>,
92    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
93        let documents = documents.into_iter().collect::<Vec<_>>();
94
95        let body = json!({
96            "model": self.model.to_string(),
97            "texts": documents,
98            "input_type": self.input_type
99        });
100
101        let body = serde_json::to_vec(&body)?;
102
103        let req = self
104            .client
105            .post("/v1/embed")?
106            .body(body)
107            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
108
109        let response = self
110            .client
111            .send::<_, Vec<u8>>(req)
112            .await
113            .map_err(EmbeddingError::HttpError)?;
114
115        if response.status().is_success() {
116            let body: ApiResponse<EmbeddingResponse> =
117                serde_json::from_slice(response.into_body().await?.as_slice())?;
118
119            match body {
120                ApiResponse::Ok(response) => {
121                    match response.meta {
122                        Some(meta) => tracing::info!(target: "rig",
123                            "Cohere embeddings billed units: {}",
124                            meta.billed_units,
125                        ),
126                        None => tracing::info!(target: "rig",
127                            "Cohere embeddings billed units: n/a",
128                        ),
129                    };
130
131                    if response.embeddings.len() != documents.len() {
132                        return Err(EmbeddingError::DocumentError(
133                            format!(
134                                "Expected {} embeddings, got {}",
135                                documents.len(),
136                                response.embeddings.len()
137                            )
138                            .into(),
139                        ));
140                    }
141
142                    Ok(response
143                        .embeddings
144                        .into_iter()
145                        .zip(documents.into_iter())
146                        .map(|(embedding, document)| embeddings::Embedding {
147                            document,
148                            vec: embedding,
149                        })
150                        .collect())
151                }
152                ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
153            }
154        } else {
155            let text = String::from_utf8_lossy(&response.into_body().await?).into();
156            Err(EmbeddingError::ProviderError(text))
157        }
158    }
159}
160
161impl<T> EmbeddingModel<T> {
162    pub fn new(
163        client: Client<T>,
164        model: impl Into<String>,
165        input_type: &str,
166        ndims: usize,
167    ) -> Self {
168        Self {
169            client,
170            model: model.into(),
171            input_type: input_type.to_string(),
172            ndims,
173        }
174    }
175
176    pub fn with_model(client: Client<T>, model: &str, input_type: &str, ndims: usize) -> Self {
177        Self {
178            client,
179            model: model.into(),
180            input_type: input_type.into(),
181            ndims,
182        }
183    }
184}