rig/providers/cohere/
embeddings.rs1use super::{Client, client::ApiResponse};
2
3use crate::{
4 embeddings::{self, EmbeddingError},
5 http_client::HttpClientExt,
6 wasm_compat::*,
7};
8
9use serde::Deserialize;
10use serde_json::json;
11
12#[derive(Deserialize)]
13pub struct EmbeddingResponse {
14 #[serde(default)]
15 pub response_type: Option<String>,
16 pub id: String,
17 pub embeddings: Vec<Vec<f64>>,
18 pub texts: Vec<String>,
19 #[serde(default)]
20 pub meta: Option<Meta>,
21}
22
23#[derive(Deserialize)]
24pub struct Meta {
25 pub api_version: ApiVersion,
26 pub billed_units: BilledUnits,
27 #[serde(default)]
28 pub warnings: Vec<String>,
29}
30
31#[derive(Deserialize)]
32pub struct ApiVersion {
33 pub version: String,
34 #[serde(default)]
35 pub is_deprecated: Option<bool>,
36 #[serde(default)]
37 pub is_experimental: Option<bool>,
38}
39
40#[derive(Deserialize, Debug)]
41pub struct BilledUnits {
42 #[serde(default)]
43 pub input_tokens: u32,
44 #[serde(default)]
45 pub output_tokens: u32,
46 #[serde(default)]
47 pub search_units: u32,
48 #[serde(default)]
49 pub classifications: u32,
50}
51
52impl std::fmt::Display for BilledUnits {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(
55 f,
56 "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
57 self.input_tokens, self.output_tokens, self.search_units, self.classifications
58 )
59 }
60}
61
62#[derive(Clone)]
63pub struct EmbeddingModel<T = reqwest::Client> {
64 client: Client<T>,
65 pub model: String,
66 pub input_type: String,
67 ndims: usize,
68}
69
70impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
71where
72 T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
73{
74 const MAX_DOCUMENTS: usize = 96;
75
76 fn ndims(&self) -> usize {
77 self.ndims
78 }
79
80 #[cfg_attr(feature = "worker", worker::send)]
81 async fn embed_texts(
82 &self,
83 documents: impl IntoIterator<Item = String>,
84 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
85 let documents = documents.into_iter().collect::<Vec<_>>();
86
87 let body = json!({
88 "model": self.model,
89 "texts": documents,
90 "input_type": self.input_type
91 });
92
93 let body = serde_json::to_vec(&body)?;
94
95 let req = self
96 .client
97 .post("/v1/embed")?
98 .header("Content-Type", "application/json")
99 .body(body)
100 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
101
102 let response = self
103 .client
104 .send::<_, Vec<u8>>(req)
105 .await
106 .map_err(EmbeddingError::HttpError)?;
107
108 if response.status().is_success() {
109 let body: ApiResponse<EmbeddingResponse> =
110 serde_json::from_slice(response.into_body().await?.as_slice())?;
111
112 match body {
113 ApiResponse::Ok(response) => {
114 match response.meta {
115 Some(meta) => tracing::info!(target: "rig",
116 "Cohere embeddings billed units: {}",
117 meta.billed_units,
118 ),
119 None => tracing::info!(target: "rig",
120 "Cohere embeddings billed units: n/a",
121 ),
122 };
123
124 if response.embeddings.len() != documents.len() {
125 return Err(EmbeddingError::DocumentError(
126 format!(
127 "Expected {} embeddings, got {}",
128 documents.len(),
129 response.embeddings.len()
130 )
131 .into(),
132 ));
133 }
134
135 Ok(response
136 .embeddings
137 .into_iter()
138 .zip(documents.into_iter())
139 .map(|(embedding, document)| embeddings::Embedding {
140 document,
141 vec: embedding,
142 })
143 .collect())
144 }
145 ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
146 }
147 } else {
148 let text = String::from_utf8_lossy(&response.into_body().await?).into();
149 Err(EmbeddingError::ProviderError(text))
150 }
151 }
152}
153
154impl<T> EmbeddingModel<T> {
155 pub fn new(client: Client<T>, model: &str, input_type: &str, ndims: usize) -> Self {
156 Self {
157 client,
158 model: model.to_string(),
159 input_type: input_type.to_string(),
160 ndims,
161 }
162 }
163}