rig-core 0.35.0

An opinionated library for building LLM powered applications.
Documentation
use crate::{
    Embed,
    client::{
        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
        ProviderClient,
    },
    embeddings::EmbeddingsBuilder,
    http_client::{self, HttpClientExt},
    wasm_compat::*,
};

use super::{CompletionModel, EmbeddingModel};
use serde::Deserialize;

// ================================================================
// Main Cohere Client
// ================================================================

#[derive(Debug, Default, Clone, Copy)]
pub struct CohereExt;

#[derive(Debug, Default, Clone, Copy)]
pub struct CohereBuilder;

type CohereApiKey = BearerAuth;

pub type Client<H = reqwest::Client> = client::Client<CohereExt, H>;
pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<CohereBuilder, CohereApiKey, H>;

impl Provider for CohereExt {
    type Builder = CohereBuilder;
    const VERIFY_PATH: &'static str = "/models";
}

impl<H> Capabilities<H> for CohereExt {
    type Completion = Capable<CompletionModel<H>>;
    type Embeddings = Capable<EmbeddingModel<H>>;
    type Transcription = Nothing;
    type ModelListing = Nothing;
    #[cfg(feature = "image")]
    type ImageGeneration = Nothing;

    #[cfg(feature = "audio")]
    type AudioGeneration = Nothing;
}

impl DebugExt for CohereExt {}

impl ProviderBuilder for CohereBuilder {
    type Extension<H>
        = CohereExt
    where
        H: HttpClientExt;
    type ApiKey = CohereApiKey;

    const BASE_URL: &'static str = "https://api.cohere.ai";

    fn build<H>(
        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
    ) -> http_client::Result<Self::Extension<H>>
    where
        H: HttpClientExt,
    {
        Ok(CohereExt)
    }
}

impl ProviderClient for Client {
    type Input = CohereApiKey;

    fn from_env() -> Self
    where
        Self: Sized,
    {
        let key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
        Self::new(key).unwrap()
    }

    fn from_val(input: Self::Input) -> Self
    where
        Self: Sized,
    {
        Self::new(input).unwrap()
    }
}

#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
    pub message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
    Ok(T),
    Err(ApiErrorResponse),
}

impl<T> Client<T>
where
    T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static,
{
    pub fn embeddings<D: Embed>(
        &self,
        model: impl Into<String>,
        input_type: &str,
    ) -> EmbeddingsBuilder<EmbeddingModel<T>, D> {
        EmbeddingsBuilder::new(self.embedding_model(model, input_type))
    }

    /// Note: default embedding dimension of 0 will be used if model is not known.
    /// If this is the case, it's better to use function `embedding_model_with_ndims`
    pub fn embedding_model(&self, model: impl Into<String>, input_type: &str) -> EmbeddingModel<T> {
        let model = model.into();
        let ndims = super::model_dimensions_from_identifier(&model).unwrap_or_default();

        EmbeddingModel::new(self.clone(), model, input_type, ndims)
    }

    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
    pub fn embedding_model_with_ndims(
        &self,
        model: impl Into<String>,
        input_type: &str,
        ndims: usize,
    ) -> EmbeddingModel<T> {
        EmbeddingModel::new(self.clone(), model, input_type, ndims)
    }
}
#[cfg(test)]
mod tests {
    #[test]
    fn test_client_initialization() {
        let _client =
            crate::providers::cohere::Client::new("dummy-key").expect("Client::new() failed");
        let _client_from_builder = crate::providers::cohere::Client::builder()
            .api_key("dummy-key")
            .build()
            .expect("Client::builder() failed");
    }
}