Skip to main content

rig_core/providers/cohere/
client.rs

1use crate::{
2    Embed,
3    client::{
4        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
5        ProviderClient,
6    },
7    embeddings::EmbeddingsBuilder,
8    http_client::{self, HttpClientExt},
9    wasm_compat::*,
10};
11
12use super::{CompletionModel, EmbeddingModel};
13use serde::Deserialize;
14
15// ================================================================
16// Main Cohere Client
17// ================================================================
18
19#[derive(Debug, Default, Clone, Copy)]
20pub struct CohereExt;
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct CohereBuilder;
24
25type CohereApiKey = BearerAuth;
26
27pub type Client<H = reqwest::Client> = client::Client<CohereExt, H>;
28pub type ClientBuilder<H = crate::markers::Missing> =
29    client::ClientBuilder<CohereBuilder, CohereApiKey, H>;
30
31impl Provider for CohereExt {
32    type Builder = CohereBuilder;
33    const VERIFY_PATH: &'static str = "/models";
34}
35
36impl<H> Capabilities<H> for CohereExt {
37    type Completion = Capable<CompletionModel<H>>;
38    type Embeddings = Capable<EmbeddingModel<H>>;
39    type Transcription = Nothing;
40    type ModelListing = Nothing;
41    #[cfg(feature = "image")]
42    type ImageGeneration = Nothing;
43
44    #[cfg(feature = "audio")]
45    type AudioGeneration = Nothing;
46    type Rerank = Nothing;
47}
48
49impl DebugExt for CohereExt {}
50
51impl ProviderBuilder for CohereBuilder {
52    type Extension<H>
53        = CohereExt
54    where
55        H: HttpClientExt;
56    type ApiKey = CohereApiKey;
57
58    const BASE_URL: &'static str = "https://api.cohere.ai";
59
60    fn build<H>(
61        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
62    ) -> http_client::Result<Self::Extension<H>>
63    where
64        H: HttpClientExt,
65    {
66        Ok(CohereExt)
67    }
68}
69
70impl ProviderClient for Client {
71    type Input = CohereApiKey;
72    type Error = crate::client::ProviderClientError;
73
74    fn from_env() -> Result<Self, Self::Error>
75    where
76        Self: Sized,
77    {
78        let key = crate::client::required_env_var("COHERE_API_KEY")?;
79        Self::new(key).map_err(Into::into)
80    }
81
82    fn from_val(input: Self::Input) -> Result<Self, Self::Error>
83    where
84        Self: Sized,
85    {
86        Self::new(input).map_err(Into::into)
87    }
88}
89
90#[derive(Debug, Deserialize)]
91pub struct ApiErrorResponse {
92    pub message: String,
93}
94
95#[derive(Debug, Deserialize)]
96#[serde(untagged)]
97pub enum ApiResponse<T> {
98    Ok(T),
99    Err(ApiErrorResponse),
100}
101
102impl<T> Client<T>
103where
104    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
105{
106    pub fn embeddings<D: Embed>(
107        &self,
108        model: impl Into<String>,
109        input_type: &str,
110    ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
111        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
112    }
113
114    /// Note: default embedding dimension of 0 will be used if model is not known.
115    /// If this is the case, it's better to use function `embedding_model_with_ndims`
116    pub fn embedding_model(&self, model: impl Into<String>, input_type: &str) -> EmbeddingModel<T> {
117        let model = model.into();
118        let ndims = super::model_dimensions_from_identifier(&model).unwrap_or_default();
119
120        EmbeddingModel::new(self.clone(), model, input_type, ndims)
121    }
122
123    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
124    pub fn embedding_model_with_ndims(
125        &self,
126        model: impl Into<String>,
127        input_type: &str,
128        ndims: usize,
129    ) -> EmbeddingModel<T> {
130        EmbeddingModel::new(self.clone(), model, input_type, ndims)
131    }
132}
133#[cfg(test)]
134mod tests {
135    #[test]
136    fn test_client_initialization() {
137        let _client =
138            crate::providers::cohere::Client::new("dummy-key").expect("Client::new() failed");
139        let _client_from_builder = crate::providers::cohere::Client::builder()
140            .api_key("dummy-key")
141            .build()
142            .expect("Client::builder() failed");
143    }
144}