rig/providers/cohere/
client.rs

1use crate::{Embed, embeddings::EmbeddingsBuilder};
2
3use super::{CompletionModel, EmbeddingModel};
4use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits};
5use serde::Deserialize;
6
7#[derive(Debug, Deserialize)]
8pub struct ApiErrorResponse {
9    pub message: String,
10}
11
12#[derive(Debug, Deserialize)]
13#[serde(untagged)]
14pub enum ApiResponse<T> {
15    Ok(T),
16    Err(ApiErrorResponse),
17}
18
19// ================================================================
20// Main Cohere Client
21// ================================================================
22const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
23
24#[derive(Clone)]
25pub struct Client {
26    base_url: String,
27    api_key: String,
28    http_client: reqwest::Client,
29}
30
31impl std::fmt::Debug for Client {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("Client")
34            .field("base_url", &self.base_url)
35            .field("http_client", &self.http_client)
36            .field("api_key", &"<REDACTED>")
37            .finish()
38    }
39}
40
41impl Client {
42    pub fn new(api_key: &str) -> Self {
43        Self::from_url(api_key, COHERE_API_BASE_URL)
44    }
45
46    pub fn from_url(api_key: &str, base_url: &str) -> Self {
47        Self {
48            base_url: base_url.to_string(),
49            api_key: api_key.to_string(),
50            http_client: reqwest::Client::builder()
51                .build()
52                .expect("Cohere reqwest client should build"),
53        }
54    }
55
56    /// Use your own `reqwest::Client`.
57    /// The API key will be automatically attached upon trying to make a request, so you shouldn't need to add it as a default header.
58    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
59        self.http_client = client;
60
61        self
62    }
63
64    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
65        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
66        self.http_client.post(url).bearer_auth(&self.api_key)
67    }
68
69    pub fn embeddings<D: Embed>(
70        &self,
71        model: &str,
72        input_type: &str,
73    ) -> EmbeddingsBuilder<EmbeddingModel, D> {
74        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
75    }
76
77    /// Note: default embedding dimension of 0 will be used if model is not known.
78    /// If this is the case, it's better to use function `embedding_model_with_ndims`
79    pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
80        let ndims = match model {
81            super::EMBED_ENGLISH_V3
82            | super::EMBED_MULTILINGUAL_V3
83            | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
84            super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
85            super::EMBED_ENGLISH_V2 => 4096,
86            super::EMBED_MULTILINGUAL_V2 => 768,
87            _ => 0,
88        };
89        EmbeddingModel::new(self.clone(), model, input_type, ndims)
90    }
91
92    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
93    pub fn embedding_model_with_ndims(
94        &self,
95        model: &str,
96        input_type: &str,
97        ndims: usize,
98    ) -> EmbeddingModel {
99        EmbeddingModel::new(self.clone(), model, input_type, ndims)
100    }
101}
102
103impl ProviderClient for Client {
104    /// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
105    /// Panics if the environment variable is not set.
106    fn from_env() -> Self {
107        let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
108        Self::new(&api_key)
109    }
110}
111
112impl CompletionClient for Client {
113    type CompletionModel = CompletionModel;
114
115    fn completion_model(&self, model: &str) -> Self::CompletionModel {
116        CompletionModel::new(self.clone(), model)
117    }
118}
119
120impl EmbeddingsClient for Client {
121    type EmbeddingModel = EmbeddingModel;
122
123    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
124        self.embedding_model(model, "search_document")
125    }
126
127    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
128        self.embedding_model_with_ndims(model, "search_document", ndims)
129    }
130
131    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
132        self.embeddings(model, "search_document")
133    }
134}
135
136impl_conversion_traits!(
137    AsTranscription,
138    AsImageGeneration,
139    AsAudioGeneration for Client
140);