atlas/providers/
cohere.rs

1//! Cohere API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::cohere;
6//!
7//! let client = cohere::Client::new("YOUR_API_KEY");
8//!
9//! let command_r = client.completion_model(cohere::COMMAND_R);
10//! ```
11use std::collections::HashMap;
12
13use crate::{
14    agent::AgentBuilder,
15    completion::{self, CompletionError},
16    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17    extractor::ExtractorBuilder,
18    json_utils, Embed,
19};
20
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25// ================================================================
26// Main Cohere Client
27// ================================================================
28const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
29
30#[derive(Clone)]
31pub struct Client {
32    base_url: String,
33    http_client: reqwest::Client,
34}
35
36impl Client {
37    pub fn new(api_key: &str) -> Self {
38        Self::from_url(api_key, COHERE_API_BASE_URL)
39    }
40
41    pub fn from_url(api_key: &str, base_url: &str) -> Self {
42        Self {
43            base_url: base_url.to_string(),
44            http_client: reqwest::Client::builder()
45                .default_headers({
46                    let mut headers = reqwest::header::HeaderMap::new();
47                    headers.insert(
48                        "Authorization",
49                        format!("Bearer {}", api_key)
50                            .parse()
51                            .expect("Bearer token should parse"),
52                    );
53                    headers
54                })
55                .build()
56                .expect("Cohere reqwest client should build"),
57        }
58    }
59
60    /// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
61    /// Panics if the environment variable is not set.
62    pub fn from_env() -> Self {
63        let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
64        Self::new(&api_key)
65    }
66
67    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
68        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69        self.http_client.post(url)
70    }
71
72    /// Note: default embedding dimension of 0 will be used if model is not known.
73    /// If this is the case, it's better to use function `embedding_model_with_ndims`
74    pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
75        let ndims = match model {
76            EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024,
77            EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384,
78            EMBED_ENGLISH_V2 => 4096,
79            EMBED_MULTILINGUAL_V2 => 768,
80            _ => 0,
81        };
82        EmbeddingModel::new(self.clone(), model, input_type, ndims)
83    }
84
85    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
86    pub fn embedding_model_with_ndims(
87        &self,
88        model: &str,
89        input_type: &str,
90        ndims: usize,
91    ) -> EmbeddingModel {
92        EmbeddingModel::new(self.clone(), model, input_type, ndims)
93    }
94
95    pub fn embeddings<D: Embed>(
96        &self,
97        model: &str,
98        input_type: &str,
99    ) -> EmbeddingsBuilder<EmbeddingModel, D> {
100        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
101    }
102
103    pub fn completion_model(&self, model: &str) -> CompletionModel {
104        CompletionModel::new(self.clone(), model)
105    }
106
107    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
108        AgentBuilder::new(self.completion_model(model))
109    }
110
111    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
112        &self,
113        model: &str,
114    ) -> ExtractorBuilder<T, CompletionModel> {
115        ExtractorBuilder::new(self.completion_model(model))
116    }
117}
118
119#[derive(Debug, Deserialize)]
120struct ApiErrorResponse {
121    message: String,
122}
123
124#[derive(Debug, Deserialize)]
125#[serde(untagged)]
126enum ApiResponse<T> {
127    Ok(T),
128    Err(ApiErrorResponse),
129}
130
131// ================================================================
132// Cohere Embedding API
133// ================================================================
134/// `embed-english-v3.0` embedding model
135pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
136/// `embed-english-light-v3.0` embedding model
137pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
138/// `embed-multilingual-v3.0` embedding model
139pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
140/// `embed-multilingual-light-v3.0` embedding model
141pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
142/// `embed-english-v2.0` embedding model
143pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
144/// `embed-english-light-v2.0` embedding model
145pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
146/// `embed-multilingual-v2.0` embedding model
147pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
148
149#[derive(Deserialize)]
150pub struct EmbeddingResponse {
151    #[serde(default)]
152    pub response_type: Option<String>,
153    pub id: String,
154    pub embeddings: Vec<Vec<f64>>,
155    pub texts: Vec<String>,
156    #[serde(default)]
157    pub meta: Option<Meta>,
158}
159
160#[derive(Deserialize)]
161pub struct Meta {
162    pub api_version: ApiVersion,
163    pub billed_units: BilledUnits,
164    #[serde(default)]
165    pub warnings: Vec<String>,
166}
167
168#[derive(Deserialize)]
169pub struct ApiVersion {
170    pub version: String,
171    #[serde(default)]
172    pub is_deprecated: Option<bool>,
173    #[serde(default)]
174    pub is_experimental: Option<bool>,
175}
176
177#[derive(Deserialize, Debug)]
178pub struct BilledUnits {
179    #[serde(default)]
180    pub input_tokens: u32,
181    #[serde(default)]
182    pub output_tokens: u32,
183    #[serde(default)]
184    pub search_units: u32,
185    #[serde(default)]
186    pub classifications: u32,
187}
188
189impl std::fmt::Display for BilledUnits {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        write!(
192            f,
193            "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
194            self.input_tokens, self.output_tokens, self.search_units, self.classifications
195        )
196    }
197}
198
199#[derive(Clone)]
200pub struct EmbeddingModel {
201    client: Client,
202    pub model: String,
203    pub input_type: String,
204    ndims: usize,
205}
206
207impl embeddings::EmbeddingModel for EmbeddingModel {
208    const MAX_DOCUMENTS: usize = 96;
209
210    fn ndims(&self) -> usize {
211        self.ndims
212    }
213
214    async fn embed_texts(
215        &self,
216        documents: impl IntoIterator<Item = String>,
217    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
218        let documents = documents.into_iter().collect::<Vec<_>>();
219
220        let response = self
221            .client
222            .post("/v1/embed")
223            .json(&json!({
224                "model": self.model,
225                "texts": documents,
226                "input_type": self.input_type,
227            }))
228            .send()
229            .await?;
230
231        if response.status().is_success() {
232            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
233                ApiResponse::Ok(response) => {
234                    match response.meta {
235                        Some(meta) => tracing::info!(target: "rig",
236                            "Cohere embeddings billed units: {}",
237                            meta.billed_units,
238                        ),
239                        None => tracing::info!(target: "rig",
240                            "Cohere embeddings billed units: n/a",
241                        ),
242                    };
243
244                    if response.embeddings.len() != documents.len() {
245                        return Err(EmbeddingError::DocumentError(
246                            format!(
247                                "Expected {} embeddings, got {}",
248                                documents.len(),
249                                response.embeddings.len()
250                            )
251                            .into(),
252                        ));
253                    }
254
255                    Ok(response
256                        .embeddings
257                        .into_iter()
258                        .zip(documents.into_iter())
259                        .map(|(embedding, document)| embeddings::Embedding {
260                            document,
261                            vec: embedding,
262                        })
263                        .collect())
264                }
265                ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
266            }
267        } else {
268            Err(EmbeddingError::ProviderError(response.text().await?))
269        }
270    }
271}
272
273impl EmbeddingModel {
274    pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
275        Self {
276            client,
277            model: model.to_string(),
278            input_type: input_type.to_string(),
279            ndims,
280        }
281    }
282}
283
284// ================================================================
285// Cohere Completion API
286// ================================================================
287/// `command-r-plus` completion model
288pub const COMMAND_R_PLUS: &str = "comman-r-plus";
289/// `command-r` completion model
290pub const COMMAND_R: &str = "command-r";
291/// `command` completion model
292pub const COMMAND: &str = "command";
293/// `command-nightly` completion model
294pub const COMMAND_NIGHTLY: &str = "command-nightly";
295/// `command-light` completion model
296pub const COMMAND_LIGHT: &str = "command-light";
297/// `command-light-nightly` completion model
298pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
299
300#[derive(Debug, Deserialize)]
301pub struct CompletionResponse {
302    pub text: String,
303    pub generation_id: String,
304    #[serde(default)]
305    pub citations: Vec<Citation>,
306    #[serde(default)]
307    pub documents: Vec<Document>,
308    #[serde(default)]
309    pub is_search_required: Option<bool>,
310    #[serde(default)]
311    pub search_queries: Vec<SearchQuery>,
312    #[serde(default)]
313    pub search_results: Vec<SearchResult>,
314    pub finish_reason: String,
315    #[serde(default)]
316    pub tool_calls: Vec<ToolCall>,
317    #[serde(default)]
318    pub chat_history: Vec<ChatHistory>,
319}
320
321impl From<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
322    fn from(response: CompletionResponse) -> Self {
323        let CompletionResponse {
324            text, tool_calls, ..
325        } = &response;
326
327        let model_response = if !tool_calls.is_empty() {
328            completion::ModelChoice::ToolCall(
329                tool_calls.first().unwrap().name.clone(),
330                tool_calls.first().unwrap().parameters.clone(),
331            )
332        } else {
333            completion::ModelChoice::Message(text.clone())
334        };
335
336        completion::CompletionResponse {
337            choice: model_response,
338            raw_response: response,
339        }
340    }
341}
342
343#[derive(Debug, Deserialize)]
344pub struct Citation {
345    pub start: u32,
346    pub end: u32,
347    pub text: String,
348    pub document_ids: Vec<String>,
349}
350
351#[derive(Debug, Deserialize)]
352pub struct Document {
353    pub id: String,
354    #[serde(flatten)]
355    pub additional_prop: HashMap<String, serde_json::Value>,
356}
357
358#[derive(Debug, Deserialize)]
359pub struct SearchQuery {
360    pub text: String,
361    pub generation_id: String,
362}
363
364#[derive(Debug, Deserialize)]
365pub struct SearchResult {
366    pub search_query: SearchQuery,
367    pub connector: Connector,
368    pub document_ids: Vec<String>,
369    #[serde(default)]
370    pub error_message: Option<String>,
371    #[serde(default)]
372    pub continue_on_failure: bool,
373}
374
375#[derive(Debug, Deserialize)]
376pub struct Connector {
377    pub id: String,
378}
379
380#[derive(Debug, Deserialize)]
381pub struct ToolCall {
382    pub name: String,
383    pub parameters: serde_json::Value,
384}
385
386#[derive(Debug, Deserialize)]
387pub struct ChatHistory {
388    pub role: String,
389    pub message: String,
390}
391
392#[derive(Debug, Deserialize, Serialize)]
393pub struct Parameter {
394    pub description: String,
395    pub r#type: String,
396    pub required: bool,
397}
398
399#[derive(Debug, Deserialize, Serialize)]
400pub struct ToolDefinition {
401    pub name: String,
402    pub description: String,
403    pub parameter_definitions: HashMap<String, Parameter>,
404}
405
406impl From<completion::ToolDefinition> for ToolDefinition {
407    fn from(tool: completion::ToolDefinition) -> Self {
408        fn convert_type(r#type: &serde_json::Value) -> String {
409            fn convert_type_str(r#type: &str) -> String {
410                match r#type {
411                    "string" => "string".to_owned(),
412                    "number" => "number".to_owned(),
413                    "integer" => "integer".to_owned(),
414                    "boolean" => "boolean".to_owned(),
415                    "array" => "array".to_owned(),
416                    "object" => "object".to_owned(),
417                    _ => "string".to_owned(),
418                }
419            }
420            match r#type {
421                serde_json::Value::String(r#type) => convert_type_str(r#type.as_str()),
422                serde_json::Value::Array(types) => convert_type_str(
423                    types
424                        .iter()
425                        .find(|t| t.as_str() != Some("null"))
426                        .and_then(|t| t.as_str())
427                        .unwrap_or("string"),
428                ),
429                _ => "string".to_owned(),
430            }
431        }
432
433        let maybe_required = tool
434            .parameters
435            .get("required")
436            .and_then(|v| v.as_array())
437            .map(|required| {
438                required
439                    .iter()
440                    .filter_map(|v| v.as_str())
441                    .collect::<Vec<_>>()
442            })
443            .unwrap_or_default();
444
445        Self {
446            name: tool.name,
447            description: tool.description,
448            parameter_definitions: tool
449                .parameters
450                .get("properties")
451                .expect("Tool properties should exist")
452                .as_object()
453                .expect("Tool properties should be an object")
454                .iter()
455                .map(|(argname, argdef)| {
456                    (
457                        argname.clone(),
458                        Parameter {
459                            description: argdef
460                                .get("description")
461                                .expect("Argument description should exist")
462                                .as_str()
463                                .expect("Argument description should be a string")
464                                .to_string(),
465                            r#type: convert_type(
466                                argdef.get("type").expect("Argument type should exist"),
467                            ),
468                            required: maybe_required.contains(&argname.as_str()),
469                        },
470                    )
471                })
472                .collect::<HashMap<_, _>>(),
473        }
474    }
475}
476
477#[derive(Deserialize, Serialize)]
478pub struct Message {
479    pub role: String,
480    pub message: String,
481}
482
483impl From<completion::Message> for Message {
484    fn from(message: completion::Message) -> Self {
485        Self {
486            role: match message.role.as_str() {
487                "system" => "SYSTEM".to_owned(),
488                "user" => "USER".to_owned(),
489                "assistant" => "CHATBOT".to_owned(),
490                _ => "USER".to_owned(),
491            },
492            message: message.content,
493        }
494    }
495}
496
497#[derive(Clone)]
498pub struct CompletionModel {
499    client: Client,
500    pub model: String,
501}
502
503impl CompletionModel {
504    pub fn new(client: Client, model: &str) -> Self {
505        Self {
506            client,
507            model: model.to_string(),
508        }
509    }
510}
511
512impl completion::CompletionModel for CompletionModel {
513    type Response = CompletionResponse;
514
515    async fn completion(
516        &self,
517        completion_request: completion::CompletionRequest,
518    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
519        let request = json!({
520            "model": self.model,
521            "preamble": completion_request.preamble,
522            "message": completion_request.prompt,
523            "documents": completion_request.documents,
524            "chat_history": completion_request.chat_history.into_iter().map(Message::from).collect::<Vec<_>>(),
525            "temperature": completion_request.temperature,
526            "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
527        });
528
529        let response = self
530            .client
531            .post("/v1/chat")
532            .json(
533                &if let Some(ref params) = completion_request.additional_params {
534                    json_utils::merge(request.clone(), params.clone())
535                } else {
536                    request.clone()
537                },
538            )
539            .send()
540            .await?;
541
542        if response.status().is_success() {
543            match response.json::<ApiResponse<CompletionResponse>>().await? {
544                ApiResponse::Ok(completion) => Ok(completion.into()),
545                ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
546            }
547        } else {
548            Err(CompletionError::ProviderError(response.text().await?))
549        }
550    }
551}