rig/providers/cohere/
client.rs

1use crate::{embeddings::EmbeddingsBuilder, Embed};
2
3use super::{CompletionModel, EmbeddingModel};
4use crate::client::{impl_conversion_traits, CompletionClient, EmbeddingsClient, ProviderClient};
5use serde::Deserialize;
6
7#[derive(Debug, Deserialize)]
8pub struct ApiErrorResponse {
9    pub message: String,
10}
11
12#[derive(Debug, Deserialize)]
13#[serde(untagged)]
14pub enum ApiResponse<T> {
15    Ok(T),
16    Err(ApiErrorResponse),
17}
18
19// ================================================================
20// Main Cohere Client
21// ================================================================
22const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
23
24#[derive(Clone, Debug)]
25pub struct Client {
26    base_url: String,
27    http_client: reqwest::Client,
28}
29
30impl Client {
31    pub fn new(api_key: &str) -> Self {
32        Self::from_url(api_key, COHERE_API_BASE_URL)
33    }
34
35    pub fn from_url(api_key: &str, base_url: &str) -> Self {
36        Self {
37            base_url: base_url.to_string(),
38            http_client: reqwest::Client::builder()
39                .default_headers({
40                    let mut headers = reqwest::header::HeaderMap::new();
41                    headers.insert(
42                        "Authorization",
43                        format!("Bearer {api_key}")
44                            .parse()
45                            .expect("Bearer token should parse"),
46                    );
47                    headers
48                })
49                .build()
50                .expect("Cohere reqwest client should build"),
51        }
52    }
53
54    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
55        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
56        self.http_client.post(url)
57    }
58
59    pub fn embeddings<D: Embed>(
60        &self,
61        model: &str,
62        input_type: &str,
63    ) -> EmbeddingsBuilder<EmbeddingModel, D> {
64        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
65    }
66
67    /// Note: default embedding dimension of 0 will be used if model is not known.
68    /// If this is the case, it's better to use function `embedding_model_with_ndims`
69    pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
70        let ndims = match model {
71            super::EMBED_ENGLISH_V3
72            | super::EMBED_MULTILINGUAL_V3
73            | super::EMBED_ENGLISH_LIGHT_V2 => 1024,
74            super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
75            super::EMBED_ENGLISH_V2 => 4096,
76            super::EMBED_MULTILINGUAL_V2 => 768,
77            _ => 0,
78        };
79        EmbeddingModel::new(self.clone(), model, input_type, ndims)
80    }
81
82    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
83    pub fn embedding_model_with_ndims(
84        &self,
85        model: &str,
86        input_type: &str,
87        ndims: usize,
88    ) -> EmbeddingModel {
89        EmbeddingModel::new(self.clone(), model, input_type, ndims)
90    }
91}
92
93impl ProviderClient for Client {
94    /// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
95    /// Panics if the environment variable is not set.
96    fn from_env() -> Self {
97        let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
98        Self::new(&api_key)
99    }
100}
101
102impl CompletionClient for Client {
103    type CompletionModel = CompletionModel;
104
105    fn completion_model(&self, model: &str) -> Self::CompletionModel {
106        CompletionModel::new(self.clone(), model)
107    }
108}
109
110impl EmbeddingsClient for Client {
111    type EmbeddingModel = EmbeddingModel;
112
113    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
114        self.embedding_model(model, "search_document")
115    }
116
117    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
118        self.embedding_model_with_ndims(model, "search_document", ndims)
119    }
120
121    fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<Self::EmbeddingModel, D> {
122        self.embeddings(model, "search_document")
123    }
124}
125
126impl_conversion_traits!(
127    AsTranscription,
128    AsImageGeneration,
129    AsAudioGeneration for Client
130);