reagent-rs 0.2.4

A Rust library for building AI agents with MCP & custom tools
Documentation
use std::{pin::Pin, sync::Arc};

use futures::Stream;

use crate::services::llm::models::{
    chat::{ChatRequest, ChatResponse, ChatStreamChunk},
    embedding::{EmbeddingsRequest, EmbeddingsResponse},
    errors::ModelClientError,
};

use super::providers::{
    anthropic::AnthropicClient,
    mistral::MistralClient,
    ollama::OllamaClient,
    openai::OpenAiClient,
    openrouter::OpenRouterClient,
};

#[derive(Debug, Clone, Default)]
pub enum Provider {
    #[default]
    Ollama,
    OpenAi,
    Mistral,
    Anthropic,
    OpenRouter,
}

#[derive(Debug, Clone, Default)]
pub struct ClientConfig {
    pub provider: Provider,
    pub base_url: Option<String>,
    pub api_key: Option<String>,
    pub organization: Option<String>,
    pub extra_headers: Option<std::collections::HashMap<String, String>>,
}

#[derive(Debug, Clone)]
enum ClientInner {
    Ollama(OllamaClient),
    OpenAi(OpenAiClient),
    Mistral(MistralClient),
    Anthropic(AnthropicClient),
    OpenRouter(OpenRouterClient),
}

#[derive(Clone, Debug)]
pub struct ModelClient {
    config: ClientConfig,
    inner: Arc<ClientInner>,
}

impl ModelClient {
    pub fn get_config(&self) -> ClientConfig {
        self.config.clone()
    }

    pub async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, ModelClientError> {
        match &*self.inner {
            ClientInner::Ollama(c) => c.chat(req).await,
            ClientInner::OpenAi(c) => c.chat(req).await,
            ClientInner::Mistral(c) => c.chat(req).await,
            ClientInner::Anthropic(c) => c.chat(req).await,
            ClientInner::OpenRouter(c) => c.chat(req).await,
        }
    }

    pub async fn chat_stream(
        &self,
        req: ChatRequest,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatStreamChunk, ModelClientError>> + Send + 'static>>, ModelClientError> {
        match &*self.inner {
            ClientInner::Ollama(c) => c.chat_stream(req).await,
            ClientInner::OpenAi(c) => c.chat_stream(req).await,
            ClientInner::Mistral(c) => c.chat_stream(req).await,
            ClientInner::Anthropic(c) => c.chat_stream(req).await,
            ClientInner::OpenRouter(c) => c.chat_stream(req).await,
        }
    }

    pub async fn embeddings(&self, req: EmbeddingsRequest) -> Result<EmbeddingsResponse, ModelClientError> {
        match &*self.inner {
            ClientInner::Ollama(c) => c.embeddings(req).await,
            ClientInner::OpenAi(c) => c.embeddings(req).await,
            ClientInner::Mistral(c) => c.embeddings(req).await,
            ClientInner::Anthropic(c) => c.embeddings(req).await,
            ClientInner::OpenRouter(c) => c.embeddings(req).await,
        }
    }
}

impl TryFrom<ClientConfig> for ModelClient {
    type Error = ModelClientError;

    fn try_from(cfg: ClientConfig) -> Result<Self, Self::Error> {
        let config = cfg.clone();
        let inner = match cfg.provider {
            Provider::Ollama => ClientInner::Ollama(OllamaClient::new(cfg)?),
            Provider::OpenAi => ClientInner::OpenAi(OpenAiClient::new(cfg)?),
            Provider::Mistral => ClientInner::Mistral(MistralClient::new(cfg)?),
            Provider::Anthropic => ClientInner::Anthropic(AnthropicClient::new(cfg)?),
            Provider::OpenRouter => ClientInner::OpenRouter(OpenRouterClient::new(cfg)?),
        };
        Ok(Self { 
            config,
            inner: Arc::new(inner) 
        })
    }
}