rig/providers/cohere/
client.rs

1use crate::{
2    Embed,
3    client::{VerifyClient, VerifyError},
4    embeddings::EmbeddingsBuilder,
5    http_client::{self, HttpClientExt},
6    wasm_compat::*,
7};
8
9use super::{CompletionModel, EmbeddingModel};
10use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits};
11use bytes::Bytes;
12use serde::Deserialize;
13
14#[derive(Debug, Deserialize)]
15pub struct ApiErrorResponse {
16    pub message: String,
17}
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(ApiErrorResponse),
24}
25
26// ================================================================
27// Main Cohere Client
28// ================================================================
29const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
30
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32    api_key: &'a str,
33    base_url: &'a str,
34    http_client: T,
35}
36
37impl<'a> ClientBuilder<'a, reqwest::Client> {
38    pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> {
39        ClientBuilder {
40            api_key,
41            base_url: COHERE_API_BASE_URL,
42            http_client: Default::default(),
43        }
44    }
45}
46
47impl<'a, T> ClientBuilder<'a, T> {
48    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
49        ClientBuilder {
50            api_key,
51            base_url: COHERE_API_BASE_URL,
52            http_client,
53        }
54    }
55
56    pub fn with_client<U>(api_key: &str, http_client: U) -> ClientBuilder<'_, U> {
57        ClientBuilder {
58            api_key,
59            base_url: COHERE_API_BASE_URL,
60            http_client,
61        }
62    }
63
64    pub fn base_url(mut self, base_url: &'a str) -> ClientBuilder<'a, T> {
65        self.base_url = base_url;
66        self
67    }
68
69    pub fn build(self) -> Client<T> {
70        Client {
71            base_url: self.base_url.to_string(),
72            api_key: self.api_key.to_string(),
73            http_client: self.http_client,
74        }
75    }
76}
77
78#[derive(Clone)]
79pub struct Client<T = reqwest::Client> {
80    base_url: String,
81    api_key: String,
82    http_client: T,
83}
84
85impl<T> std::fmt::Debug for Client<T>
86where
87    T: HttpClientExt + std::fmt::Debug,
88{
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("Client")
91            .field("base_url", &self.base_url)
92            .field("http_client", &self.http_client)
93            .field("api_key", &"<REDACTED>")
94            .finish()
95    }
96}
97
98impl Client<reqwest::Client> {
99    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
100        ClientBuilder::new(api_key)
101    }
102
103    /// Create a new Cohere client. For more control, use the `builder` method.
104    ///
105    /// # Panics
106    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
107    pub fn new(api_key: &str) -> Self {
108        ClientBuilder::new(api_key).build()
109    }
110
111    pub fn from_env() -> Self {
112        <Self as ProviderClient>::from_env()
113    }
114}
115
116impl<T> Client<T>
117where
118    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
119{
120    pub(crate) fn req(
121        &self,
122        method: http_client::Method,
123        path: &str,
124    ) -> http_client::Result<http_client::Builder> {
125        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
126
127        http_client::with_bearer_auth(
128            http_client::Builder::new().method(method).uri(url),
129            &self.api_key,
130        )
131    }
132
133    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
134        self.req(http_client::Method::POST, path)
135    }
136
137    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
138        self.req(http_client::Method::GET, path)
139    }
140
141    pub fn http_client(&self) -> T {
142        self.http_client.clone()
143    }
144
145    pub(crate) async fn send<U, V>(
146        &self,
147        req: http_client::Request<U>,
148    ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
149    where
150        U: Into<Bytes> + Send,
151        V: From<Bytes> + Send + 'static,
152    {
153        self.http_client.send(req).await
154    }
155
156    pub fn embeddings<D: Embed>(
157        &self,
158        model: &str,
159        input_type: &str,
160    ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
161        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
162    }
163
164    /// Note: default embedding dimension of 0 will be used if model is not known.
165    /// If this is the case, it's better to use function `embedding_model_with_ndims`
166    pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel<T> {
167        let ndims = match model {
168            super::EMBED_ENGLISH_V3
169            | super::EMBED_MULTILINGUAL_V3
170            | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
171            super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
172            super::EMBED_ENGLISH_V2 => 4096,
173            super::EMBED_MULTILINGUAL_V2 => 768,
174            _ => 0,
175        };
176        EmbeddingModel::new(self.clone(), model, input_type, ndims)
177    }
178
179    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
180    pub fn embedding_model_with_ndims(
181        &self,
182        model: &str,
183        input_type: &str,
184        ndims: usize,
185    ) -> EmbeddingModel<T> {
186        EmbeddingModel::new(self.clone(), model, input_type, ndims)
187    }
188}
189
190impl<T> ProviderClient for Client<T>
191where
192    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
193{
194    /// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
195    /// Panics if the environment variable is not set.
196    fn from_env() -> Self {
197        let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
198        ClientBuilder::new_with_client(&api_key, T::default()).build()
199    }
200
201    fn from_val(input: crate::client::ProviderValue) -> Self {
202        let crate::client::ProviderValue::Simple(api_key) = input else {
203            panic!("Incorrect provider value type")
204        };
205        ClientBuilder::new_with_client(&api_key, T::default()).build()
206    }
207}
208
209impl<T> CompletionClient for Client<T>
210where
211    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
212{
213    type CompletionModel = CompletionModel<T>;
214
215    fn completion_model(&self, model: &str) -> Self::CompletionModel {
216        CompletionModel::new(self.clone(), model)
217    }
218}
219
220impl<T> EmbeddingsClient for Client<T>
221where
222    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
223{
224    type EmbeddingModel = EmbeddingModel<T>;
225
226    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
227        self.embedding_model(model, "search_document")
228    }
229
230    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
231        self.embedding_model_with_ndims(model, "search_document", ndims)
232    }
233
234    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
235        self.embeddings(model, "search_document")
236    }
237}
238
239impl<T> VerifyClient for Client<T>
240where
241    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
242{
243    #[cfg_attr(feature = "worker", worker::send)]
244    async fn verify(&self) -> Result<(), VerifyError> {
245        let req = self
246            .get("/models")?
247            .body(http_client::NoBody)
248            .map_err(|e| VerifyError::HttpError(e.into()))?;
249
250        let response = self.http_client.send::<_, Vec<u8>>(req).await?;
251        let status = response.status();
252        let body = http_client::text(response).await?;
253
254        match status {
255            reqwest::StatusCode::OK => Ok(()),
256            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
257            reqwest::StatusCode::INTERNAL_SERVER_ERROR => Err(VerifyError::ProviderError(body)),
258            _ => Err(VerifyError::ProviderError(body)),
259        }
260    }
261}
262
263impl_conversion_traits!(
264    AsTranscription,
265    AsImageGeneration,
266    AsAudioGeneration for Client<T>
267);