rig_core/providers/cohere/
client.rs1use crate::{
2 Embed,
3 client::{
4 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
5 ProviderClient,
6 },
7 embeddings::EmbeddingsBuilder,
8 http_client::{self, HttpClientExt},
9 wasm_compat::*,
10};
11
12use super::{CompletionModel, EmbeddingModel};
13use serde::Deserialize;
14
15#[derive(Debug, Default, Clone, Copy)]
20pub struct CohereExt;
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct CohereBuilder;
24
25type CohereApiKey = BearerAuth;
26
27pub type Client<H = reqwest::Client> = client::Client<CohereExt, H>;
28pub type ClientBuilder<H = crate::markers::Missing> =
29 client::ClientBuilder<CohereBuilder, CohereApiKey, H>;
30
31impl Provider for CohereExt {
32 type Builder = CohereBuilder;
33 const VERIFY_PATH: &'static str = "/models";
34}
35
36impl<H> Capabilities<H> for CohereExt {
37 type Completion = Capable<CompletionModel<H>>;
38 type Embeddings = Capable<EmbeddingModel<H>>;
39 type Transcription = Nothing;
40 type ModelListing = Nothing;
41 #[cfg(feature = "image")]
42 type ImageGeneration = Nothing;
43
44 #[cfg(feature = "audio")]
45 type AudioGeneration = Nothing;
46}
47
48impl DebugExt for CohereExt {}
49
50impl ProviderBuilder for CohereBuilder {
51 type Extension<H>
52 = CohereExt
53 where
54 H: HttpClientExt;
55 type ApiKey = CohereApiKey;
56
57 const BASE_URL: &'static str = "https://api.cohere.ai";
58
59 fn build<H>(
60 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
61 ) -> http_client::Result<Self::Extension<H>>
62 where
63 H: HttpClientExt,
64 {
65 Ok(CohereExt)
66 }
67}
68
69impl ProviderClient for Client {
70 type Input = CohereApiKey;
71 type Error = crate::client::ProviderClientError;
72
73 fn from_env() -> Result<Self, Self::Error>
74 where
75 Self: Sized,
76 {
77 let key = crate::client::required_env_var("COHERE_API_KEY")?;
78 Self::new(key).map_err(Into::into)
79 }
80
81 fn from_val(input: Self::Input) -> Result<Self, Self::Error>
82 where
83 Self: Sized,
84 {
85 Self::new(input).map_err(Into::into)
86 }
87}
88
89#[derive(Debug, Deserialize)]
90pub struct ApiErrorResponse {
91 pub message: String,
92}
93
94#[derive(Debug, Deserialize)]
95#[serde(untagged)]
96pub enum ApiResponse<T> {
97 Ok(T),
98 Err(ApiErrorResponse),
99}
100
101impl<T> Client<T>
102where
103 T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
104{
105 pub fn embeddings<D: Embed>(
106 &self,
107 model: impl Into<String>,
108 input_type: &str,
109 ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
110 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
111 }
112
113 pub fn embedding_model(&self, model: impl Into<String>, input_type: &str) -> EmbeddingModel<T> {
116 let model = model.into();
117 let ndims = super::model_dimensions_from_identifier(&model).unwrap_or_default();
118
119 EmbeddingModel::new(self.clone(), model, input_type, ndims)
120 }
121
122 pub fn embedding_model_with_ndims(
124 &self,
125 model: impl Into<String>,
126 input_type: &str,
127 ndims: usize,
128 ) -> EmbeddingModel<T> {
129 EmbeddingModel::new(self.clone(), model, input_type, ndims)
130 }
131}
132#[cfg(test)]
133mod tests {
134 #[test]
135 fn test_client_initialization() {
136 let _client =
137 crate::providers::cohere::Client::new("dummy-key").expect("Client::new() failed");
138 let _client_from_builder = crate::providers::cohere::Client::builder()
139 .api_key("dummy-key")
140 .build()
141 .expect("Client::builder() failed");
142 }
143}