rig/providers/
cohere.rs

1//! Cohere API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::cohere;
6//!
7//! let client = cohere::Client::new("YOUR_API_KEY");
8//!
9//! let command_r = client.completion_model(cohere::COMMAND_R);
10//! ```
11use std::collections::HashMap;
12
13use crate::{
14    agent::AgentBuilder,
15    completion::{self, CompletionError},
16    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17    extractor::ExtractorBuilder,
18    json_utils, message, Embed, OneOrMany,
19};
20
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25// ================================================================
26// Main Cohere Client
27// ================================================================
28const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
29
30#[derive(Clone)]
31pub struct Client {
32    base_url: String,
33    http_client: reqwest::Client,
34}
35
36impl Client {
37    pub fn new(api_key: &str) -> Self {
38        Self::from_url(api_key, COHERE_API_BASE_URL)
39    }
40
41    pub fn from_url(api_key: &str, base_url: &str) -> Self {
42        Self {
43            base_url: base_url.to_string(),
44            http_client: reqwest::Client::builder()
45                .default_headers({
46                    let mut headers = reqwest::header::HeaderMap::new();
47                    headers.insert(
48                        "Authorization",
49                        format!("Bearer {}", api_key)
50                            .parse()
51                            .expect("Bearer token should parse"),
52                    );
53                    headers
54                })
55                .build()
56                .expect("Cohere reqwest client should build"),
57        }
58    }
59
60    /// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
61    /// Panics if the environment variable is not set.
62    pub fn from_env() -> Self {
63        let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
64        Self::new(&api_key)
65    }
66
67    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
68        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69        self.http_client.post(url)
70    }
71
72    /// Note: default embedding dimension of 0 will be used if model is not known.
73    /// If this is the case, it's better to use function `embedding_model_with_ndims`
74    pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
75        let ndims = match model {
76            EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024,
77            EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384,
78            EMBED_ENGLISH_V2 => 4096,
79            EMBED_MULTILINGUAL_V2 => 768,
80            _ => 0,
81        };
82        EmbeddingModel::new(self.clone(), model, input_type, ndims)
83    }
84
85    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
86    pub fn embedding_model_with_ndims(
87        &self,
88        model: &str,
89        input_type: &str,
90        ndims: usize,
91    ) -> EmbeddingModel {
92        EmbeddingModel::new(self.clone(), model, input_type, ndims)
93    }
94
95    pub fn embeddings<D: Embed>(
96        &self,
97        model: &str,
98        input_type: &str,
99    ) -> EmbeddingsBuilder<EmbeddingModel, D> {
100        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
101    }
102
103    pub fn completion_model(&self, model: &str) -> CompletionModel {
104        CompletionModel::new(self.clone(), model)
105    }
106
107    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
108        AgentBuilder::new(self.completion_model(model))
109    }
110
111    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
112        &self,
113        model: &str,
114    ) -> ExtractorBuilder<T, CompletionModel> {
115        ExtractorBuilder::new(self.completion_model(model))
116    }
117}
118
119#[derive(Debug, Deserialize)]
120struct ApiErrorResponse {
121    message: String,
122}
123
124#[derive(Debug, Deserialize)]
125#[serde(untagged)]
126enum ApiResponse<T> {
127    Ok(T),
128    Err(ApiErrorResponse),
129}
130
131// ================================================================
132// Cohere Embedding API
133// ================================================================
134/// `embed-english-v3.0` embedding model
135pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
136/// `embed-english-light-v3.0` embedding model
137pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
138/// `embed-multilingual-v3.0` embedding model
139pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
140/// `embed-multilingual-light-v3.0` embedding model
141pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
142/// `embed-english-v2.0` embedding model
143pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
144/// `embed-english-light-v2.0` embedding model
145pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
146/// `embed-multilingual-v2.0` embedding model
147pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
148
149#[derive(Deserialize)]
150pub struct EmbeddingResponse {
151    #[serde(default)]
152    pub response_type: Option<String>,
153    pub id: String,
154    pub embeddings: Vec<Vec<f64>>,
155    pub texts: Vec<String>,
156    #[serde(default)]
157    pub meta: Option<Meta>,
158}
159
160#[derive(Deserialize)]
161pub struct Meta {
162    pub api_version: ApiVersion,
163    pub billed_units: BilledUnits,
164    #[serde(default)]
165    pub warnings: Vec<String>,
166}
167
168#[derive(Deserialize)]
169pub struct ApiVersion {
170    pub version: String,
171    #[serde(default)]
172    pub is_deprecated: Option<bool>,
173    #[serde(default)]
174    pub is_experimental: Option<bool>,
175}
176
177#[derive(Deserialize, Debug)]
178pub struct BilledUnits {
179    #[serde(default)]
180    pub input_tokens: u32,
181    #[serde(default)]
182    pub output_tokens: u32,
183    #[serde(default)]
184    pub search_units: u32,
185    #[serde(default)]
186    pub classifications: u32,
187}
188
189impl std::fmt::Display for BilledUnits {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        write!(
192            f,
193            "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
194            self.input_tokens, self.output_tokens, self.search_units, self.classifications
195        )
196    }
197}
198
199#[derive(Clone)]
200pub struct EmbeddingModel {
201    client: Client,
202    pub model: String,
203    pub input_type: String,
204    ndims: usize,
205}
206
207impl embeddings::EmbeddingModel for EmbeddingModel {
208    const MAX_DOCUMENTS: usize = 96;
209
210    fn ndims(&self) -> usize {
211        self.ndims
212    }
213
214    #[cfg_attr(feature = "worker", worker::send)]
215    async fn embed_texts(
216        &self,
217        documents: impl IntoIterator<Item = String>,
218    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
219        let documents = documents.into_iter().collect::<Vec<_>>();
220
221        let response = self
222            .client
223            .post("/v1/embed")
224            .json(&json!({
225                "model": self.model,
226                "texts": documents,
227                "input_type": self.input_type,
228            }))
229            .send()
230            .await?;
231
232        if response.status().is_success() {
233            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
234                ApiResponse::Ok(response) => {
235                    match response.meta {
236                        Some(meta) => tracing::info!(target: "rig",
237                            "Cohere embeddings billed units: {}",
238                            meta.billed_units,
239                        ),
240                        None => tracing::info!(target: "rig",
241                            "Cohere embeddings billed units: n/a",
242                        ),
243                    };
244
245                    if response.embeddings.len() != documents.len() {
246                        return Err(EmbeddingError::DocumentError(
247                            format!(
248                                "Expected {} embeddings, got {}",
249                                documents.len(),
250                                response.embeddings.len()
251                            )
252                            .into(),
253                        ));
254                    }
255
256                    Ok(response
257                        .embeddings
258                        .into_iter()
259                        .zip(documents.into_iter())
260                        .map(|(embedding, document)| embeddings::Embedding {
261                            document,
262                            vec: embedding,
263                        })
264                        .collect())
265                }
266                ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
267            }
268        } else {
269            Err(EmbeddingError::ProviderError(response.text().await?))
270        }
271    }
272}
273
274impl EmbeddingModel {
275    pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
276        Self {
277            client,
278            model: model.to_string(),
279            input_type: input_type.to_string(),
280            ndims,
281        }
282    }
283}
284
285// ================================================================
286// Cohere Completion API
287// ================================================================
288/// `command-r-plus` completion model
289pub const COMMAND_R_PLUS: &str = "comman-r-plus";
290/// `command-r` completion model
291pub const COMMAND_R: &str = "command-r";
292/// `command` completion model
293pub const COMMAND: &str = "command";
294/// `command-nightly` completion model
295pub const COMMAND_NIGHTLY: &str = "command-nightly";
296/// `command-light` completion model
297pub const COMMAND_LIGHT: &str = "command-light";
298/// `command-light-nightly` completion model
299pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
300
301#[derive(Debug, Deserialize)]
302pub struct CompletionResponse {
303    pub text: String,
304    pub generation_id: String,
305    #[serde(default)]
306    pub citations: Vec<Citation>,
307    #[serde(default)]
308    pub documents: Vec<Document>,
309    #[serde(default)]
310    pub is_search_required: Option<bool>,
311    #[serde(default)]
312    pub search_queries: Vec<SearchQuery>,
313    #[serde(default)]
314    pub search_results: Vec<SearchResult>,
315    pub finish_reason: String,
316    #[serde(default)]
317    pub tool_calls: Vec<ToolCall>,
318    #[serde(default)]
319    pub chat_history: Vec<ChatHistory>,
320}
321
322impl From<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
323    fn from(response: CompletionResponse) -> Self {
324        let CompletionResponse {
325            text, tool_calls, ..
326        } = &response;
327
328        let model_response = if !tool_calls.is_empty() {
329            tool_calls
330                .iter()
331                .map(|tool_call| {
332                    completion::AssistantContent::tool_call(
333                        tool_call.name.clone(),
334                        tool_call.name.clone(),
335                        tool_call.parameters.clone(),
336                    )
337                })
338                .collect::<Vec<_>>()
339        } else {
340            vec![completion::AssistantContent::text(text.clone())]
341        };
342
343        completion::CompletionResponse {
344            choice: OneOrMany::many(model_response).expect("There is atleast one content"),
345            raw_response: response,
346        }
347    }
348}
349
350#[derive(Debug, Deserialize)]
351pub struct Citation {
352    pub start: u32,
353    pub end: u32,
354    pub text: String,
355    pub document_ids: Vec<String>,
356}
357
358#[derive(Debug, Deserialize)]
359pub struct Document {
360    pub id: String,
361    #[serde(flatten)]
362    pub additional_prop: HashMap<String, serde_json::Value>,
363}
364
365#[derive(Debug, Deserialize)]
366pub struct SearchQuery {
367    pub text: String,
368    pub generation_id: String,
369}
370
371#[derive(Debug, Deserialize)]
372pub struct SearchResult {
373    pub search_query: SearchQuery,
374    pub connector: Connector,
375    pub document_ids: Vec<String>,
376    #[serde(default)]
377    pub error_message: Option<String>,
378    #[serde(default)]
379    pub continue_on_failure: bool,
380}
381
382#[derive(Debug, Deserialize)]
383pub struct Connector {
384    pub id: String,
385}
386
387#[derive(Debug, Deserialize, Serialize)]
388pub struct ToolCall {
389    pub name: String,
390    pub parameters: serde_json::Value,
391}
392
393#[derive(Debug, Deserialize)]
394pub struct ChatHistory {
395    pub role: String,
396    pub message: String,
397}
398
399#[derive(Debug, Deserialize, Serialize)]
400pub struct Parameter {
401    pub description: String,
402    pub r#type: String,
403    pub required: bool,
404}
405
406#[derive(Debug, Deserialize, Serialize)]
407pub struct ToolDefinition {
408    pub name: String,
409    pub description: String,
410    pub parameter_definitions: HashMap<String, Parameter>,
411}
412
413impl From<completion::ToolDefinition> for ToolDefinition {
414    fn from(tool: completion::ToolDefinition) -> Self {
415        fn convert_type(r#type: &serde_json::Value) -> String {
416            fn convert_type_str(r#type: &str) -> String {
417                match r#type {
418                    "string" => "string".to_owned(),
419                    "number" => "number".to_owned(),
420                    "integer" => "integer".to_owned(),
421                    "boolean" => "boolean".to_owned(),
422                    "array" => "array".to_owned(),
423                    "object" => "object".to_owned(),
424                    _ => "string".to_owned(),
425                }
426            }
427            match r#type {
428                serde_json::Value::String(r#type) => convert_type_str(r#type.as_str()),
429                serde_json::Value::Array(types) => convert_type_str(
430                    types
431                        .iter()
432                        .find(|t| t.as_str() != Some("null"))
433                        .and_then(|t| t.as_str())
434                        .unwrap_or("string"),
435                ),
436                _ => "string".to_owned(),
437            }
438        }
439
440        let maybe_required = tool
441            .parameters
442            .get("required")
443            .and_then(|v| v.as_array())
444            .map(|required| {
445                required
446                    .iter()
447                    .filter_map(|v| v.as_str())
448                    .collect::<Vec<_>>()
449            })
450            .unwrap_or_default();
451
452        Self {
453            name: tool.name,
454            description: tool.description,
455            parameter_definitions: tool
456                .parameters
457                .get("properties")
458                .expect("Tool properties should exist")
459                .as_object()
460                .expect("Tool properties should be an object")
461                .iter()
462                .map(|(argname, argdef)| {
463                    (
464                        argname.clone(),
465                        Parameter {
466                            description: argdef
467                                .get("description")
468                                .expect("Argument description should exist")
469                                .as_str()
470                                .expect("Argument description should be a string")
471                                .to_string(),
472                            r#type: convert_type(
473                                argdef.get("type").expect("Argument type should exist"),
474                            ),
475                            required: maybe_required.contains(&argname.as_str()),
476                        },
477                    )
478                })
479                .collect::<HashMap<_, _>>(),
480        }
481    }
482}
483
484#[derive(Deserialize, Serialize)]
485#[serde(tag = "role", rename_all = "UPPERCASE")]
486pub enum Message {
487    User {
488        message: String,
489        tool_calls: Vec<ToolCall>,
490    },
491
492    Chatbot {
493        message: String,
494        tool_calls: Vec<ToolCall>,
495    },
496
497    Tool {
498        tool_results: Vec<ToolResult>,
499    },
500
501    /// According to the documentation, this message type should not be used
502    System {
503        content: String,
504        tool_calls: Vec<ToolCall>,
505    },
506}
507
508#[derive(Deserialize, Serialize)]
509pub struct ToolResult {
510    pub call: ToolCall,
511    pub outputs: Vec<serde_json::Value>,
512}
513
514impl TryFrom<message::Message> for Vec<Message> {
515    type Error = message::MessageError;
516
517    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
518        match message {
519            message::Message::User { content } => content
520                .into_iter()
521                .map(|content| {
522                    Ok(Message::User {
523                        message: match content {
524                            message::UserContent::Text(message::Text { text }) => text,
525                            _ => {
526                                return Err(message::MessageError::ConversionError(
527                                    "Only text content is supported by Cohere".to_owned(),
528                                ))
529                            }
530                        },
531                        tool_calls: vec![],
532                    })
533                })
534                .collect::<Result<Vec<_>, _>>(),
535            _ => Err(message::MessageError::ConversionError(
536                "Only user messages are supported by Cohere".to_owned(),
537            )),
538        }
539    }
540}
541
542#[derive(Clone)]
543pub struct CompletionModel {
544    client: Client,
545    pub model: String,
546}
547
548impl CompletionModel {
549    pub fn new(client: Client, model: &str) -> Self {
550        Self {
551            client,
552            model: model.to_string(),
553        }
554    }
555}
556
557impl completion::CompletionModel for CompletionModel {
558    type Response = CompletionResponse;
559
560    #[cfg_attr(feature = "worker", worker::send)]
561    async fn completion(
562        &self,
563        completion_request: completion::CompletionRequest,
564    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
565        let chat_history = completion_request
566            .chat_history
567            .into_iter()
568            .map(Vec::<Message>::try_from)
569            .collect::<Result<Vec<Vec<_>>, _>>()?
570            .into_iter()
571            .flatten()
572            .collect::<Vec<_>>();
573
574        let message = match completion_request.prompt {
575            message::Message::User { content } => Ok(content
576                .into_iter()
577                .map(|content| match content {
578                    message::UserContent::Text(message::Text { text }) => Ok(text),
579                    _ => Err(CompletionError::RequestError(
580                        "Only text content is supported by Cohere".into(),
581                    )),
582                })
583                .collect::<Result<Vec<_>, _>>()?
584                .join("\n")),
585
586            _ => Err(CompletionError::RequestError(
587                "Only user messages are supported by Cohere".into(),
588            )),
589        }?;
590
591        let request = json!({
592            "model": self.model,
593            "preamble": completion_request.preamble,
594            "message": message,
595            "documents": completion_request.documents,
596            "chat_history": chat_history,
597            "temperature": completion_request.temperature,
598            "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
599        });
600
601        let response = self
602            .client
603            .post("/v1/chat")
604            .json(
605                &if let Some(ref params) = completion_request.additional_params {
606                    json_utils::merge(request.clone(), params.clone())
607                } else {
608                    request.clone()
609                },
610            )
611            .send()
612            .await?;
613
614        if response.status().is_success() {
615            match response.json::<ApiResponse<CompletionResponse>>().await? {
616                ApiResponse::Ok(completion) => Ok(completion.into()),
617                ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
618            }
619        } else {
620            Err(CompletionError::ProviderError(response.text().await?))
621        }
622    }
623}