1use crate::{
2 Embed,
3 client::{VerifyClient, VerifyError},
4 embeddings::EmbeddingsBuilder,
5 http_client::{self, HttpClientExt},
6 wasm_compat::*,
7};
8
9use super::{CompletionModel, EmbeddingModel};
10use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits};
11use bytes::Bytes;
12use serde::Deserialize;
13
14#[derive(Debug, Deserialize)]
15pub struct ApiErrorResponse {
16 pub message: String,
17}
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(ApiErrorResponse),
24}
25
26const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
30
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32 api_key: &'a str,
33 base_url: &'a str,
34 http_client: T,
35}
36
37impl<'a> ClientBuilder<'a, reqwest::Client> {
38 pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> {
39 ClientBuilder {
40 api_key,
41 base_url: COHERE_API_BASE_URL,
42 http_client: Default::default(),
43 }
44 }
45}
46
47impl<'a, T> ClientBuilder<'a, T> {
48 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
49 ClientBuilder {
50 api_key,
51 base_url: COHERE_API_BASE_URL,
52 http_client,
53 }
54 }
55
56 pub fn with_client<U>(api_key: &str, http_client: U) -> ClientBuilder<'_, U> {
57 ClientBuilder {
58 api_key,
59 base_url: COHERE_API_BASE_URL,
60 http_client,
61 }
62 }
63
64 pub fn base_url(mut self, base_url: &'a str) -> ClientBuilder<'a, T> {
65 self.base_url = base_url;
66 self
67 }
68
69 pub fn build(self) -> Client<T> {
70 Client {
71 base_url: self.base_url.to_string(),
72 api_key: self.api_key.to_string(),
73 http_client: self.http_client,
74 }
75 }
76}
77
78#[derive(Clone)]
79pub struct Client<T = reqwest::Client> {
80 base_url: String,
81 api_key: String,
82 http_client: T,
83}
84
85impl<T> std::fmt::Debug for Client<T>
86where
87 T: HttpClientExt + std::fmt::Debug,
88{
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("Client")
91 .field("base_url", &self.base_url)
92 .field("http_client", &self.http_client)
93 .field("api_key", &"<REDACTED>")
94 .finish()
95 }
96}
97
98impl Client<reqwest::Client> {
99 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
100 ClientBuilder::new(api_key)
101 }
102
103 pub fn new(api_key: &str) -> Self {
108 ClientBuilder::new(api_key).build()
109 }
110
111 pub fn from_env() -> Self {
112 <Self as ProviderClient>::from_env()
113 }
114}
115
116impl<T> Client<T>
117where
118 T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
119{
120 pub(crate) fn req(
121 &self,
122 method: http_client::Method,
123 path: &str,
124 ) -> http_client::Result<http_client::Builder> {
125 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
126
127 http_client::with_bearer_auth(
128 http_client::Builder::new().method(method).uri(url),
129 &self.api_key,
130 )
131 }
132
133 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
134 self.req(http_client::Method::POST, path)
135 }
136
137 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
138 self.req(http_client::Method::GET, path)
139 }
140
141 pub fn http_client(&self) -> T {
142 self.http_client.clone()
143 }
144
145 pub(crate) async fn send<U, V>(
146 &self,
147 req: http_client::Request<U>,
148 ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
149 where
150 U: Into<Bytes> + Send,
151 V: From<Bytes> + Send + 'static,
152 {
153 self.http_client.send(req).await
154 }
155
156 pub fn embeddings<D: Embed>(
157 &self,
158 model: &str,
159 input_type: &str,
160 ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
161 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
162 }
163
164 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel<T> {
167 let ndims = match model {
168 super::EMBED_ENGLISH_V3
169 | super::EMBED_MULTILINGUAL_V3
170 | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
171 super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
172 super::EMBED_ENGLISH_V2 => 4096,
173 super::EMBED_MULTILINGUAL_V2 => 768,
174 _ => 0,
175 };
176 EmbeddingModel::new(self.clone(), model, input_type, ndims)
177 }
178
179 pub fn embedding_model_with_ndims(
181 &self,
182 model: &str,
183 input_type: &str,
184 ndims: usize,
185 ) -> EmbeddingModel<T> {
186 EmbeddingModel::new(self.clone(), model, input_type, ndims)
187 }
188}
189
190impl<T> ProviderClient for Client<T>
191where
192 T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
193{
194 fn from_env() -> Self {
197 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
198 ClientBuilder::new_with_client(&api_key, T::default()).build()
199 }
200
201 fn from_val(input: crate::client::ProviderValue) -> Self {
202 let crate::client::ProviderValue::Simple(api_key) = input else {
203 panic!("Incorrect provider value type")
204 };
205 ClientBuilder::new_with_client(&api_key, T::default()).build()
206 }
207}
208
209impl<T> CompletionClient for Client<T>
210where
211 T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
212{
213 type CompletionModel = CompletionModel<T>;
214
215 fn completion_model(&self, model: &str) -> Self::CompletionModel {
216 CompletionModel::new(self.clone(), model)
217 }
218}
219
220impl<T> EmbeddingsClient for Client<T>
221where
222 T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
223{
224 type EmbeddingModel = EmbeddingModel<T>;
225
226 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
227 self.embedding_model(model, "search_document")
228 }
229
230 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
231 self.embedding_model_with_ndims(model, "search_document", ndims)
232 }
233
234 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
235 self.embeddings(model, "search_document")
236 }
237}
238
239impl<T> VerifyClient for Client<T>
240where
241 T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
242{
243 #[cfg_attr(feature = "worker", worker::send)]
244 async fn verify(&self) -> Result<(), VerifyError> {
245 let req = self
246 .get("/models")?
247 .body(http_client::NoBody)
248 .map_err(|e| VerifyError::HttpError(e.into()))?;
249
250 let response = self.http_client.send::<_, Vec<u8>>(req).await?;
251 let status = response.status();
252 let body = http_client::text(response).await?;
253
254 match status {
255 reqwest::StatusCode::OK => Ok(()),
256 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
257 reqwest::StatusCode::INTERNAL_SERVER_ERROR => Err(VerifyError::ProviderError(body)),
258 _ => Err(VerifyError::ProviderError(body)),
259 }
260 }
261}
262
263impl_conversion_traits!(
264 AsTranscription,
265 AsImageGeneration,
266 AsAudioGeneration for Client<T>
267);