Skip to main content

agentkit_provider_ollama/
lib.rs

1//! Ollama model adapter for the agentkit agent loop.
2//!
3//! This crate provides [`OllamaAdapter`] and [`OllamaConfig`] for connecting
4//! the agent loop to a local [Ollama](https://ollama.ai) instance via its
5//! OpenAI-compatible chat completions endpoint. It is built on the generic
6//! [`agentkit_adapter_completions`] crate.
7//!
8//! No API key is required — Ollama runs locally and does not authenticate
9//! requests by default.
10//!
11//! # Quick start
12//!
13//! ```rust,ignore
14//! use agentkit_loop::{Agent, SessionConfig};
15//! use agentkit_provider_ollama::{OllamaAdapter, OllamaConfig};
16//!
17//! #[tokio::main]
18//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//!     // Ollama must be running locally (e.g. `ollama serve`).
20//!     let config = OllamaConfig::new("llama3.1:8b");
21//!     let adapter = OllamaAdapter::new(config)?;
22//!
23//!     let agent = Agent::builder()
24//!         .model(adapter)
25//!         .build()?;
26//!
27//!     let mut driver = agent
28//!         .start(SessionConfig::new("demo"))
29//!         .await?;
30//!     Ok(())
31//! }
32//! ```
33
34use agentkit_adapter_completions::{
35    CompletionsAdapter, CompletionsError, CompletionsProvider, CompletionsSession, CompletionsTurn,
36};
37use agentkit_loop::{LoopError, ModelAdapter, SessionConfig};
38use async_trait::async_trait;
39use serde::Serialize;
40use thiserror::Error;
41
42const DEFAULT_ENDPOINT: &str = "http://localhost:11434/v1/chat/completions";
43
44/// Configuration for connecting to a local Ollama instance.
45///
46/// No API key is needed — Ollama runs without authentication by default.
47/// Build one with [`OllamaConfig::new`] for explicit values, or
48/// [`OllamaConfig::from_env`] to read from environment variables.
49///
50/// # Example
51///
52/// ```rust,no_run
53/// use agentkit_provider_ollama::OllamaConfig;
54///
55/// let config = OllamaConfig::new("llama3.1:8b")
56///     .with_temperature(0.0)
57///     .with_num_predict(4096);
58/// ```
59#[derive(Clone, Debug)]
60pub struct OllamaConfig {
61    /// Model name as known to Ollama, e.g. `"llama3.1:8b"` or `"mistral"`.
62    pub model: String,
63    /// Chat completions endpoint URL. Defaults to `http://localhost:11434/v1/chat/completions`.
64    pub base_url: String,
65    /// Sampling temperature (0.0 = deterministic, higher = more creative).
66    pub temperature: Option<f32>,
67    /// Maximum number of tokens to generate (Ollama's equivalent of `max_completion_tokens`).
68    pub num_predict: Option<u32>,
69    /// Limits the next token selection to the top K most probable tokens.
70    pub top_k: Option<u32>,
71    /// Nucleus sampling parameter.
72    pub top_p: Option<f32>,
73    /// Whether the model is allowed to emit multiple tool calls in a
74    /// single turn. Omitted from the request when `None`.
75    pub parallel_tool_calls: Option<bool>,
76    /// Request SSE streaming responses. Defaults to `true`.
77    pub streaming: bool,
78    /// Whether the loaded chat template enforces strict
79    /// `user`/`assistant` role alternation. Set to `true` when running
80    /// Mistral-/Mixtral-/Llama-template models locally; the adapter then
81    /// merges adjacent user messages before sending.
82    pub strict_alternating_roles: bool,
83}
84
85impl OllamaConfig {
86    /// Creates a new configuration with the given model name.
87    pub fn new(model: impl Into<String>) -> Self {
88        Self {
89            model: model.into(),
90            base_url: DEFAULT_ENDPOINT.into(),
91            temperature: None,
92            num_predict: None,
93            top_k: None,
94            top_p: None,
95            parallel_tool_calls: None,
96            streaming: true,
97            strict_alternating_roles: false,
98        }
99    }
100
101    /// Overrides the default chat completions endpoint URL.
102    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
103        self.base_url = url.into();
104        self
105    }
106
107    /// Sets the sampling temperature (0.0 for deterministic output).
108    pub fn with_temperature(mut self, v: f32) -> Self {
109        self.temperature = Some(v);
110        self
111    }
112
113    /// Sets the maximum number of tokens to generate.
114    pub fn with_num_predict(mut self, v: u32) -> Self {
115        self.num_predict = Some(v);
116        self
117    }
118
119    /// Limits the next token selection to the top K most probable tokens.
120    pub fn with_top_k(mut self, v: u32) -> Self {
121        self.top_k = Some(v);
122        self
123    }
124
125    /// Sets the nucleus sampling parameter.
126    pub fn with_top_p(mut self, v: f32) -> Self {
127        self.top_p = Some(v);
128        self
129    }
130
131    /// Sets whether the model may emit multiple tool calls in a single turn.
132    pub fn with_parallel_tool_calls(mut self, flag: bool) -> Self {
133        self.parallel_tool_calls = Some(flag);
134        self
135    }
136
137    /// Toggles SSE streaming of model responses. Default: true.
138    pub fn with_streaming(mut self, flag: bool) -> Self {
139        self.streaming = flag;
140        self
141    }
142
143    /// Enable strict `user`/`assistant` role alternation for chat
144    /// templates that require it (notably Mistral, Mixtral, Llama). The
145    /// adapter merges adjacent user-role messages before sending. Same
146    /// rejection mode as vLLM-served Mistral; see
147    /// <https://github.com/vllm-project/vllm/issues/6862>.
148    pub fn with_strict_alternating_roles(mut self, flag: bool) -> Self {
149        self.strict_alternating_roles = flag;
150        self
151    }
152
153    /// Builds a configuration from environment variables.
154    ///
155    /// | Variable | Required | Default |
156    /// |---|---|---|
157    /// | `OLLAMA_MODEL` | yes | -- |
158    /// | `OLLAMA_BASE_URL` | no | `http://localhost:11434/v1/chat/completions` |
159    pub fn from_env() -> Result<Self, OllamaError> {
160        let model =
161            std::env::var("OLLAMA_MODEL").map_err(|_| OllamaError::MissingEnv("OLLAMA_MODEL"))?;
162
163        let mut config = Self::new(model);
164
165        if let Ok(url) = std::env::var("OLLAMA_BASE_URL") {
166            config = config.with_base_url(url);
167        }
168
169        Ok(config)
170    }
171}
172
173/// Request parameters serialized into the Ollama request body.
174#[derive(Clone, Debug, Serialize)]
175pub struct OllamaRequestConfig {
176    pub model: String,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub temperature: Option<f32>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub num_predict: Option<u32>,
181    #[serde(skip_serializing_if = "Option::is_none")]
182    pub top_k: Option<u32>,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub top_p: Option<f32>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub parallel_tool_calls: Option<bool>,
187}
188
189/// The Ollama provider, implementing [`CompletionsProvider`].
190#[derive(Clone, Debug)]
191pub struct OllamaProvider {
192    base_url: String,
193    streaming: bool,
194    strict_alternating_roles: bool,
195    request_config: OllamaRequestConfig,
196}
197
198impl From<OllamaConfig> for OllamaProvider {
199    fn from(config: OllamaConfig) -> Self {
200        Self {
201            base_url: config.base_url,
202            streaming: config.streaming,
203            strict_alternating_roles: config.strict_alternating_roles,
204            request_config: OllamaRequestConfig {
205                model: config.model,
206                temperature: config.temperature,
207                num_predict: config.num_predict,
208                top_k: config.top_k,
209                top_p: config.top_p,
210                parallel_tool_calls: config.parallel_tool_calls,
211            },
212        }
213    }
214}
215
216impl CompletionsProvider for OllamaProvider {
217    type Config = OllamaRequestConfig;
218
219    fn provider_name(&self) -> &str {
220        "Ollama"
221    }
222    fn endpoint_url(&self) -> &str {
223        &self.base_url
224    }
225    fn config(&self) -> &OllamaRequestConfig {
226        &self.request_config
227    }
228
229    fn preprocess_request(
230        &self,
231        builder: agentkit_http::HttpRequestBuilder,
232    ) -> agentkit_http::HttpRequestBuilder {
233        builder.header(
234            "User-Agent",
235            concat!("agentkit-provider-ollama/", env!("CARGO_PKG_VERSION")),
236        )
237    }
238
239    fn streaming(&self) -> bool {
240        self.streaming
241    }
242
243    fn requires_alternating_roles(&self) -> bool {
244        self.strict_alternating_roles
245    }
246}
247
248/// Model adapter that connects the agentkit agent loop to a local Ollama instance.
249///
250/// # Example
251///
252/// ```rust,no_run
253/// use agentkit_loop::Agent;
254/// use agentkit_provider_ollama::{OllamaAdapter, OllamaConfig};
255///
256/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
257/// let adapter = OllamaAdapter::new(
258///     OllamaConfig::new("llama3.1:8b").with_temperature(0.0),
259/// )?;
260///
261/// let agent = Agent::builder()
262///     .model(adapter)
263///     .build()?;
264/// # Ok(())
265/// # }
266/// ```
267#[derive(Clone)]
268pub struct OllamaAdapter(CompletionsAdapter<OllamaProvider>);
269
270/// An active session with a local Ollama instance.
271pub type OllamaSession = CompletionsSession<OllamaProvider>;
272
273/// A completed turn from a local Ollama instance.
274pub type OllamaTurn = CompletionsTurn;
275
276impl OllamaAdapter {
277    /// Creates a new adapter from the given configuration.
278    pub fn new(config: OllamaConfig) -> Result<Self, OllamaError> {
279        let provider = OllamaProvider::from(config);
280        Ok(Self(CompletionsAdapter::new(provider)?))
281    }
282}
283
284#[async_trait]
285impl ModelAdapter for OllamaAdapter {
286    type Session = OllamaSession;
287
288    async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
289        self.0.start_session(config).await
290    }
291}
292
293/// Errors produced by the Ollama adapter.
294#[derive(Debug, Error)]
295pub enum OllamaError {
296    /// A required environment variable is not set.
297    #[error("missing environment variable {0}")]
298    MissingEnv(&'static str),
299
300    /// An error from the generic completions adapter.
301    #[error(transparent)]
302    Completions(#[from] CompletionsError),
303}