rig/providers/cohere/
embeddings.rs1use super::{client::ApiResponse, client::Client};
2use crate::{
3 embeddings::{self, EmbeddingError},
4 http_client::HttpClientExt,
5 wasm_compat::*,
6};
7use serde::Deserialize;
8use serde_json::json;
9
10#[derive(Deserialize)]
11pub struct EmbeddingResponse {
12 #[serde(default)]
13 pub response_type: Option<String>,
14 pub id: String,
15 pub embeddings: Vec<Vec<f64>>,
16 pub texts: Vec<String>,
17 #[serde(default)]
18 pub meta: Option<Meta>,
19}
20
21#[derive(Deserialize)]
22pub struct Meta {
23 pub api_version: ApiVersion,
24 pub billed_units: BilledUnits,
25 #[serde(default)]
26 pub warnings: Vec<String>,
27}
28
29#[derive(Deserialize)]
30pub struct ApiVersion {
31 pub version: String,
32 #[serde(default)]
33 pub is_deprecated: Option<bool>,
34 #[serde(default)]
35 pub is_experimental: Option<bool>,
36}
37
38#[derive(Deserialize, Debug)]
39pub struct BilledUnits {
40 #[serde(default)]
41 pub input_tokens: u32,
42 #[serde(default)]
43 pub output_tokens: u32,
44 #[serde(default)]
45 pub search_units: u32,
46 #[serde(default)]
47 pub classifications: u32,
48}
49
50impl std::fmt::Display for BilledUnits {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(
53 f,
54 "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
55 self.input_tokens, self.output_tokens, self.search_units, self.classifications
56 )
57 }
58}
59
60#[derive(Clone)]
61pub struct EmbeddingModel<T = reqwest::Client> {
62 client: Client<T>,
63 pub model: String,
64 pub input_type: String,
65 ndims: usize,
66}
67
68impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
69where
70 T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
71{
72 const MAX_DOCUMENTS: usize = 96;
73 type Client = Client<T>;
74
75 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
76 let model = model.into();
77 let dims = dims
78 .or(super::model_dimensions_from_identifier(&model))
79 .unwrap_or_default();
80
81 Self::new(client.clone(), model, "search_document", dims)
82 }
83
84 fn ndims(&self) -> usize {
85 self.ndims
86 }
87
88 async fn embed_texts(
89 &self,
90 documents: impl IntoIterator<Item = String>,
91 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
92 let documents = documents.into_iter().collect::<Vec<_>>();
93
94 let body = json!({
95 "model": self.model.to_string(),
96 "texts": documents,
97 "input_type": self.input_type
98 });
99
100 let body = serde_json::to_vec(&body)?;
101
102 let req = self
103 .client
104 .post("/v1/embed")?
105 .body(body)
106 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
107
108 let response = self
109 .client
110 .send::<_, Vec<u8>>(req)
111 .await
112 .map_err(EmbeddingError::HttpError)?;
113
114 if response.status().is_success() {
115 let body: ApiResponse<EmbeddingResponse> =
116 serde_json::from_slice(response.into_body().await?.as_slice())?;
117
118 match body {
119 ApiResponse::Ok(response) => {
120 match response.meta {
121 Some(meta) => tracing::info!(target: "rig",
122 "Cohere embeddings billed units: {}",
123 meta.billed_units,
124 ),
125 None => tracing::info!(target: "rig",
126 "Cohere embeddings billed units: n/a",
127 ),
128 };
129
130 if response.embeddings.len() != documents.len() {
131 return Err(EmbeddingError::DocumentError(
132 format!(
133 "Expected {} embeddings, got {}",
134 documents.len(),
135 response.embeddings.len()
136 )
137 .into(),
138 ));
139 }
140
141 Ok(response
142 .embeddings
143 .into_iter()
144 .zip(documents.into_iter())
145 .map(|(embedding, document)| embeddings::Embedding {
146 document,
147 vec: embedding,
148 })
149 .collect())
150 }
151 ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
152 }
153 } else {
154 let text = String::from_utf8_lossy(&response.into_body().await?).into();
155 Err(EmbeddingError::ProviderError(text))
156 }
157 }
158}
159
160impl<T> EmbeddingModel<T> {
161 pub fn new(
162 client: Client<T>,
163 model: impl Into<String>,
164 input_type: &str,
165 ndims: usize,
166 ) -> Self {
167 Self {
168 client,
169 model: model.into(),
170 input_type: input_type.to_string(),
171 ndims,
172 }
173 }
174
175 pub fn with_model(client: Client<T>, model: &str, input_type: &str, ndims: usize) -> Self {
176 Self {
177 client,
178 model: model.into(),
179 input_type: input_type.into(),
180 ndims,
181 }
182 }
183}