rig/providers/cohere/
embeddings.rs

1use super::{Client, client::ApiResponse};
2
3use crate::embeddings::{self, EmbeddingError};
4
5use serde::Deserialize;
6use serde_json::json;
7
8#[derive(Deserialize)]
9pub struct EmbeddingResponse {
10    #[serde(default)]
11    pub response_type: Option<String>,
12    pub id: String,
13    pub embeddings: Vec<Vec<f64>>,
14    pub texts: Vec<String>,
15    #[serde(default)]
16    pub meta: Option<Meta>,
17}
18
19#[derive(Deserialize)]
20pub struct Meta {
21    pub api_version: ApiVersion,
22    pub billed_units: BilledUnits,
23    #[serde(default)]
24    pub warnings: Vec<String>,
25}
26
27#[derive(Deserialize)]
28pub struct ApiVersion {
29    pub version: String,
30    #[serde(default)]
31    pub is_deprecated: Option<bool>,
32    #[serde(default)]
33    pub is_experimental: Option<bool>,
34}
35
36#[derive(Deserialize, Debug)]
37pub struct BilledUnits {
38    #[serde(default)]
39    pub input_tokens: u32,
40    #[serde(default)]
41    pub output_tokens: u32,
42    #[serde(default)]
43    pub search_units: u32,
44    #[serde(default)]
45    pub classifications: u32,
46}
47
48impl std::fmt::Display for BilledUnits {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(
51            f,
52            "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
53            self.input_tokens, self.output_tokens, self.search_units, self.classifications
54        )
55    }
56}
57
58#[derive(Clone)]
59pub struct EmbeddingModel {
60    client: Client,
61    pub model: String,
62    pub input_type: String,
63    ndims: usize,
64}
65
66impl embeddings::EmbeddingModel for EmbeddingModel {
67    const MAX_DOCUMENTS: usize = 96;
68
69    fn ndims(&self) -> usize {
70        self.ndims
71    }
72
73    #[cfg_attr(feature = "worker", worker::send)]
74    async fn embed_texts(
75        &self,
76        documents: impl IntoIterator<Item = String>,
77    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
78        let documents = documents.into_iter().collect::<Vec<_>>();
79
80        let response = self
81            .client
82            .post("/v1/embed")
83            .json(&json!({
84                "model": self.model,
85                "texts": documents,
86                "input_type": self.input_type,
87            }))
88            .send()
89            .await?;
90
91        if response.status().is_success() {
92            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
93                ApiResponse::Ok(response) => {
94                    match response.meta {
95                        Some(meta) => tracing::info!(target: "rig",
96                            "Cohere embeddings billed units: {}",
97                            meta.billed_units,
98                        ),
99                        None => tracing::info!(target: "rig",
100                            "Cohere embeddings billed units: n/a",
101                        ),
102                    };
103
104                    if response.embeddings.len() != documents.len() {
105                        return Err(EmbeddingError::DocumentError(
106                            format!(
107                                "Expected {} embeddings, got {}",
108                                documents.len(),
109                                response.embeddings.len()
110                            )
111                            .into(),
112                        ));
113                    }
114
115                    Ok(response
116                        .embeddings
117                        .into_iter()
118                        .zip(documents.into_iter())
119                        .map(|(embedding, document)| embeddings::Embedding {
120                            document,
121                            vec: embedding,
122                        })
123                        .collect())
124                }
125                ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
126            }
127        } else {
128            Err(EmbeddingError::ProviderError(response.text().await?))
129        }
130    }
131}
132
133impl EmbeddingModel {
134    pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
135        Self {
136            client,
137            model: model.to_string(),
138            input_type: input_type.to_string(),
139            ndims,
140        }
141    }
142}