use std::sync::Arc;
#[cfg(feature = "metrics")]
use metrics::counter;
#[cfg(feature = "tracing")]
use tracing::instrument;
use crate::parser::GenericStreamParser;
use crate::tools::{DynTool, ToolRegistry};
use crate::transport::Transport;
use crate::types::chat::{
ChatRequest, ChatResponse, ChatStream, ChatStreamEvent, SimpleChatRequest, StreamingChatRequest,
};
use crate::types::generate::{
GenerateRequest, GenerateResponse, GenerateStream, GenerateStreamEvent, SimpleGenerateRequest,
StreamingGenerateRequest,
};
use crate::types::{HttpRequest, ListModelsResponse, ListRunningModelsResponse};
use crate::{Error, OllamaClientBuilder, Result};
#[derive(Clone)]
pub struct OllamaClient {
pub(crate) transport: Arc<dyn Transport + Send + Sync>,
pub(crate) tool_registry: ToolRegistry,
}
impl OllamaClient {
pub fn builder() -> OllamaClientBuilder {
OllamaClientBuilder::new()
}
#[cfg_attr(feature = "tracing", instrument(skip(self, tool)))]
pub fn register_tool(&mut self, tool: DynTool) -> Result<()> {
self.tool_registry.register_tool(tool)
}
#[cfg_attr(feature = "tracing", instrument(skip(self)))]
pub fn unregister_tool(&mut self, name: &str) -> Result<()> {
self.tool_registry.unregister_tool(name)
}
#[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
pub async fn chat_stream(&self, request: StreamingChatRequest) -> Result<ChatStream> {
#[cfg(feature = "metrics")]
counter!("ollama_client.chat_requests_total", "type" => "streaming").increment(1);
let chat_request = ChatRequest::from(request);
let request = HttpRequest::new("/api/chat").post().body(chat_request)?;
let byte_stream = self.transport.send_http_stream_request(request).await?;
let parser = GenericStreamParser::<_, ChatResponse, ChatStreamEvent>::new(byte_stream);
Ok(ChatStream {
inner: Box::pin(parser),
})
}
#[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
pub async fn chat_simple(&self, request: SimpleChatRequest) -> Result<ChatResponse> {
#[cfg(feature = "metrics")]
counter!("ollama_client.chat_requests_total", "type" => "non_streaming").increment(1);
let chat_request = ChatRequest::from(request);
let request = HttpRequest::new("/api/chat").post().body(chat_request)?;
let response = self.transport.send_http_request(request).await?;
match response.body {
Some(bytes) => ChatResponse::from_bytes(bytes),
None => Err(Error::Protocol("Missing response body".into())),
}
}
#[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
pub async fn generate_stream(
&self,
request: StreamingGenerateRequest,
) -> Result<GenerateStream> {
#[cfg(feature = "metrics")]
counter!("ollama_client.generate_requests_total", "type" => "streaming").increment(1);
let generate_request = GenerateRequest::from(request);
let request = HttpRequest::new("/api/generate")
.post()
.body(generate_request)?;
let byte_stream = self.transport.send_http_stream_request(request).await?;
let parser =
GenericStreamParser::<_, GenerateResponse, GenerateStreamEvent>::new(byte_stream);
Ok(GenerateStream {
inner: Box::pin(parser),
})
}
#[cfg_attr(feature = "tracing", instrument(skip(self, request)))]
pub async fn generate_simple(
&self,
request: SimpleGenerateRequest,
) -> Result<GenerateResponse> {
#[cfg(feature = "metrics")]
counter!("ollama_client.generate_requests_total", "type" => "non_streaming").increment(1);
let generate_request = GenerateRequest::from(request);
let request = HttpRequest::new("/api/generate")
.post()
.body(generate_request)?;
let response = self.transport.send_http_request(request).await?;
match response.body {
Some(bytes) => GenerateResponse::from_bytes(bytes),
None => Err(Error::Protocol("Missing response body".into())),
}
}
pub async fn list_models(&self) -> Result<ListModelsResponse> {
let request = HttpRequest::new("/api/tags");
let response = self.transport.send_http_request(request).await?;
match response.body {
Some(bytes) => ListModelsResponse::from_bytes(bytes),
None => Err(Error::Protocol("Missing response body".into())),
}
}
pub async fn list_running_models(&self) -> Result<ListRunningModelsResponse> {
let request = HttpRequest::new("/api/ps");
let response = self.transport.send_http_request(request).await?;
match response.body {
Some(bytes) => ListRunningModelsResponse::from_bytes(bytes),
None => Err(Error::Protocol("Missing response body".into())),
}
}
}