rig/providers/cohere/
embeddings.rs1use super::{client::ApiResponse, client::Client};
2use crate::{
3 embeddings::{self, EmbeddingError},
4 http_client::HttpClientExt,
5 wasm_compat::*,
6};
7use serde::Deserialize;
8use serde_json::json;
9
10#[derive(Deserialize)]
11pub struct EmbeddingResponse {
12 #[serde(default)]
13 pub response_type: Option<String>,
14 pub id: String,
15 pub embeddings: Vec<Vec<f64>>,
16 pub texts: Vec<String>,
17 #[serde(default)]
18 pub meta: Option<Meta>,
19}
20
21#[derive(Deserialize)]
22pub struct Meta {
23 pub api_version: ApiVersion,
24 pub billed_units: BilledUnits,
25 #[serde(default)]
26 pub warnings: Vec<String>,
27}
28
29#[derive(Deserialize)]
30pub struct ApiVersion {
31 pub version: String,
32 #[serde(default)]
33 pub is_deprecated: Option<bool>,
34 #[serde(default)]
35 pub is_experimental: Option<bool>,
36}
37
38#[derive(Deserialize, Debug)]
39pub struct BilledUnits {
40 #[serde(default)]
41 pub input_tokens: u32,
42 #[serde(default)]
43 pub output_tokens: u32,
44 #[serde(default)]
45 pub search_units: u32,
46 #[serde(default)]
47 pub classifications: u32,
48}
49
50impl std::fmt::Display for BilledUnits {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(
53 f,
54 "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
55 self.input_tokens, self.output_tokens, self.search_units, self.classifications
56 )
57 }
58}
59
60#[derive(Clone)]
61pub struct EmbeddingModel<T = reqwest::Client> {
62 client: Client<T>,
63 pub model: String,
64 pub input_type: String,
65 ndims: usize,
66}
67
68impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
69where
70 T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
71{
72 const MAX_DOCUMENTS: usize = 96;
73 type Client = Client<T>;
74
75 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
76 let model = model.into();
77 let dims = dims
78 .or(super::model_dimensions_from_identifier(&model))
79 .unwrap_or_default();
80
81 Self::new(client.clone(), model, "search_document", dims)
82 }
83
84 fn ndims(&self) -> usize {
85 self.ndims
86 }
87
88 #[cfg_attr(feature = "worker", worker::send)]
89 async fn embed_texts(
90 &self,
91 documents: impl IntoIterator<Item = String>,
92 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
93 let documents = documents.into_iter().collect::<Vec<_>>();
94
95 let body = json!({
96 "model": self.model.to_string(),
97 "texts": documents,
98 "input_type": self.input_type
99 });
100
101 let body = serde_json::to_vec(&body)?;
102
103 let req = self
104 .client
105 .post("/v1/embed")?
106 .body(body)
107 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
108
109 let response = self
110 .client
111 .send::<_, Vec<u8>>(req)
112 .await
113 .map_err(EmbeddingError::HttpError)?;
114
115 if response.status().is_success() {
116 let body: ApiResponse<EmbeddingResponse> =
117 serde_json::from_slice(response.into_body().await?.as_slice())?;
118
119 match body {
120 ApiResponse::Ok(response) => {
121 match response.meta {
122 Some(meta) => tracing::info!(target: "rig",
123 "Cohere embeddings billed units: {}",
124 meta.billed_units,
125 ),
126 None => tracing::info!(target: "rig",
127 "Cohere embeddings billed units: n/a",
128 ),
129 };
130
131 if response.embeddings.len() != documents.len() {
132 return Err(EmbeddingError::DocumentError(
133 format!(
134 "Expected {} embeddings, got {}",
135 documents.len(),
136 response.embeddings.len()
137 )
138 .into(),
139 ));
140 }
141
142 Ok(response
143 .embeddings
144 .into_iter()
145 .zip(documents.into_iter())
146 .map(|(embedding, document)| embeddings::Embedding {
147 document,
148 vec: embedding,
149 })
150 .collect())
151 }
152 ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
153 }
154 } else {
155 let text = String::from_utf8_lossy(&response.into_body().await?).into();
156 Err(EmbeddingError::ProviderError(text))
157 }
158 }
159}
160
161impl<T> EmbeddingModel<T> {
162 pub fn new(
163 client: Client<T>,
164 model: impl Into<String>,
165 input_type: &str,
166 ndims: usize,
167 ) -> Self {
168 Self {
169 client,
170 model: model.into(),
171 input_type: input_type.to_string(),
172 ndims,
173 }
174 }
175
176 pub fn with_model(client: Client<T>, model: &str, input_type: &str, ndims: usize) -> Self {
177 Self {
178 client,
179 model: model.into(),
180 input_type: input_type.into(),
181 ndims,
182 }
183 }
184}