1use std::collections::HashMap;
12
13use crate::{
14 agent::AgentBuilder,
15 completion::{self, CompletionError},
16 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17 extractor::ExtractorBuilder,
18 json_utils, Embed,
19};
20
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
29
30#[derive(Clone)]
31pub struct Client {
32 base_url: String,
33 http_client: reqwest::Client,
34}
35
36impl Client {
37 pub fn new(api_key: &str) -> Self {
38 Self::from_url(api_key, COHERE_API_BASE_URL)
39 }
40
41 pub fn from_url(api_key: &str, base_url: &str) -> Self {
42 Self {
43 base_url: base_url.to_string(),
44 http_client: reqwest::Client::builder()
45 .default_headers({
46 let mut headers = reqwest::header::HeaderMap::new();
47 headers.insert(
48 "Authorization",
49 format!("Bearer {}", api_key)
50 .parse()
51 .expect("Bearer token should parse"),
52 );
53 headers
54 })
55 .build()
56 .expect("Cohere reqwest client should build"),
57 }
58 }
59
60 pub fn from_env() -> Self {
63 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
64 Self::new(&api_key)
65 }
66
67 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
68 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69 self.http_client.post(url)
70 }
71
72 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
75 let ndims = match model {
76 EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024,
77 EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384,
78 EMBED_ENGLISH_V2 => 4096,
79 EMBED_MULTILINGUAL_V2 => 768,
80 _ => 0,
81 };
82 EmbeddingModel::new(self.clone(), model, input_type, ndims)
83 }
84
85 pub fn embedding_model_with_ndims(
87 &self,
88 model: &str,
89 input_type: &str,
90 ndims: usize,
91 ) -> EmbeddingModel {
92 EmbeddingModel::new(self.clone(), model, input_type, ndims)
93 }
94
95 pub fn embeddings<D: Embed>(
96 &self,
97 model: &str,
98 input_type: &str,
99 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
100 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
101 }
102
103 pub fn completion_model(&self, model: &str) -> CompletionModel {
104 CompletionModel::new(self.clone(), model)
105 }
106
107 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
108 AgentBuilder::new(self.completion_model(model))
109 }
110
111 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
112 &self,
113 model: &str,
114 ) -> ExtractorBuilder<T, CompletionModel> {
115 ExtractorBuilder::new(self.completion_model(model))
116 }
117}
118
119#[derive(Debug, Deserialize)]
120struct ApiErrorResponse {
121 message: String,
122}
123
124#[derive(Debug, Deserialize)]
125#[serde(untagged)]
126enum ApiResponse<T> {
127 Ok(T),
128 Err(ApiErrorResponse),
129}
130
131pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
136pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
138pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
140pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
142pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
144pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
146pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
148
149#[derive(Deserialize)]
150pub struct EmbeddingResponse {
151 #[serde(default)]
152 pub response_type: Option<String>,
153 pub id: String,
154 pub embeddings: Vec<Vec<f64>>,
155 pub texts: Vec<String>,
156 #[serde(default)]
157 pub meta: Option<Meta>,
158}
159
160#[derive(Deserialize)]
161pub struct Meta {
162 pub api_version: ApiVersion,
163 pub billed_units: BilledUnits,
164 #[serde(default)]
165 pub warnings: Vec<String>,
166}
167
168#[derive(Deserialize)]
169pub struct ApiVersion {
170 pub version: String,
171 #[serde(default)]
172 pub is_deprecated: Option<bool>,
173 #[serde(default)]
174 pub is_experimental: Option<bool>,
175}
176
177#[derive(Deserialize, Debug)]
178pub struct BilledUnits {
179 #[serde(default)]
180 pub input_tokens: u32,
181 #[serde(default)]
182 pub output_tokens: u32,
183 #[serde(default)]
184 pub search_units: u32,
185 #[serde(default)]
186 pub classifications: u32,
187}
188
189impl std::fmt::Display for BilledUnits {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 write!(
192 f,
193 "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
194 self.input_tokens, self.output_tokens, self.search_units, self.classifications
195 )
196 }
197}
198
199#[derive(Clone)]
200pub struct EmbeddingModel {
201 client: Client,
202 pub model: String,
203 pub input_type: String,
204 ndims: usize,
205}
206
207impl embeddings::EmbeddingModel for EmbeddingModel {
208 const MAX_DOCUMENTS: usize = 96;
209
210 fn ndims(&self) -> usize {
211 self.ndims
212 }
213
214 async fn embed_texts(
215 &self,
216 documents: impl IntoIterator<Item = String>,
217 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
218 let documents = documents.into_iter().collect::<Vec<_>>();
219
220 let response = self
221 .client
222 .post("/v1/embed")
223 .json(&json!({
224 "model": self.model,
225 "texts": documents,
226 "input_type": self.input_type,
227 }))
228 .send()
229 .await?;
230
231 if response.status().is_success() {
232 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
233 ApiResponse::Ok(response) => {
234 match response.meta {
235 Some(meta) => tracing::info!(target: "rig",
236 "Cohere embeddings billed units: {}",
237 meta.billed_units,
238 ),
239 None => tracing::info!(target: "rig",
240 "Cohere embeddings billed units: n/a",
241 ),
242 };
243
244 if response.embeddings.len() != documents.len() {
245 return Err(EmbeddingError::DocumentError(
246 format!(
247 "Expected {} embeddings, got {}",
248 documents.len(),
249 response.embeddings.len()
250 )
251 .into(),
252 ));
253 }
254
255 Ok(response
256 .embeddings
257 .into_iter()
258 .zip(documents.into_iter())
259 .map(|(embedding, document)| embeddings::Embedding {
260 document,
261 vec: embedding,
262 })
263 .collect())
264 }
265 ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
266 }
267 } else {
268 Err(EmbeddingError::ProviderError(response.text().await?))
269 }
270 }
271}
272
273impl EmbeddingModel {
274 pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
275 Self {
276 client,
277 model: model.to_string(),
278 input_type: input_type.to_string(),
279 ndims,
280 }
281 }
282}
283
284pub const COMMAND_R_PLUS: &str = "comman-r-plus";
289pub const COMMAND_R: &str = "command-r";
291pub const COMMAND: &str = "command";
293pub const COMMAND_NIGHTLY: &str = "command-nightly";
295pub const COMMAND_LIGHT: &str = "command-light";
297pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
299
300#[derive(Debug, Deserialize)]
301pub struct CompletionResponse {
302 pub text: String,
303 pub generation_id: String,
304 #[serde(default)]
305 pub citations: Vec<Citation>,
306 #[serde(default)]
307 pub documents: Vec<Document>,
308 #[serde(default)]
309 pub is_search_required: Option<bool>,
310 #[serde(default)]
311 pub search_queries: Vec<SearchQuery>,
312 #[serde(default)]
313 pub search_results: Vec<SearchResult>,
314 pub finish_reason: String,
315 #[serde(default)]
316 pub tool_calls: Vec<ToolCall>,
317 #[serde(default)]
318 pub chat_history: Vec<ChatHistory>,
319}
320
321impl From<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
322 fn from(response: CompletionResponse) -> Self {
323 let CompletionResponse {
324 text, tool_calls, ..
325 } = &response;
326
327 let model_response = if !tool_calls.is_empty() {
328 completion::ModelChoice::ToolCall(
329 tool_calls.first().unwrap().name.clone(),
330 tool_calls.first().unwrap().parameters.clone(),
331 )
332 } else {
333 completion::ModelChoice::Message(text.clone())
334 };
335
336 completion::CompletionResponse {
337 choice: model_response,
338 raw_response: response,
339 }
340 }
341}
342
343#[derive(Debug, Deserialize)]
344pub struct Citation {
345 pub start: u32,
346 pub end: u32,
347 pub text: String,
348 pub document_ids: Vec<String>,
349}
350
351#[derive(Debug, Deserialize)]
352pub struct Document {
353 pub id: String,
354 #[serde(flatten)]
355 pub additional_prop: HashMap<String, serde_json::Value>,
356}
357
358#[derive(Debug, Deserialize)]
359pub struct SearchQuery {
360 pub text: String,
361 pub generation_id: String,
362}
363
364#[derive(Debug, Deserialize)]
365pub struct SearchResult {
366 pub search_query: SearchQuery,
367 pub connector: Connector,
368 pub document_ids: Vec<String>,
369 #[serde(default)]
370 pub error_message: Option<String>,
371 #[serde(default)]
372 pub continue_on_failure: bool,
373}
374
375#[derive(Debug, Deserialize)]
376pub struct Connector {
377 pub id: String,
378}
379
380#[derive(Debug, Deserialize)]
381pub struct ToolCall {
382 pub name: String,
383 pub parameters: serde_json::Value,
384}
385
386#[derive(Debug, Deserialize)]
387pub struct ChatHistory {
388 pub role: String,
389 pub message: String,
390}
391
392#[derive(Debug, Deserialize, Serialize)]
393pub struct Parameter {
394 pub description: String,
395 pub r#type: String,
396 pub required: bool,
397}
398
399#[derive(Debug, Deserialize, Serialize)]
400pub struct ToolDefinition {
401 pub name: String,
402 pub description: String,
403 pub parameter_definitions: HashMap<String, Parameter>,
404}
405
406impl From<completion::ToolDefinition> for ToolDefinition {
407 fn from(tool: completion::ToolDefinition) -> Self {
408 fn convert_type(r#type: &serde_json::Value) -> String {
409 fn convert_type_str(r#type: &str) -> String {
410 match r#type {
411 "string" => "string".to_owned(),
412 "number" => "number".to_owned(),
413 "integer" => "integer".to_owned(),
414 "boolean" => "boolean".to_owned(),
415 "array" => "array".to_owned(),
416 "object" => "object".to_owned(),
417 _ => "string".to_owned(),
418 }
419 }
420 match r#type {
421 serde_json::Value::String(r#type) => convert_type_str(r#type.as_str()),
422 serde_json::Value::Array(types) => convert_type_str(
423 types
424 .iter()
425 .find(|t| t.as_str() != Some("null"))
426 .and_then(|t| t.as_str())
427 .unwrap_or("string"),
428 ),
429 _ => "string".to_owned(),
430 }
431 }
432
433 let maybe_required = tool
434 .parameters
435 .get("required")
436 .and_then(|v| v.as_array())
437 .map(|required| {
438 required
439 .iter()
440 .filter_map(|v| v.as_str())
441 .collect::<Vec<_>>()
442 })
443 .unwrap_or_default();
444
445 Self {
446 name: tool.name,
447 description: tool.description,
448 parameter_definitions: tool
449 .parameters
450 .get("properties")
451 .expect("Tool properties should exist")
452 .as_object()
453 .expect("Tool properties should be an object")
454 .iter()
455 .map(|(argname, argdef)| {
456 (
457 argname.clone(),
458 Parameter {
459 description: argdef
460 .get("description")
461 .expect("Argument description should exist")
462 .as_str()
463 .expect("Argument description should be a string")
464 .to_string(),
465 r#type: convert_type(
466 argdef.get("type").expect("Argument type should exist"),
467 ),
468 required: maybe_required.contains(&argname.as_str()),
469 },
470 )
471 })
472 .collect::<HashMap<_, _>>(),
473 }
474 }
475}
476
477#[derive(Deserialize, Serialize)]
478pub struct Message {
479 pub role: String,
480 pub message: String,
481}
482
483impl From<completion::Message> for Message {
484 fn from(message: completion::Message) -> Self {
485 Self {
486 role: match message.role.as_str() {
487 "system" => "SYSTEM".to_owned(),
488 "user" => "USER".to_owned(),
489 "assistant" => "CHATBOT".to_owned(),
490 _ => "USER".to_owned(),
491 },
492 message: message.content,
493 }
494 }
495}
496
497#[derive(Clone)]
498pub struct CompletionModel {
499 client: Client,
500 pub model: String,
501}
502
503impl CompletionModel {
504 pub fn new(client: Client, model: &str) -> Self {
505 Self {
506 client,
507 model: model.to_string(),
508 }
509 }
510}
511
512impl completion::CompletionModel for CompletionModel {
513 type Response = CompletionResponse;
514
515 async fn completion(
516 &self,
517 completion_request: completion::CompletionRequest,
518 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
519 let request = json!({
520 "model": self.model,
521 "preamble": completion_request.preamble,
522 "message": completion_request.prompt,
523 "documents": completion_request.documents,
524 "chat_history": completion_request.chat_history.into_iter().map(Message::from).collect::<Vec<_>>(),
525 "temperature": completion_request.temperature,
526 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
527 });
528
529 let response = self
530 .client
531 .post("/v1/chat")
532 .json(
533 &if let Some(ref params) = completion_request.additional_params {
534 json_utils::merge(request.clone(), params.clone())
535 } else {
536 request.clone()
537 },
538 )
539 .send()
540 .await?;
541
542 if response.status().is_success() {
543 match response.json::<ApiResponse<CompletionResponse>>().await? {
544 ApiResponse::Ok(completion) => Ok(completion.into()),
545 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
546 }
547 } else {
548 Err(CompletionError::ProviderError(response.text().await?))
549 }
550 }
551}