rig/providers/cohere/
embeddings.rs

1use super::{Client, client::ApiResponse};
2
3use crate::{
4    embeddings::{self, EmbeddingError},
5    http_client::HttpClientExt,
6    wasm_compat::*,
7};
8
9use serde::Deserialize;
10use serde_json::json;
11
12#[derive(Deserialize)]
13pub struct EmbeddingResponse {
14    #[serde(default)]
15    pub response_type: Option<String>,
16    pub id: String,
17    pub embeddings: Vec<Vec<f64>>,
18    pub texts: Vec<String>,
19    #[serde(default)]
20    pub meta: Option<Meta>,
21}
22
23#[derive(Deserialize)]
24pub struct Meta {
25    pub api_version: ApiVersion,
26    pub billed_units: BilledUnits,
27    #[serde(default)]
28    pub warnings: Vec<String>,
29}
30
31#[derive(Deserialize)]
32pub struct ApiVersion {
33    pub version: String,
34    #[serde(default)]
35    pub is_deprecated: Option<bool>,
36    #[serde(default)]
37    pub is_experimental: Option<bool>,
38}
39
40#[derive(Deserialize, Debug)]
41pub struct BilledUnits {
42    #[serde(default)]
43    pub input_tokens: u32,
44    #[serde(default)]
45    pub output_tokens: u32,
46    #[serde(default)]
47    pub search_units: u32,
48    #[serde(default)]
49    pub classifications: u32,
50}
51
52impl std::fmt::Display for BilledUnits {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(
55            f,
56            "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
57            self.input_tokens, self.output_tokens, self.search_units, self.classifications
58        )
59    }
60}
61
62#[derive(Clone)]
63pub struct EmbeddingModel<T = reqwest::Client> {
64    client: Client<T>,
65    pub model: String,
66    pub input_type: String,
67    ndims: usize,
68}
69
70impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
71where
72    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
73{
74    const MAX_DOCUMENTS: usize = 96;
75
76    fn ndims(&self) -> usize {
77        self.ndims
78    }
79
80    #[cfg_attr(feature = "worker", worker::send)]
81    async fn embed_texts(
82        &self,
83        documents: impl IntoIterator<Item = String>,
84    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
85        let documents = documents.into_iter().collect::<Vec<_>>();
86
87        let body = json!({
88            "model": self.model,
89            "texts": documents,
90            "input_type": self.input_type
91        });
92
93        let body = serde_json::to_vec(&body)?;
94
95        let req = self
96            .client
97            .post("/v1/embed")?
98            .header("Content-Type", "application/json")
99            .body(body)
100            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
101
102        let response = self
103            .client
104            .send::<_, Vec<u8>>(req)
105            .await
106            .map_err(EmbeddingError::HttpError)?;
107
108        if response.status().is_success() {
109            let body: ApiResponse<EmbeddingResponse> =
110                serde_json::from_slice(response.into_body().await?.as_slice())?;
111
112            match body {
113                ApiResponse::Ok(response) => {
114                    match response.meta {
115                        Some(meta) => tracing::info!(target: "rig",
116                            "Cohere embeddings billed units: {}",
117                            meta.billed_units,
118                        ),
119                        None => tracing::info!(target: "rig",
120                            "Cohere embeddings billed units: n/a",
121                        ),
122                    };
123
124                    if response.embeddings.len() != documents.len() {
125                        return Err(EmbeddingError::DocumentError(
126                            format!(
127                                "Expected {} embeddings, got {}",
128                                documents.len(),
129                                response.embeddings.len()
130                            )
131                            .into(),
132                        ));
133                    }
134
135                    Ok(response
136                        .embeddings
137                        .into_iter()
138                        .zip(documents.into_iter())
139                        .map(|(embedding, document)| embeddings::Embedding {
140                            document,
141                            vec: embedding,
142                        })
143                        .collect())
144                }
145                ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
146            }
147        } else {
148            let text = String::from_utf8_lossy(&response.into_body().await?).into();
149            Err(EmbeddingError::ProviderError(text))
150        }
151    }
152}
153
154impl<T> EmbeddingModel<T> {
155    pub fn new(client: Client<T>, model: &str, input_type: &str, ndims: usize) -> Self {
156        Self {
157            client,
158            model: model.to_string(),
159            input_type: input_type.to_string(),
160            ndims,
161        }
162    }
163}