rig/providers/cohere/
client.rs1use crate::{Embed, embeddings::EmbeddingsBuilder};
2
3use super::{CompletionModel, EmbeddingModel};
4use crate::client::{
5 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits,
6};
7use serde::Deserialize;
8
9#[derive(Debug, Deserialize)]
10pub struct ApiErrorResponse {
11 pub message: String,
12}
13
14#[derive(Debug, Deserialize)]
15#[serde(untagged)]
16pub enum ApiResponse<T> {
17 Ok(T),
18 Err(ApiErrorResponse),
19}
20
21const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
25
26pub struct ClientBuilder<'a> {
27 api_key: &'a str,
28 base_url: &'a str,
29 http_client: Option<reqwest::Client>,
30}
31
32impl<'a> ClientBuilder<'a> {
33 pub fn new(api_key: &'a str) -> Self {
34 Self {
35 api_key,
36 base_url: COHERE_API_BASE_URL,
37 http_client: None,
38 }
39 }
40
41 pub fn base_url(mut self, base_url: &'a str) -> Self {
42 self.base_url = base_url;
43 self
44 }
45
46 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
47 self.http_client = Some(client);
48 self
49 }
50
51 pub fn build(self) -> Result<Client, ClientBuilderError> {
52 let http_client = if let Some(http_client) = self.http_client {
53 http_client
54 } else {
55 reqwest::Client::builder().build()?
56 };
57
58 Ok(Client {
59 base_url: self.base_url.to_string(),
60 api_key: self.api_key.to_string(),
61 http_client,
62 })
63 }
64}
65
66#[derive(Clone)]
67pub struct Client {
68 base_url: String,
69 api_key: String,
70 http_client: reqwest::Client,
71}
72
73impl std::fmt::Debug for Client {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("Client")
76 .field("base_url", &self.base_url)
77 .field("http_client", &self.http_client)
78 .field("api_key", &"<REDACTED>")
79 .finish()
80 }
81}
82
83impl Client {
84 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
95 ClientBuilder::new(api_key)
96 }
97
98 pub fn new(api_key: &str) -> Self {
103 Self::builder(api_key)
104 .build()
105 .expect("Cohere client should build")
106 }
107
108 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
109 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
110 self.http_client.post(url).bearer_auth(&self.api_key)
111 }
112
113 pub fn embeddings<D: Embed>(
114 &self,
115 model: &str,
116 input_type: &str,
117 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
118 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
119 }
120
121 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
124 let ndims = match model {
125 super::EMBED_ENGLISH_V3
126 | super::EMBED_MULTILINGUAL_V3
127 | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
128 super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
129 super::EMBED_ENGLISH_V2 => 4096,
130 super::EMBED_MULTILINGUAL_V2 => 768,
131 _ => 0,
132 };
133 EmbeddingModel::new(self.clone(), model, input_type, ndims)
134 }
135
136 pub fn embedding_model_with_ndims(
138 &self,
139 model: &str,
140 input_type: &str,
141 ndims: usize,
142 ) -> EmbeddingModel {
143 EmbeddingModel::new(self.clone(), model, input_type, ndims)
144 }
145}
146
147impl ProviderClient for Client {
148 fn from_env() -> Self {
151 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
152 Self::new(&api_key)
153 }
154
155 fn from_val(input: crate::client::ProviderValue) -> Self {
156 let crate::client::ProviderValue::Simple(api_key) = input else {
157 panic!("Incorrect provider value type")
158 };
159 Self::new(&api_key)
160 }
161}
162
163impl CompletionClient for Client {
164 type CompletionModel = CompletionModel;
165
166 fn completion_model(&self, model: &str) -> Self::CompletionModel {
167 CompletionModel::new(self.clone(), model)
168 }
169}
170
171impl EmbeddingsClient for Client {
172 type EmbeddingModel = EmbeddingModel;
173
174 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
175 self.embedding_model(model, "search_document")
176 }
177
178 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
179 self.embedding_model_with_ndims(model, "search_document", ndims)
180 }
181
182 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
183 self.embeddings(model, "search_document")
184 }
185}
186
187impl_conversion_traits!(
188 AsTranscription,
189 AsImageGeneration,
190 AsAudioGeneration for Client
191);