rig/providers/cohere/
client.rs1use crate::{
2 Embed,
3 client::{VerifyClient, VerifyError},
4 embeddings::EmbeddingsBuilder,
5};
6
7use super::{CompletionModel, EmbeddingModel};
8use crate::client::{
9 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits,
10};
11use serde::Deserialize;
12
13#[derive(Debug, Deserialize)]
14pub struct ApiErrorResponse {
15 pub message: String,
16}
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20pub enum ApiResponse<T> {
21 Ok(T),
22 Err(ApiErrorResponse),
23}
24
25const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
29
30pub struct ClientBuilder<'a> {
31 api_key: &'a str,
32 base_url: &'a str,
33 http_client: Option<reqwest::Client>,
34}
35
36impl<'a> ClientBuilder<'a> {
37 pub fn new(api_key: &'a str) -> Self {
38 Self {
39 api_key,
40 base_url: COHERE_API_BASE_URL,
41 http_client: None,
42 }
43 }
44
45 pub fn base_url(mut self, base_url: &'a str) -> Self {
46 self.base_url = base_url;
47 self
48 }
49
50 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
51 self.http_client = Some(client);
52 self
53 }
54
55 pub fn build(self) -> Result<Client, ClientBuilderError> {
56 let http_client = if let Some(http_client) = self.http_client {
57 http_client
58 } else {
59 reqwest::Client::builder().build()?
60 };
61
62 Ok(Client {
63 base_url: self.base_url.to_string(),
64 api_key: self.api_key.to_string(),
65 http_client,
66 })
67 }
68}
69
70#[derive(Clone)]
71pub struct Client {
72 base_url: String,
73 api_key: String,
74 http_client: reqwest::Client,
75}
76
77impl std::fmt::Debug for Client {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("Client")
80 .field("base_url", &self.base_url)
81 .field("http_client", &self.http_client)
82 .field("api_key", &"<REDACTED>")
83 .finish()
84 }
85}
86
87impl Client {
88 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
99 ClientBuilder::new(api_key)
100 }
101
102 pub fn new(api_key: &str) -> Self {
107 Self::builder(api_key)
108 .build()
109 .expect("Cohere client should build")
110 }
111
112 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
113 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
114 self.http_client.post(url).bearer_auth(&self.api_key)
115 }
116
117 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
118 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
119 self.http_client.get(url).bearer_auth(&self.api_key)
120 }
121
122 pub fn embeddings<D: Embed>(
123 &self,
124 model: &str,
125 input_type: &str,
126 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
127 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
128 }
129
130 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
133 let ndims = match model {
134 super::EMBED_ENGLISH_V3
135 | super::EMBED_MULTILINGUAL_V3
136 | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
137 super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
138 super::EMBED_ENGLISH_V2 => 4096,
139 super::EMBED_MULTILINGUAL_V2 => 768,
140 _ => 0,
141 };
142 EmbeddingModel::new(self.clone(), model, input_type, ndims)
143 }
144
145 pub fn embedding_model_with_ndims(
147 &self,
148 model: &str,
149 input_type: &str,
150 ndims: usize,
151 ) -> EmbeddingModel {
152 EmbeddingModel::new(self.clone(), model, input_type, ndims)
153 }
154}
155
156impl ProviderClient for Client {
157 fn from_env() -> Self {
160 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
161 Self::new(&api_key)
162 }
163
164 fn from_val(input: crate::client::ProviderValue) -> Self {
165 let crate::client::ProviderValue::Simple(api_key) = input else {
166 panic!("Incorrect provider value type")
167 };
168 Self::new(&api_key)
169 }
170}
171
172impl CompletionClient for Client {
173 type CompletionModel = CompletionModel;
174
175 fn completion_model(&self, model: &str) -> Self::CompletionModel {
176 CompletionModel::new(self.clone(), model)
177 }
178}
179
180impl EmbeddingsClient for Client {
181 type EmbeddingModel = EmbeddingModel;
182
183 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
184 self.embedding_model(model, "search_document")
185 }
186
187 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
188 self.embedding_model_with_ndims(model, "search_document", ndims)
189 }
190
191 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
192 self.embeddings(model, "search_document")
193 }
194}
195
196impl VerifyClient for Client {
197 #[cfg_attr(feature = "worker", worker::send)]
198 async fn verify(&self) -> Result<(), VerifyError> {
199 let response = self.get("/v1/models").send().await?;
200 match response.status() {
201 reqwest::StatusCode::OK => Ok(()),
202 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
203 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
204 Err(VerifyError::ProviderError(response.text().await?))
205 }
206 _ => {
207 response.error_for_status()?;
208 Ok(())
209 }
210 }
211 }
212}
213
214impl_conversion_traits!(
215 AsTranscription,
216 AsImageGeneration,
217 AsAudioGeneration for Client
218);