1use crate::{
2 Embed,
3 client::{VerifyClient, VerifyError},
4 embeddings::EmbeddingsBuilder,
5 http_client::{self, HttpClientExt},
6 wasm_compat::*,
7};
8
9use super::{CompletionModel, EmbeddingModel};
10use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits};
11use bytes::Bytes;
12use reqwest_eventsource::{CannotCloneRequestError, EventSource};
13use serde::Deserialize;
14
15#[derive(Debug, Deserialize)]
16pub struct ApiErrorResponse {
17 pub message: String,
18}
19
20#[derive(Debug, Deserialize)]
21#[serde(untagged)]
22pub enum ApiResponse<T> {
23 Ok(T),
24 Err(ApiErrorResponse),
25}
26
27const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
31
32pub struct ClientBuilder<'a, T = reqwest::Client> {
33 api_key: &'a str,
34 base_url: &'a str,
35 http_client: T,
36}
37
38impl<'a> ClientBuilder<'a, reqwest::Client> {
39 pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> {
40 ClientBuilder {
41 api_key,
42 base_url: COHERE_API_BASE_URL,
43 http_client: Default::default(),
44 }
45 }
46}
47
48impl<'a, T> ClientBuilder<'a, T> {
49 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
50 ClientBuilder {
51 api_key,
52 base_url: COHERE_API_BASE_URL,
53 http_client,
54 }
55 }
56
57 pub fn with_client<U>(api_key: &str, http_client: U) -> ClientBuilder<'_, U> {
58 ClientBuilder {
59 api_key,
60 base_url: COHERE_API_BASE_URL,
61 http_client,
62 }
63 }
64
65 pub fn base_url(mut self, base_url: &'a str) -> ClientBuilder<'a, T> {
66 self.base_url = base_url;
67 self
68 }
69
70 pub fn build(self) -> Client<T> {
71 Client {
72 base_url: self.base_url.to_string(),
73 api_key: self.api_key.to_string(),
74 http_client: self.http_client,
75 }
76 }
77}
78
79#[derive(Clone)]
80pub struct Client<T = reqwest::Client> {
81 base_url: String,
82 api_key: String,
83 http_client: T,
84}
85
86impl<T> std::fmt::Debug for Client<T>
87where
88 T: HttpClientExt + std::fmt::Debug,
89{
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 f.debug_struct("Client")
92 .field("base_url", &self.base_url)
93 .field("http_client", &self.http_client)
94 .field("api_key", &"<REDACTED>")
95 .finish()
96 }
97}
98
99impl Client<reqwest::Client> {
100 pub fn new(api_key: &str) -> Self {
105 ClientBuilder::new(api_key).build()
106 }
107}
108
109impl<T> Client<T>
110where
111 T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
112{
113 fn req(
114 &self,
115 method: http_client::Method,
116 path: &str,
117 ) -> http_client::Result<http_client::Builder> {
118 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
119
120 http_client::with_bearer_auth(
121 http_client::Builder::new().method(method).uri(url),
122 &self.api_key,
123 )
124 }
125
126 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
127 self.req(http_client::Method::POST, path)
128 }
129
130 pub(crate) async fn send<U, V>(
131 &self,
132 req: http_client::Request<U>,
133 ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
134 where
135 U: Into<Bytes> + Send,
136 V: From<Bytes> + Send + 'static,
137 {
138 self.http_client.send(req).await
139 }
140
141 pub fn embeddings<D: Embed>(
142 &self,
143 model: &str,
144 input_type: &str,
145 ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
146 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
147 }
148
149 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel<T> {
152 let ndims = match model {
153 super::EMBED_ENGLISH_V3
154 | super::EMBED_MULTILINGUAL_V3
155 | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
156 super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
157 super::EMBED_ENGLISH_V2 => 4096,
158 super::EMBED_MULTILINGUAL_V2 => 768,
159 _ => 0,
160 };
161 EmbeddingModel::new(self.clone(), model, input_type, ndims)
162 }
163
164 pub fn embedding_model_with_ndims(
166 &self,
167 model: &str,
168 input_type: &str,
169 ndims: usize,
170 ) -> EmbeddingModel<T> {
171 EmbeddingModel::new(self.clone(), model, input_type, ndims)
172 }
173}
174
175impl Client<reqwest::Client> {
176 pub(crate) async fn eventsource(
177 &self,
178 req: reqwest::RequestBuilder,
179 ) -> Result<EventSource, CannotCloneRequestError> {
180 reqwest_eventsource::EventSource::new(req)
181 }
182
183 pub(crate) fn client(&self) -> &reqwest::Client {
184 &self.http_client
185 }
186}
187
188impl ProviderClient for Client<reqwest::Client> {
189 fn from_env() -> Self {
192 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
193 Self::new(&api_key)
194 }
195
196 fn from_val(input: crate::client::ProviderValue) -> Self {
197 let crate::client::ProviderValue::Simple(api_key) = input else {
198 panic!("Incorrect provider value type")
199 };
200 Self::new(&api_key)
201 }
202}
203
204impl CompletionClient for Client<reqwest::Client> {
205 type CompletionModel = CompletionModel<reqwest::Client>;
206
207 fn completion_model(&self, model: &str) -> Self::CompletionModel {
208 CompletionModel::new(self.clone(), model)
209 }
210}
211
212impl EmbeddingsClient for Client<reqwest::Client> {
213 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
214
215 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
216 self.embedding_model(model, "search_document")
217 }
218
219 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
220 self.embedding_model_with_ndims(model, "search_document", ndims)
221 }
222
223 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
224 self.embeddings(model, "search_document")
225 }
226}
227
228impl VerifyClient for Client<reqwest::Client> {
229 #[cfg_attr(feature = "worker", worker::send)]
230 async fn verify(&self) -> Result<(), VerifyError> {
231 let response = self
232 .http_client
233 .get("/v1/models")
234 .send()
235 .await
236 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
237
238 match response.status() {
239 reqwest::StatusCode::OK => Ok(()),
240 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
241 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
242 Err(VerifyError::ProviderError(response.text().await.map_err(
243 |e| VerifyError::HttpError(http_client::Error::Instance(e.into())),
244 )?))
245 }
246 _ => {
247 response
248 .error_for_status()
249 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
250 Ok(())
251 }
252 }
253 }
254}
255
256impl_conversion_traits!(
257 AsTranscription,
258 AsImageGeneration,
259 AsAudioGeneration for Client<T>
260);