rig/providers/cohere/
client.rs

1use crate::{
2    Embed,
3    client::{VerifyClient, VerifyError},
4    embeddings::EmbeddingsBuilder,
5    http_client::{self, HttpClientExt},
6    wasm_compat::*,
7};
8
9use super::{CompletionModel, EmbeddingModel};
10use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits};
11use bytes::Bytes;
12use reqwest_eventsource::{CannotCloneRequestError, EventSource};
13use serde::Deserialize;
14
15#[derive(Debug, Deserialize)]
16pub struct ApiErrorResponse {
17    pub message: String,
18}
19
20#[derive(Debug, Deserialize)]
21#[serde(untagged)]
22pub enum ApiResponse<T> {
23    Ok(T),
24    Err(ApiErrorResponse),
25}
26
27// ================================================================
28// Main Cohere Client
29// ================================================================
30const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
31
32pub struct ClientBuilder<'a, T = reqwest::Client> {
33    api_key: &'a str,
34    base_url: &'a str,
35    http_client: T,
36}
37
38impl<'a> ClientBuilder<'a, reqwest::Client> {
39    pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> {
40        ClientBuilder {
41            api_key,
42            base_url: COHERE_API_BASE_URL,
43            http_client: Default::default(),
44        }
45    }
46}
47
48impl<'a, T> ClientBuilder<'a, T> {
49    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
50        ClientBuilder {
51            api_key,
52            base_url: COHERE_API_BASE_URL,
53            http_client,
54        }
55    }
56
57    pub fn with_client<U>(api_key: &str, http_client: U) -> ClientBuilder<'_, U> {
58        ClientBuilder {
59            api_key,
60            base_url: COHERE_API_BASE_URL,
61            http_client,
62        }
63    }
64
65    pub fn base_url(mut self, base_url: &'a str) -> ClientBuilder<'a, T> {
66        self.base_url = base_url;
67        self
68    }
69
70    pub fn build(self) -> Client<T> {
71        Client {
72            base_url: self.base_url.to_string(),
73            api_key: self.api_key.to_string(),
74            http_client: self.http_client,
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct Client<T = reqwest::Client> {
81    base_url: String,
82    api_key: String,
83    http_client: T,
84}
85
86impl<T> std::fmt::Debug for Client<T>
87where
88    T: HttpClientExt + std::fmt::Debug,
89{
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.debug_struct("Client")
92            .field("base_url", &self.base_url)
93            .field("http_client", &self.http_client)
94            .field("api_key", &"<REDACTED>")
95            .finish()
96    }
97}
98
99impl Client<reqwest::Client> {
100    /// Create a new Cohere client. For more control, use the `builder` method.
101    ///
102    /// # Panics
103    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
104    pub fn new(api_key: &str) -> Self {
105        ClientBuilder::new(api_key).build()
106    }
107}
108
109impl<T> Client<T>
110where
111    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
112{
113    fn req(
114        &self,
115        method: http_client::Method,
116        path: &str,
117    ) -> http_client::Result<http_client::Builder> {
118        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
119
120        http_client::with_bearer_auth(
121            http_client::Builder::new().method(method).uri(url),
122            &self.api_key,
123        )
124    }
125
126    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
127        self.req(http_client::Method::POST, path)
128    }
129
130    pub(crate) async fn send<U, V>(
131        &self,
132        req: http_client::Request<U>,
133    ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
134    where
135        U: Into<Bytes> + Send,
136        V: From<Bytes> + Send + 'static,
137    {
138        self.http_client.send(req).await
139    }
140
141    pub fn embeddings<D: Embed>(
142        &self,
143        model: &str,
144        input_type: &str,
145    ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
146        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
147    }
148
149    /// Note: default embedding dimension of 0 will be used if model is not known.
150    /// If this is the case, it's better to use function `embedding_model_with_ndims`
151    pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel<T> {
152        let ndims = match model {
153            super::EMBED_ENGLISH_V3
154            | super::EMBED_MULTILINGUAL_V3
155            | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
156            super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
157            super::EMBED_ENGLISH_V2 => 4096,
158            super::EMBED_MULTILINGUAL_V2 => 768,
159            _ => 0,
160        };
161        EmbeddingModel::new(self.clone(), model, input_type, ndims)
162    }
163
164    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
165    pub fn embedding_model_with_ndims(
166        &self,
167        model: &str,
168        input_type: &str,
169        ndims: usize,
170    ) -> EmbeddingModel<T> {
171        EmbeddingModel::new(self.clone(), model, input_type, ndims)
172    }
173}
174
175impl Client<reqwest::Client> {
176    pub(crate) async fn eventsource(
177        &self,
178        req: reqwest::RequestBuilder,
179    ) -> Result<EventSource, CannotCloneRequestError> {
180        reqwest_eventsource::EventSource::new(req)
181    }
182
183    pub(crate) fn client(&self) -> &reqwest::Client {
184        &self.http_client
185    }
186}
187
188impl ProviderClient for Client<reqwest::Client> {
189    /// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
190    /// Panics if the environment variable is not set.
191    fn from_env() -> Self {
192        let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
193        Self::new(&api_key)
194    }
195
196    fn from_val(input: crate::client::ProviderValue) -> Self {
197        let crate::client::ProviderValue::Simple(api_key) = input else {
198            panic!("Incorrect provider value type")
199        };
200        Self::new(&api_key)
201    }
202}
203
204impl CompletionClient for Client<reqwest::Client> {
205    type CompletionModel = CompletionModel<reqwest::Client>;
206
207    fn completion_model(&self, model: &str) -> Self::CompletionModel {
208        CompletionModel::new(self.clone(), model)
209    }
210}
211
212impl EmbeddingsClient for Client<reqwest::Client> {
213    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
214
215    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
216        self.embedding_model(model, "search_document")
217    }
218
219    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
220        self.embedding_model_with_ndims(model, "search_document", ndims)
221    }
222
223    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
224        self.embeddings(model, "search_document")
225    }
226}
227
228impl VerifyClient for Client<reqwest::Client> {
229    #[cfg_attr(feature = "worker", worker::send)]
230    async fn verify(&self) -> Result<(), VerifyError> {
231        let response = self
232            .http_client
233            .get("/v1/models")
234            .send()
235            .await
236            .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
237
238        match response.status() {
239            reqwest::StatusCode::OK => Ok(()),
240            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
241            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
242                Err(VerifyError::ProviderError(response.text().await.map_err(
243                    |e| VerifyError::HttpError(http_client::Error::Instance(e.into())),
244                )?))
245            }
246            _ => {
247                response
248                    .error_for_status()
249                    .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
250                Ok(())
251            }
252        }
253    }
254}
255
256impl_conversion_traits!(
257    AsTranscription,
258    AsImageGeneration,
259    AsAudioGeneration for Client<T>
260);