rig/providers/cohere/
client.rs1use crate::{embeddings::EmbeddingsBuilder, Embed};
2
3use super::{CompletionModel, EmbeddingModel};
4use crate::client::{impl_conversion_traits, CompletionClient, EmbeddingsClient, ProviderClient};
5use serde::Deserialize;
6
7#[derive(Debug, Deserialize)]
8pub struct ApiErrorResponse {
9 pub message: String,
10}
11
12#[derive(Debug, Deserialize)]
13#[serde(untagged)]
14pub enum ApiResponse<T> {
15 Ok(T),
16 Err(ApiErrorResponse),
17}
18
19const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
23
24#[derive(Clone, Debug)]
25pub struct Client {
26 base_url: String,
27 http_client: reqwest::Client,
28}
29
30impl Client {
31 pub fn new(api_key: &str) -> Self {
32 Self::from_url(api_key, COHERE_API_BASE_URL)
33 }
34
35 pub fn from_url(api_key: &str, base_url: &str) -> Self {
36 Self {
37 base_url: base_url.to_string(),
38 http_client: reqwest::Client::builder()
39 .default_headers({
40 let mut headers = reqwest::header::HeaderMap::new();
41 headers.insert(
42 "Authorization",
43 format!("Bearer {api_key}")
44 .parse()
45 .expect("Bearer token should parse"),
46 );
47 headers
48 })
49 .build()
50 .expect("Cohere reqwest client should build"),
51 }
52 }
53
54 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
55 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
56 self.http_client.post(url)
57 }
58
59 pub fn embeddings<D: Embed>(
60 &self,
61 model: &str,
62 input_type: &str,
63 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
64 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
65 }
66
67 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
70 let ndims = match model {
71 super::EMBED_ENGLISH_V3
72 | super::EMBED_MULTILINGUAL_V3
73 | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
74 super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
75 super::EMBED_ENGLISH_V2 => 4096,
76 super::EMBED_MULTILINGUAL_V2 => 768,
77 _ => 0,
78 };
79 EmbeddingModel::new(self.clone(), model, input_type, ndims)
80 }
81
82 pub fn embedding_model_with_ndims(
84 &self,
85 model: &str,
86 input_type: &str,
87 ndims: usize,
88 ) -> EmbeddingModel {
89 EmbeddingModel::new(self.clone(), model, input_type, ndims)
90 }
91}
92
93impl ProviderClient for Client {
94 fn from_env() -> Self {
97 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
98 Self::new(&api_key)
99 }
100}
101
102impl CompletionClient for Client {
103 type CompletionModel = CompletionModel;
104
105 fn completion_model(&self, model: &str) -> Self::CompletionModel {
106 CompletionModel::new(self.clone(), model)
107 }
108}
109
110impl EmbeddingsClient for Client {
111 type EmbeddingModel = EmbeddingModel;
112
113 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
114 self.embedding_model(model, "search_document")
115 }
116
117 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
118 self.embedding_model_with_ndims(model, "search_document", ndims)
119 }
120
121 fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
122 self.embeddings(model, "search_document")
123 }
124}
125
126impl_conversion_traits!(
127 AsTranscription,
128 AsImageGeneration,
129 AsAudioGeneration for Client
130);