use crate::{
Embed,
client::{
self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
ProviderClient,
},
embeddings::EmbeddingsBuilder,
http_client::{self, HttpClientExt},
wasm_compat::*,
};
use super::{CompletionModel, EmbeddingModel};
use serde::Deserialize;
#[derive(Debug, Default, Clone, Copy)]
pub struct CohereExt;
#[derive(Debug, Default, Clone, Copy)]
pub struct CohereBuilder;
type CohereApiKey = BearerAuth;
pub type Client<H = reqwest::Client> = client::Client<CohereExt, H>;
pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<CohereBuilder, CohereApiKey, H>;
impl Provider for CohereExt {
type Builder = CohereBuilder;
const VERIFY_PATH: &'static str = "/models";
fn build<H>(
_: &client::ClientBuilder<Self::Builder, CohereApiKey, H>,
) -> http_client::Result<Self> {
Ok(Self)
}
}
impl<H> Capabilities<H> for CohereExt {
type Completion = Capable<CompletionModel<H>>;
type Embeddings = Capable<EmbeddingModel<H>>;
type Transcription = Nothing;
#[cfg(feature = "image")]
type ImageGeneration = Nothing;
#[cfg(feature = "audio")]
type AudioGeneration = Nothing;
}
impl DebugExt for CohereExt {}
impl ProviderBuilder for CohereBuilder {
type Output = CohereExt;
type ApiKey = CohereApiKey;
const BASE_URL: &'static str = "https://api.cohere.ai";
fn finish<H>(
&self,
builder: client::ClientBuilder<Self, CohereApiKey, H>,
) -> http_client::Result<client::ClientBuilder<Self, CohereApiKey, H>> {
Ok(builder)
}
}
impl ProviderClient for Client {
type Input = CohereApiKey;
fn from_env() -> Self
where
Self: Sized,
{
let key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
Self::new(key).unwrap()
}
fn from_val(input: Self::Input) -> Self
where
Self: Sized,
{
Self::new(input).unwrap()
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
impl<T> Client<T>
where
T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
{
pub fn embeddings<D: Embed>(
&self,
model: impl Into<String>,
input_type: &str,
) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
EmbeddingsBuilder::new(self.embedding_model(model, input_type))
}
pub fn embedding_model(&self, model: impl Into<String>, input_type: &str) -> EmbeddingModel<T> {
let model = model.into();
let ndims = super::model_dimensions_from_identifier(&model).unwrap_or_default();
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
pub fn embedding_model_with_ndims(
&self,
model: impl Into<String>,
input_type: &str,
ndims: usize,
) -> EmbeddingModel<T> {
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
}