Skip to main content

rig/providers/cohere/
client.rs

1use crate::{
2    Embed,
3    client::{
4        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
5        ProviderClient,
6    },
7    embeddings::EmbeddingsBuilder,
8    http_client::{self, HttpClientExt},
9    wasm_compat::*,
10};
11
12use super::{CompletionModel, EmbeddingModel};
13use serde::Deserialize;
14
15// ================================================================
16// Main Cohere Client
17// ================================================================
18
19#[derive(Debug, Default, Clone, Copy)]
20pub struct CohereExt;
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct CohereBuilder;
24
25type CohereApiKey = BearerAuth;
26
27pub type Client<H = reqwest::Client> = client::Client<CohereExt, H>;
28pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<CohereBuilder, CohereApiKey, H>;
29
30impl Provider for CohereExt {
31    type Builder = CohereBuilder;
32
33    const VERIFY_PATH: &'static str = "/models";
34
35    fn build<H>(
36        _: &client::ClientBuilder<Self::Builder, CohereApiKey, H>,
37    ) -> http_client::Result<Self> {
38        Ok(Self)
39    }
40}
41
42impl<H> Capabilities<H> for CohereExt {
43    type Completion = Capable<CompletionModel<H>>;
44    type Embeddings = Capable<EmbeddingModel<H>>;
45    type Transcription = Nothing;
46    type ModelListing = Nothing;
47    #[cfg(feature = "image")]
48    type ImageGeneration = Nothing;
49
50    #[cfg(feature = "audio")]
51    type AudioGeneration = Nothing;
52}
53
54impl DebugExt for CohereExt {}
55
56impl ProviderBuilder for CohereBuilder {
57    type Output = CohereExt;
58    type ApiKey = CohereApiKey;
59
60    const BASE_URL: &'static str = "https://api.cohere.ai";
61
62    fn finish<H>(
63        &self,
64        builder: client::ClientBuilder<Self, CohereApiKey, H>,
65    ) -> http_client::Result<client::ClientBuilder<Self, CohereApiKey, H>> {
66        Ok(builder)
67    }
68}
69
70impl ProviderClient for Client {
71    type Input = CohereApiKey;
72
73    fn from_env() -> Self
74    where
75        Self: Sized,
76    {
77        let key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
78        Self::new(key).unwrap()
79    }
80
81    fn from_val(input: Self::Input) -> Self
82    where
83        Self: Sized,
84    {
85        Self::new(input).unwrap()
86    }
87}
88
89#[derive(Debug, Deserialize)]
90pub struct ApiErrorResponse {
91    pub message: String,
92}
93
94#[derive(Debug, Deserialize)]
95#[serde(untagged)]
96pub enum ApiResponse<T> {
97    Ok(T),
98    Err(ApiErrorResponse),
99}
100
101impl<T> Client<T>
102where
103    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
104{
105    pub fn embeddings<D: Embed>(
106        &self,
107        model: impl Into<String>,
108        input_type: &str,
109    ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
110        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
111    }
112
113    /// Note: default embedding dimension of 0 will be used if model is not known.
114    /// If this is the case, it's better to use function `embedding_model_with_ndims`
115    pub fn embedding_model(&self, model: impl Into<String>, input_type: &str) -> EmbeddingModel<T> {
116        let model = model.into();
117        let ndims = super::model_dimensions_from_identifier(&model).unwrap_or_default();
118
119        EmbeddingModel::new(self.clone(), model, input_type, ndims)
120    }
121
122    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
123    pub fn embedding_model_with_ndims(
124        &self,
125        model: impl Into<String>,
126        input_type: &str,
127        ndims: usize,
128    ) -> EmbeddingModel<T> {
129        EmbeddingModel::new(self.clone(), model, input_type, ndims)
130    }
131}
132#[cfg(test)]
133mod tests {
134    #[test]
135    fn test_client_initialization() {
136        let _client: crate::providers::cohere::Client =
137            crate::providers::cohere::Client::new("dummy-key").expect("Client::new() failed");
138        let _client_from_builder: crate::providers::cohere::Client =
139            crate::providers::cohere::Client::builder()
140                .api_key("dummy-key")
141                .build()
142                .expect("Client::builder() failed");
143    }
144}