ollama-sdk 0.4.1

An idiomatic, unofficial Rust client for the Ollama API with support for streaming, tool calling, and custom transports.
Documentation
use std::sync::Arc;

#[cfg(feature = "metrics")]
use metrics::counter;
#[cfg(feature = "tracing")]
use tracing::instrument;

use crate::parser::GenericStreamParser;
use crate::tools::{DynTool, ToolRegistry};
use crate::transport::Transport;
use crate::types::chat::{
    ChatRequest, ChatResponse, ChatStream, ChatStreamEvent, SimpleChatRequest, StreamingChatRequest,
};
use crate::types::generate::{
    GenerateRequest, GenerateResponse, GenerateStream, GenerateStreamEvent, SimpleGenerateRequest,
    StreamingGenerateRequest,
};
use crate::types::{HttpRequest, ListModelsResponse, ListRunningModelsResponse};
use crate::{Error, OllamaClientBuilder, Result};

/// A client for interacting with the Ollama API.
///
/// Use [`OllamaClient::builder()`] to create a client builder with a default `reqwest` transport.
#[derive(Clone)]
pub struct OllamaClient {
    pub(crate) transport: Arc<dyn Transport + Send + Sync>,
    pub(crate) tool_registry: ToolRegistry,
}

impl OllamaClient {
    /// Returns a new [`OllamaClientBuilder`].
    pub fn builder() -> OllamaClientBuilder {
        OllamaClientBuilder::new()
    }

    /// Registers a dynamic tool with the client's tool registry.
    ///
    /// This allows the client to use the registered tool in tool-calling scenarios.
    ///
    /// # Arguments
    ///
    /// * `tool` - An instance of [`DynTool`] to be registered.
    ///
    /// # Errors
    ///
    /// Returns an [`Error::Tool`](variant@Error::Tool) if a tool with the same name is already registered.
    #[cfg_attr(feature = "tracing", instrument(skip(self, tool)))]
    pub fn register_tool(&mut self, tool: DynTool) -> Result<()> {
        self.tool_registry.register_tool(tool)
    }

    /// Unregisters a tool from the client's tool registry by its name.
    ///
    /// # Arguments
    ///
    /// * `name` - The name of the tool to unregister.
    ///
    /// # Errors
    ///
    /// Returns an [`Error::Tool`](variant@Error::Tool) if no tool with the given name is found.
    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
    pub fn unregister_tool(&mut self, name: &str) -> Result<()> {
        self.tool_registry.unregister_tool(name)
    }

    /// Sends a streaming chat request to the Ollama API.
    ///
    /// This method returns a [`ChatStream`] which can be used to asynchronously
    /// receive chat responses as they are generated by the model.
    ///
    /// # Arguments
    ///
    /// * `request` - The [`StreamingChatRequest`] containing the chat messages and model.
    #[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
    pub async fn chat_stream(&self, request: StreamingChatRequest) -> Result<ChatStream> {
        #[cfg(feature = "metrics")]
        counter!("ollama_client.chat_requests_total", "type" => "streaming").increment(1);

        let chat_request = ChatRequest::from(request);
        let request = HttpRequest::new("/api/chat").post().body(chat_request)?;

        let byte_stream = self.transport.send_http_stream_request(request).await?;
        let parser = GenericStreamParser::<_, ChatResponse, ChatStreamEvent>::new(byte_stream);

        Ok(ChatStream {
            inner: Box::pin(parser),
        })
    }

    /// Sends a non-streaming chat request to the Ollama API.
    ///
    /// This method waits for the complete response from the model before returning
    /// a [`ChatResponse`].
    ///
    /// # Arguments
    ///
    /// * `request` - The [`SimpleChatRequest`] containing the chat messages and model.
    #[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
    pub async fn chat_simple(&self, request: SimpleChatRequest) -> Result<ChatResponse> {
        #[cfg(feature = "metrics")]
        counter!("ollama_client.chat_requests_total", "type" => "non_streaming").increment(1);

        let chat_request = ChatRequest::from(request);
        let request = HttpRequest::new("/api/chat").post().body(chat_request)?;

        let response = self.transport.send_http_request(request).await?;

        match response.body {
            Some(bytes) => ChatResponse::from_bytes(bytes),
            None => Err(Error::Protocol("Missing response body".into())),
        }
    }

    /// Sends a streaming generate request to the Ollama API.
    ///
    /// This method returns a [`GenerateStream`] which can be used to asynchronously
    /// receive generation responses as they are produced by the model.
    ///
    /// # Arguments
    ///
    /// * `request` - The [`StreamingGenerateRequest`] containing the prompt and model.
    #[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
    pub async fn generate_stream(
        &self,
        request: StreamingGenerateRequest,
    ) -> Result<GenerateStream> {
        #[cfg(feature = "metrics")]
        counter!("ollama_client.generate_requests_total", "type" => "streaming").increment(1);

        let generate_request = GenerateRequest::from(request);
        let request = HttpRequest::new("/api/generate")
            .post()
            .body(generate_request)?;

        let byte_stream = self.transport.send_http_stream_request(request).await?;
        let parser =
            GenericStreamParser::<_, GenerateResponse, GenerateStreamEvent>::new(byte_stream);

        Ok(GenerateStream {
            inner: Box::pin(parser),
        })
    }

    /// Sends a non-streaming generate request to the Ollama API.
    ///
    /// This method waits for the complete response from the model before returning
    /// a [`GenerateResponse`].
    ///
    /// # Arguments
    ///
    /// * `request` - The [`SimpleGenerateRequest`] containing the prompt and model.
    #[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
    pub async fn generate_simple(
        &self,
        request: SimpleGenerateRequest,
    ) -> Result<GenerateResponse> {
        #[cfg(feature = "metrics")]
        counter!("ollama_client.generate_requests_total", "type" => "non_streaming").increment(1);

        let generate_request = GenerateRequest::from(request);
        let request = HttpRequest::new("/api/generate")
            .post()
            .body(generate_request)?;

        let response = self.transport.send_http_request(request).await?;

        match response.body {
            Some(bytes) => GenerateResponse::from_bytes(bytes),
            None => Err(Error::Protocol("Missing response body".into())),
        }
    }

    /// Lists all available models on the Ollama server.
    ///
    /// Returns a [`ListModelsResponse`] which consists of list of
    /// [`OllamaModel`](crate::types::OllamaModel).
    pub async fn list_models(&self) -> Result<ListModelsResponse> {
        let request = HttpRequest::new("/api/tags");

        let response = self.transport.send_http_request(request).await?;

        match response.body {
            Some(bytes) => ListModelsResponse::from_bytes(bytes),
            None => Err(Error::Protocol("Missing response body".into())),
        }
    }

    /// Lists all models that are currently running on the Ollama server.
    ///
    /// Returns a [`ListRunningModelsResponse`] which consists of list of
    /// [`OllamaRunningModel`](crate::types::OllamaRunningModel).
    pub async fn list_running_models(&self) -> Result<ListRunningModelsResponse> {
        let request = HttpRequest::new("/api/ps");

        let response = self.transport.send_http_request(request).await?;

        match response.body {
            Some(bytes) => ListRunningModelsResponse::from_bytes(bytes),
            None => Err(Error::Protocol("Missing response body".into())),
        }
    }
}