rig/providers/cohere/
client.rs

1use crate::{
2    Embed,
3    client::{VerifyClient, VerifyError},
4    embeddings::EmbeddingsBuilder,
5};
6
7use super::{CompletionModel, EmbeddingModel};
8use crate::client::{
9    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits,
10};
11use serde::Deserialize;
12
13#[derive(Debug, Deserialize)]
14pub struct ApiErrorResponse {
15    pub message: String,
16}
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20pub enum ApiResponse<T> {
21    Ok(T),
22    Err(ApiErrorResponse),
23}
24
25// ================================================================
26// Main Cohere Client
27// ================================================================
28const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
29
30pub struct ClientBuilder<'a> {
31    api_key: &'a str,
32    base_url: &'a str,
33    http_client: Option<reqwest::Client>,
34}
35
36impl<'a> ClientBuilder<'a> {
37    pub fn new(api_key: &'a str) -> Self {
38        Self {
39            api_key,
40            base_url: COHERE_API_BASE_URL,
41            http_client: None,
42        }
43    }
44
45    pub fn base_url(mut self, base_url: &'a str) -> Self {
46        self.base_url = base_url;
47        self
48    }
49
50    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
51        self.http_client = Some(client);
52        self
53    }
54
55    pub fn build(self) -> Result<Client, ClientBuilderError> {
56        let http_client = if let Some(http_client) = self.http_client {
57            http_client
58        } else {
59            reqwest::Client::builder().build()?
60        };
61
62        Ok(Client {
63            base_url: self.base_url.to_string(),
64            api_key: self.api_key.to_string(),
65            http_client,
66        })
67    }
68}
69
70#[derive(Clone)]
71pub struct Client {
72    base_url: String,
73    api_key: String,
74    http_client: reqwest::Client,
75}
76
77impl std::fmt::Debug for Client {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("Client")
80            .field("base_url", &self.base_url)
81            .field("http_client", &self.http_client)
82            .field("api_key", &"<REDACTED>")
83            .finish()
84    }
85}
86
87impl Client {
88    /// Create a new Cohere client builder.
89    ///
90    /// # Example
91    /// ```
92    /// use rig::providers::cohere::{ClientBuilder, self};
93    ///
94    /// // Initialize the Cohere client
95    /// let cohere_client = Client::builder("your-cohere-api-key")
96    ///    .build()
97    /// ```
98    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
99        ClientBuilder::new(api_key)
100    }
101
102    /// Create a new Cohere client. For more control, use the `builder` method.
103    ///
104    /// # Panics
105    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
106    pub fn new(api_key: &str) -> Self {
107        Self::builder(api_key)
108            .build()
109            .expect("Cohere client should build")
110    }
111
112    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
113        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
114        self.http_client.post(url).bearer_auth(&self.api_key)
115    }
116
117    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
118        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
119        self.http_client.get(url).bearer_auth(&self.api_key)
120    }
121
122    pub fn embeddings<D: Embed>(
123        &self,
124        model: &str,
125        input_type: &str,
126    ) -> EmbeddingsBuilder<EmbeddingModel, D> {
127        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
128    }
129
130    /// Note: default embedding dimension of 0 will be used if model is not known.
131    /// If this is the case, it's better to use function `embedding_model_with_ndims`
132    pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
133        let ndims = match model {
134            super::EMBED_ENGLISH_V3
135            | super::EMBED_MULTILINGUAL_V3
136            | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
137            super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
138            super::EMBED_ENGLISH_V2 => 4096,
139            super::EMBED_MULTILINGUAL_V2 => 768,
140            _ => 0,
141        };
142        EmbeddingModel::new(self.clone(), model, input_type, ndims)
143    }
144
145    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
146    pub fn embedding_model_with_ndims(
147        &self,
148        model: &str,
149        input_type: &str,
150        ndims: usize,
151    ) -> EmbeddingModel {
152        EmbeddingModel::new(self.clone(), model, input_type, ndims)
153    }
154}
155
156impl ProviderClient for Client {
157    /// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
158    /// Panics if the environment variable is not set.
159    fn from_env() -> Self {
160        let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
161        Self::new(&api_key)
162    }
163
164    fn from_val(input: crate::client::ProviderValue) -> Self {
165        let crate::client::ProviderValue::Simple(api_key) = input else {
166            panic!("Incorrect provider value type")
167        };
168        Self::new(&api_key)
169    }
170}
171
172impl CompletionClient for Client {
173    type CompletionModel = CompletionModel;
174
175    fn completion_model(&self, model: &str) -> Self::CompletionModel {
176        CompletionModel::new(self.clone(), model)
177    }
178}
179
180impl EmbeddingsClient for Client {
181    type EmbeddingModel = EmbeddingModel;
182
183    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
184        self.embedding_model(model, "search_document")
185    }
186
187    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
188        self.embedding_model_with_ndims(model, "search_document", ndims)
189    }
190
191    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
192        self.embeddings(model, "search_document")
193    }
194}
195
196impl VerifyClient for Client {
197    #[cfg_attr(feature = "worker", worker::send)]
198    async fn verify(&self) -> Result<(), VerifyError> {
199        let response = self.get("/v1/models").send().await?;
200        match response.status() {
201            reqwest::StatusCode::OK => Ok(()),
202            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
203            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
204                Err(VerifyError::ProviderError(response.text().await?))
205            }
206            _ => {
207                response.error_for_status()?;
208                Ok(())
209            }
210        }
211    }
212}
213
214impl_conversion_traits!(
215    AsTranscription,
216    AsImageGeneration,
217    AsAudioGeneration for Client
218);