rig/providers/cohere/
client.rs1use crate::{
2 Embed,
3 client::{
4 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
5 ProviderClient,
6 },
7 embeddings::EmbeddingsBuilder,
8 http_client::{self, HttpClientExt},
9 wasm_compat::*,
10};
11
12use super::{CompletionModel, EmbeddingModel};
13use serde::Deserialize;
14
15#[derive(Debug, Default, Clone, Copy)]
20pub struct CohereExt;
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct CohereBuilder;
24
25type CohereApiKey = BearerAuth;
26
27pub type Client<H = reqwest::Client> = client::Client<CohereExt, H>;
28pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<CohereBuilder, CohereApiKey, H>;
29
30impl Provider for CohereExt {
31 type Builder = CohereBuilder;
32 const VERIFY_PATH: &'static str = "/models";
33}
34
35impl<H> Capabilities<H> for CohereExt {
36 type Completion = Capable<CompletionModel<H>>;
37 type Embeddings = Capable<EmbeddingModel<H>>;
38 type Transcription = Nothing;
39 type ModelListing = Nothing;
40 #[cfg(feature = "image")]
41 type ImageGeneration = Nothing;
42
43 #[cfg(feature = "audio")]
44 type AudioGeneration = Nothing;
45}
46
47impl DebugExt for CohereExt {}
48
49impl ProviderBuilder for CohereBuilder {
50 type Extension<H>
51 = CohereExt
52 where
53 H: HttpClientExt;
54 type ApiKey = CohereApiKey;
55
56 const BASE_URL: &'static str = "https://api.cohere.ai";
57
58 fn build<H>(
59 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
60 ) -> http_client::Result<Self::Extension<H>>
61 where
62 H: HttpClientExt,
63 {
64 Ok(CohereExt)
65 }
66}
67
68impl ProviderClient for Client {
69 type Input = CohereApiKey;
70 type Error = crate::client::ProviderClientError;
71
72 fn from_env() -> Result<Self, Self::Error>
73 where
74 Self: Sized,
75 {
76 let key = crate::client::required_env_var("COHERE_API_KEY")?;
77 Self::new(key).map_err(Into::into)
78 }
79
80 fn from_val(input: Self::Input) -> Result<Self, Self::Error>
81 where
82 Self: Sized,
83 {
84 Self::new(input).map_err(Into::into)
85 }
86}
87
88#[derive(Debug, Deserialize)]
89pub struct ApiErrorResponse {
90 pub message: String,
91}
92
93#[derive(Debug, Deserialize)]
94#[serde(untagged)]
95pub enum ApiResponse<T> {
96 Ok(T),
97 Err(ApiErrorResponse),
98}
99
100impl<T> Client<T>
101where
102 T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
103{
104 pub fn embeddings<D: Embed>(
105 &self,
106 model: impl Into<String>,
107 input_type: &str,
108 ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
109 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
110 }
111
112 pub fn embedding_model(&self, model: impl Into<String>, input_type: &str) -> EmbeddingModel<T> {
115 let model = model.into();
116 let ndims = super::model_dimensions_from_identifier(&model).unwrap_or_default();
117
118 EmbeddingModel::new(self.clone(), model, input_type, ndims)
119 }
120
121 pub fn embedding_model_with_ndims(
123 &self,
124 model: impl Into<String>,
125 input_type: &str,
126 ndims: usize,
127 ) -> EmbeddingModel<T> {
128 EmbeddingModel::new(self.clone(), model, input_type, ndims)
129 }
130}
131#[cfg(test)]
132mod tests {
133 #[test]
134 fn test_client_initialization() {
135 let _client =
136 crate::providers::cohere::Client::new("dummy-key").expect("Client::new() failed");
137 let _client_from_builder = crate::providers::cohere::Client::builder()
138 .api_key("dummy-key")
139 .build()
140 .expect("Client::builder() failed");
141 }
142}