llmkit-ollama 0.1.0

Ollama (local Llama/Mistral) provider adapter for llmkit-rs
Documentation
//! [`OllamaProvider`] — implements [`LlmProvider`] against a local Ollama server.

use std::time::Instant;

use async_trait::async_trait;
use llmkit_core::{
    ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
    LlmResult, TokenUsage,
};

use crate::types::{ChatResponseBody, EmbeddingsRequestBody, EmbeddingsResponseBody};
use crate::{chat, stream};

const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const DEFAULT_MODEL: &str = "llama3.1";

/// Ollama provider for local models (Llama, Mistral, …).
#[derive(Clone)]
pub struct OllamaProvider {
    http: reqwest::Client,
    base_url: String,
    model: String,
}

impl OllamaProvider {
    /// Construct against the default local server (`http://localhost:11434`).
    pub fn new() -> Self {
        Self {
            http: reqwest::Client::new(),
            base_url: DEFAULT_BASE_URL.to_string(),
            model: DEFAULT_MODEL.to_string(),
        }
    }

    /// Construct from the `OLLAMA_HOST` environment variable, if set.
    pub fn from_env() -> Self {
        let mut p = Self::new();
        if let Ok(host) = std::env::var("OLLAMA_HOST") {
            p.base_url = host;
        }
        p
    }

    /// Set the default model.
    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model = model.into();
        self
    }

    /// Override the base URL.
    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
        self.base_url = base_url.into();
        self
    }

    /// Provide a custom [`reqwest::Client`].
    pub fn with_client(mut self, client: reqwest::Client) -> Self {
        self.http = client;
        self
    }

    fn resolved_model(&self, req: &ChatRequest) -> String {
        req.model.clone().unwrap_or_else(|| self.model.clone())
    }
}

impl Default for OllamaProvider {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl LlmProvider for OllamaProvider {
    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        let model = self.resolved_model(&req);
        let body = chat::build_request(&req, model, false);

        let start = Instant::now();
        let resp = self
            .http
            .post(format!("{}/api/chat", self.base_url))
            .json(&body)
            .send()
            .await
            .map_err(map_reqwest_err)?;

        let resp = check_status(resp).await?;
        let parsed: ChatResponseBody = resp.json().await.map_err(map_reqwest_err)?;
        chat::map_response(parsed, start.elapsed().as_millis() as u64)
    }

    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        let model = self.resolved_model(&req);
        let body = chat::build_request(&req, model, true);

        let resp = self
            .http
            .post(format!("{}/api/chat", self.base_url))
            .json(&body)
            .send()
            .await
            .map_err(map_reqwest_err)?;

        let resp = check_status(resp).await?;
        Ok(stream::parse(resp))
    }

    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
        let model = req.model.clone().unwrap_or_else(|| self.model.clone());
        let body = EmbeddingsRequestBody { model, input: req.input };

        let resp = self
            .http
            .post(format!("{}/api/embed", self.base_url))
            .json(&body)
            .send()
            .await
            .map_err(map_reqwest_err)?;

        let resp = check_status(resp).await?;
        let parsed: EmbeddingsResponseBody = resp.json().await.map_err(map_reqwest_err)?;

        Ok(EmbedResponse {
            provider: "ollama".into(),
            model: parsed.model,
            embeddings: parsed.embeddings,
            usage: TokenUsage::new(parsed.prompt_eval_count.unwrap_or(0), 0),
        })
    }

    fn name(&self) -> &'static str {
        "ollama"
    }

    fn model(&self) -> &str {
        &self.model
    }
}

fn map_reqwest_err(e: reqwest::Error) -> LlmError {
    if e.is_timeout() {
        LlmError::Timeout
    } else if e.is_connect() {
        LlmError::Transport(format!("cannot reach Ollama server: {e}"))
    } else if e.is_decode() {
        LlmError::Serialization(e.to_string())
    } else {
        LlmError::Transport(e.to_string())
    }
}

async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
    let status = resp.status();
    if status.is_success() {
        return Ok(resp);
    }
    let code = status.as_u16();
    let message = resp.text().await.unwrap_or_default();
    Err(match code {
        404 => LlmError::InvalidRequest(format!("model not found or endpoint missing: {message}")),
        400 => LlmError::InvalidRequest(message),
        _ => LlmError::Provider { status: code, message },
    })
}