rig/providers/cohere/
embeddings.rs1use super::{Client, client::ApiResponse};
2
3use crate::embeddings::{self, EmbeddingError};
4
5use serde::Deserialize;
6use serde_json::json;
7
8#[derive(Deserialize)]
9pub struct EmbeddingResponse {
10 #[serde(default)]
11 pub response_type: Option<String>,
12 pub id: String,
13 pub embeddings: Vec<Vec<f64>>,
14 pub texts: Vec<String>,
15 #[serde(default)]
16 pub meta: Option<Meta>,
17}
18
19#[derive(Deserialize)]
20pub struct Meta {
21 pub api_version: ApiVersion,
22 pub billed_units: BilledUnits,
23 #[serde(default)]
24 pub warnings: Vec<String>,
25}
26
27#[derive(Deserialize)]
28pub struct ApiVersion {
29 pub version: String,
30 #[serde(default)]
31 pub is_deprecated: Option<bool>,
32 #[serde(default)]
33 pub is_experimental: Option<bool>,
34}
35
36#[derive(Deserialize, Debug)]
37pub struct BilledUnits {
38 #[serde(default)]
39 pub input_tokens: u32,
40 #[serde(default)]
41 pub output_tokens: u32,
42 #[serde(default)]
43 pub search_units: u32,
44 #[serde(default)]
45 pub classifications: u32,
46}
47
48impl std::fmt::Display for BilledUnits {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(
51 f,
52 "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
53 self.input_tokens, self.output_tokens, self.search_units, self.classifications
54 )
55 }
56}
57
58#[derive(Clone)]
59pub struct EmbeddingModel {
60 client: Client,
61 pub model: String,
62 pub input_type: String,
63 ndims: usize,
64}
65
66impl embeddings::EmbeddingModel for EmbeddingModel {
67 const MAX_DOCUMENTS: usize = 96;
68
69 fn ndims(&self) -> usize {
70 self.ndims
71 }
72
73 #[cfg_attr(feature = "worker", worker::send)]
74 async fn embed_texts(
75 &self,
76 documents: impl IntoIterator<Item = String>,
77 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
78 let documents = documents.into_iter().collect::<Vec<_>>();
79
80 let response = self
81 .client
82 .post("/v1/embed")
83 .json(&json!({
84 "model": self.model,
85 "texts": documents,
86 "input_type": self.input_type,
87 }))
88 .send()
89 .await?;
90
91 if response.status().is_success() {
92 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
93 ApiResponse::Ok(response) => {
94 match response.meta {
95 Some(meta) => tracing::info!(target: "rig",
96 "Cohere embeddings billed units: {}",
97 meta.billed_units,
98 ),
99 None => tracing::info!(target: "rig",
100 "Cohere embeddings billed units: n/a",
101 ),
102 };
103
104 if response.embeddings.len() != documents.len() {
105 return Err(EmbeddingError::DocumentError(
106 format!(
107 "Expected {} embeddings, got {}",
108 documents.len(),
109 response.embeddings.len()
110 )
111 .into(),
112 ));
113 }
114
115 Ok(response
116 .embeddings
117 .into_iter()
118 .zip(documents.into_iter())
119 .map(|(embedding, document)| embeddings::Embedding {
120 document,
121 vec: embedding,
122 })
123 .collect())
124 }
125 ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
126 }
127 } else {
128 Err(EmbeddingError::ProviderError(response.text().await?))
129 }
130 }
131}
132
133impl EmbeddingModel {
134 pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
135 Self {
136 client,
137 model: model.to_string(),
138 input_type: input_type.to_string(),
139 ndims,
140 }
141 }
142}