reagent-rs 0.2.9

A Rust library for building AI agents with MCP, custom tools and skills
Documentation
use std::{pin::Pin, sync::Arc};

use futures::Stream;

use crate::{
    services::llm::{
        models::{
            chat::{ChatRequest, ChatResponse, ChatStreamChunk},
            embedding::{EmbeddingsRequest, EmbeddingsResponse},
            errors::InferenceClientError,
        },
        SchemaSpec, StructuredOuputFormat,
    },
    ClientConfig,
};

use super::providers::{
    anthropic::AnthropicClient, mistral::MistralClient, ollama::OllamaClient, openai::OpenAiClient,
    openrouter::OpenRouterClient,
};

#[derive(Debug, Clone, Default)]
pub enum Provider {
    #[default]
    Ollama,
    OpenAi,
    Mistral,
    Anthropic,
    OpenRouter,
}

#[derive(Debug, Clone)]
enum ClientInner {
    Ollama(OllamaClient),
    OpenAi(OpenAiClient),
    Mistral(MistralClient),
    Anthropic(AnthropicClient),
    OpenRouter(OpenRouterClient),
}

#[derive(Clone, Debug)]
pub struct InferenceClient {
    config: ClientConfig,
    inner: Arc<ClientInner>,
}

impl InferenceClient {
    pub fn get_config(&self) -> &ClientConfig {
        &self.config
    }

    pub fn structured_output_format(
        &self,
        spec: &SchemaSpec,
    ) -> Result<serde_json::Value, InferenceClientError> {
        match self.get_config().provider {
            Some(Provider::Ollama) => Ok(OllamaClient::format(spec)),
            Some(Provider::OpenAi) => Ok(OpenAiClient::format(spec)),
            Some(Provider::OpenRouter) => Ok(OpenRouterClient::format(spec)),
            _ => Err(InferenceClientError::Unsupported(
                "Structured outputs not yet supported for this provider".into(),
            )),
        }
    }

    pub async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, InferenceClientError> {
        match &*self.inner {
            ClientInner::Ollama(c) => c.chat(req).await,
            ClientInner::OpenAi(c) => c.chat(req).await,
            ClientInner::Mistral(c) => c.chat(req).await,
            ClientInner::Anthropic(c) => c.chat(req).await,
            ClientInner::OpenRouter(c) => c.chat(req).await,
        }
    }

    pub async fn chat_stream(
        &self,
        req: ChatRequest,
    ) -> Result<
        Pin<Box<dyn Stream<Item = Result<ChatStreamChunk, InferenceClientError>> + Send + 'static>>,
        InferenceClientError,
    > {
        match &*self.inner {
            ClientInner::Ollama(c) => c.chat_stream(req).await,
            ClientInner::OpenAi(c) => c.chat_stream(req).await,
            ClientInner::Mistral(c) => c.chat_stream(req).await,
            ClientInner::Anthropic(c) => c.chat_stream(req).await,
            ClientInner::OpenRouter(c) => c.chat_stream(req).await,
        }
    }

    pub async fn embeddings(
        &self,
        req: EmbeddingsRequest,
    ) -> Result<EmbeddingsResponse, InferenceClientError> {
        match &*self.inner {
            ClientInner::Ollama(c) => c.embeddings(req).await,
            ClientInner::OpenAi(c) => c.embeddings(req).await,
            ClientInner::Mistral(c) => c.embeddings(req).await,
            ClientInner::Anthropic(c) => c.embeddings(req).await,
            ClientInner::OpenRouter(c) => c.embeddings(req).await,
        }
    }
}

impl TryFrom<ClientConfig> for InferenceClient {
    type Error = InferenceClientError;

    fn try_from(cfg: ClientConfig) -> Result<Self, Self::Error> {
        let config = cfg.clone();
        let Some(provider) = cfg.provider.clone() else {
            return Err(InferenceClientError::Config("Provider not defined".into()));
        };
        let inner = match provider {
            Provider::Ollama => ClientInner::Ollama(OllamaClient::new(cfg)?),
            Provider::OpenAi => ClientInner::OpenAi(OpenAiClient::new(cfg)?),
            Provider::Mistral => ClientInner::Mistral(MistralClient::new(cfg)?),
            Provider::Anthropic => ClientInner::Anthropic(AnthropicClient::new(cfg)?),
            Provider::OpenRouter => ClientInner::OpenRouter(OpenRouterClient::new(cfg)?),
        };
        Ok(Self {
            config,
            inner: Arc::new(inner),
        })
    }
}