reagent-rs 0.2.4

A Rust library for building AI agents with MCP & custom tools
Documentation
use async_stream::try_stream;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::de::DeserializeOwned;
use std::{fmt, pin::Pin};
use tracing::{debug, error, info_span, trace, instrument, Instrument};

use crate::services::llm::models::chat::ChatStreamChunk;
use crate::services::llm::models::{
    chat::{ChatRequest, ChatResponse},
    embedding::{EmbeddingsRequest, EmbeddingsResponse},
    errors::ModelClientError,
};

#[derive(Debug, Clone)]
pub struct OllamaClient {
    pub client: Client,
    pub base_url: String,
}

impl OllamaClient {
    pub fn new(cfg: crate::services::llm::client::ClientConfig) -> Result<Self, ModelClientError> {
        let base_url = cfg
            .base_url
            .unwrap_or("http://localhost:11434".into());
        Ok(Self { 
            client: Client::new(), 
            base_url 
        })
    }

    #[instrument(name = "ollama.post", skip_all, fields(endpoint))]
    async fn post<T, R>(&self, endpoint: &str, request_body: &T) -> Result<R, ModelClientError>
    where
        T: serde::Serialize + fmt::Debug,
        R: DeserializeOwned + fmt::Debug,
    {
        let url = format!("{}{}", self.base_url, endpoint);
        let span = info_span!("http.request", %url);
        async {
            let response = self
                .client
                .post(&url)
                .json(request_body)
                .send()
                .await
                .map_err(|e| ModelClientError::Api(e.to_string()))?;

            let status = response.status();
            debug!(%status, "received response");

            if !status.is_success() {
                let error_text = response
                    .text()
                    .await
                    .unwrap_or_else(|_| "Failed to read error body".into());
                error!(%status, body = %error_text, "request failed");
                return Err(ModelClientError::Api(format!(
                    "Request failed: {status} - {error_text}"
                )));
            }

            let response_text = response
                .text()
                .await
                .map_err(|e| ModelClientError::Api(format!("Failed to read response text: {e}")))?;

            match serde_json::from_str::<R>(&response_text) {
                Ok(parsed) => {
                    trace!(?parsed, "deserialized response");
                    Ok(parsed)
                }
                Err(e) => {
                    error!(%e, raw = %response_text, "deserialization error");
                    Err(ModelClientError::Serialization(format!(
                        "Error decoding response body: {e}. Raw JSON was: '{response_text}'"
                    )))
                }
            }
        }
        .instrument(span)
        .await
    }

    #[instrument(name = "ollama.post_stream", skip_all, fields(endpoint))]
    async fn post_stream<T, R>(
        &self,
        endpoint: &str,
        body: &T,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<R, ModelClientError>> + Send + 'static>>, ModelClientError>
    where
        T: serde::Serialize + fmt::Debug,
        R: serde::de::DeserializeOwned + fmt::Debug + Send + 'static,
    {
        let url = format!("{}{}", self.base_url, endpoint);
        let resp = self
            .client
            .post(&url)
            .json(body)
            .send()
            .await
            .map_err(|e| ModelClientError::Api(e.to_string()))?;

        if !resp.status().is_success() {
            return Err(ModelClientError::Api(format!("Request failed: {resp:#?}")));
        }

        let byte_stream = resp.bytes_stream();
        let s = try_stream! {
            let mut buf = Vec::<u8>::new();
            futures::pin_mut!(byte_stream);
            while let Some(chunk) = byte_stream.next().await {
                let chunk = chunk.map_err(|e| ModelClientError::Request(e.to_string()))?;
                buf.extend_from_slice(&chunk);
                while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
                    let line: Vec<u8> = buf.drain(..=pos).collect();
                    let line = &line[..line.len() - 1];
                    if line.is_empty() { continue; }
                    let parsed: R = serde_json::from_slice(line)
                        .map_err(|e| ModelClientError::Serialization(e.to_string()))?;
                    yield parsed;
                }
            }
        };
        Ok(Box::pin(s))
    }

    pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ModelClientError> {
        self.post("/api/chat", &request).await
    }

    pub async fn chat_stream(
        &self,
        req: ChatRequest,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatStreamChunk, ModelClientError>> + Send + 'static>>, ModelClientError> {
        self.post_stream("/api/chat", &req).await
    }

    pub async fn embeddings(&self, request: EmbeddingsRequest) -> Result<EmbeddingsResponse, ModelClientError> {
        self.post("/api/embeddings", &request).await
    }
}