rig/providers/cohere/
client.rs1use crate::{Embed, embeddings::EmbeddingsBuilder};
2
3use super::{CompletionModel, EmbeddingModel};
4use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits};
5use serde::Deserialize;
6
7#[derive(Debug, Deserialize)]
8pub struct ApiErrorResponse {
9 pub message: String,
10}
11
12#[derive(Debug, Deserialize)]
13#[serde(untagged)]
14pub enum ApiResponse<T> {
15 Ok(T),
16 Err(ApiErrorResponse),
17}
18
19const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
23
24#[derive(Clone)]
25pub struct Client {
26 base_url: String,
27 api_key: String,
28 http_client: reqwest::Client,
29}
30
31impl std::fmt::Debug for Client {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("Client")
34 .field("base_url", &self.base_url)
35 .field("http_client", &self.http_client)
36 .field("api_key", &"<REDACTED>")
37 .finish()
38 }
39}
40
41impl Client {
42 pub fn new(api_key: &str) -> Self {
43 Self::from_url(api_key, COHERE_API_BASE_URL)
44 }
45
46 pub fn from_url(api_key: &str, base_url: &str) -> Self {
47 Self {
48 base_url: base_url.to_string(),
49 api_key: api_key.to_string(),
50 http_client: reqwest::Client::builder()
51 .build()
52 .expect("Cohere reqwest client should build"),
53 }
54 }
55
56 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
59 self.http_client = client;
60
61 self
62 }
63
64 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
65 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
66 self.http_client.post(url).bearer_auth(&self.api_key)
67 }
68
69 pub fn embeddings<D: Embed>(
70 &self,
71 model: &str,
72 input_type: &str,
73 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
74 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
75 }
76
77 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
80 let ndims = match model {
81 super::EMBED_ENGLISH_V3
82 | super::EMBED_MULTILINGUAL_V3
83 | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
84 super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
85 super::EMBED_ENGLISH_V2 => 4096,
86 super::EMBED_MULTILINGUAL_V2 => 768,
87 _ => 0,
88 };
89 EmbeddingModel::new(self.clone(), model, input_type, ndims)
90 }
91
92 pub fn embedding_model_with_ndims(
94 &self,
95 model: &str,
96 input_type: &str,
97 ndims: usize,
98 ) -> EmbeddingModel {
99 EmbeddingModel::new(self.clone(), model, input_type, ndims)
100 }
101}
102
103impl ProviderClient for Client {
104 fn from_env() -> Self {
107 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
108 Self::new(&api_key)
109 }
110}
111
112impl CompletionClient for Client {
113 type CompletionModel = CompletionModel;
114
115 fn completion_model(&self, model: &str) -> Self::CompletionModel {
116 CompletionModel::new(self.clone(), model)
117 }
118}
119
120impl EmbeddingsClient for Client {
121 type EmbeddingModel = EmbeddingModel;
122
123 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
124 self.embedding_model(model, "search_document")
125 }
126
127 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
128 self.embedding_model_with_ndims(model, "search_document", ndims)
129 }
130
131 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
132 self.embeddings(model, "search_document")
133 }
134}
135
136impl_conversion_traits!(
137 AsTranscription,
138 AsImageGeneration,
139 AsAudioGeneration for Client
140);