Skip to main content

rig/providers/cohere/
client.rs

1use crate::{
2    Embed,
3    client::{
4        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
5        ProviderClient,
6    },
7    embeddings::EmbeddingsBuilder,
8    http_client::{self, HttpClientExt},
9    wasm_compat::*,
10};
11
12use super::{CompletionModel, EmbeddingModel};
13use serde::Deserialize;
14
15// ================================================================
16// Main Cohere Client
17// ================================================================
18
19#[derive(Debug, Default, Clone, Copy)]
20pub struct CohereExt;
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct CohereBuilder;
24
25type CohereApiKey = BearerAuth;
26
27pub type Client<H = reqwest::Client> = client::Client<CohereExt, H>;
28pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<CohereBuilder, CohereApiKey, H>;
29
30impl Provider for CohereExt {
31    type Builder = CohereBuilder;
32    const VERIFY_PATH: &'static str = "/models";
33}
34
35impl<H> Capabilities<H> for CohereExt {
36    type Completion = Capable<CompletionModel<H>>;
37    type Embeddings = Capable<EmbeddingModel<H>>;
38    type Transcription = Nothing;
39    type ModelListing = Nothing;
40    #[cfg(feature = "image")]
41    type ImageGeneration = Nothing;
42
43    #[cfg(feature = "audio")]
44    type AudioGeneration = Nothing;
45}
46
47impl DebugExt for CohereExt {}
48
49impl ProviderBuilder for CohereBuilder {
50    type Extension<H>
51        = CohereExt
52    where
53        H: HttpClientExt;
54    type ApiKey = CohereApiKey;
55
56    const BASE_URL: &'static str = "https://api.cohere.ai";
57
58    fn build<H>(
59        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
60    ) -> http_client::Result<Self::Extension<H>>
61    where
62        H: HttpClientExt,
63    {
64        Ok(CohereExt)
65    }
66}
67
68impl ProviderClient for Client {
69    type Input = CohereApiKey;
70
71    fn from_env() -> Self
72    where
73        Self: Sized,
74    {
75        let key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
76        Self::new(key).unwrap()
77    }
78
79    fn from_val(input: Self::Input) -> Self
80    where
81        Self: Sized,
82    {
83        Self::new(input).unwrap()
84    }
85}
86
87#[derive(Debug, Deserialize)]
88pub struct ApiErrorResponse {
89    pub message: String,
90}
91
92#[derive(Debug, Deserialize)]
93#[serde(untagged)]
94pub enum ApiResponse<T> {
95    Ok(T),
96    Err(ApiErrorResponse),
97}
98
99impl<T> Client<T>
100where
101    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
102{
103    pub fn embeddings<D: Embed>(
104        &self,
105        model: impl Into<String>,
106        input_type: &str,
107    ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
108        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
109    }
110
111    /// Note: default embedding dimension of 0 will be used if model is not known.
112    /// If this is the case, it's better to use function `embedding_model_with_ndims`
113    pub fn embedding_model(&self, model: impl Into<String>, input_type: &str) -> EmbeddingModel<T> {
114        let model = model.into();
115        let ndims = super::model_dimensions_from_identifier(&model).unwrap_or_default();
116
117        EmbeddingModel::new(self.clone(), model, input_type, ndims)
118    }
119
120    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
121    pub fn embedding_model_with_ndims(
122        &self,
123        model: impl Into<String>,
124        input_type: &str,
125        ndims: usize,
126    ) -> EmbeddingModel<T> {
127        EmbeddingModel::new(self.clone(), model, input_type, ndims)
128    }
129}
130#[cfg(test)]
131mod tests {
132    #[test]
133    fn test_client_initialization() {
134        let _client =
135            crate::providers::cohere::Client::new("dummy-key").expect("Client::new() failed");
136        let _client_from_builder = crate::providers::cohere::Client::builder()
137            .api_key("dummy-key")
138            .build()
139            .expect("Client::builder() failed");
140    }
141}