use std::env;
use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use super::error::BackendError;
use super::openai_compat::{OpenAICompatConfig, OpenAICompatibleBackend};
use super::{Backend, Capability, ChatRequest, ChatResponse, ChatStream};
const OLLAMA_HOST_ENV: &str = "OLLAMA_HOST";
const API_KEY_ENV: &str = "OLLAMA_API_KEY";
pub struct OllamaBackend {
inner: OpenAICompatibleBackend,
}
impl OllamaBackend {
pub fn from_env() -> Self {
let api_key = env::var(API_KEY_ENV).ok();
let host = env::var(OLLAMA_HOST_ENV).ok();
let mut backend = Self::with_api_key(api_key);
if let Some(host) = host {
backend = backend.with_base_url(host);
}
backend
}
pub fn with_api_key(api_key: Option<String>) -> Self {
Self {
inner: OpenAICompatibleBackend::new(OpenAICompatConfig::ollama(), api_key),
}
}
pub fn local() -> Self {
Self::with_api_key(None)
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.inner = self.inner.with_base_url(base_url);
self
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.inner = self.inner.with_default_model(model);
self
}
pub fn inner(&self) -> &OpenAICompatibleBackend {
&self.inner
}
}
impl Default for OllamaBackend {
fn default() -> Self {
Self::from_env()
}
}
#[async_trait]
impl Backend for OllamaBackend {
fn name(&self) -> &str {
self.inner.name()
}
fn default_model(&self) -> &str {
self.inner.default_model()
}
async fn complete(&self, request: ChatRequest) -> Result<ChatResponse, BackendError> {
self.inner.complete(request).await
}
async fn stream(&self, request: ChatRequest) -> Result<ChatStream, BackendError> {
self.inner.stream(request).await
}
fn count_tokens(&self, model: &str, text: &str) -> usize {
self.inner.count_tokens(model, text)
}
fn supports(&self, capability: Capability, model: &str) -> bool {
match capability {
Capability::Vision => is_known_multimodal(model),
Capability::LockedParams => false,
other => self.inner.supports(other, model),
}
}
}
fn is_known_multimodal(model: &str) -> bool {
let lc = model.to_lowercase();
lc.contains("llava")
|| lc.contains("bakllava")
|| lc.contains("llama3.2-vision")
|| lc.contains("llama-3.2-vision")
|| lc.contains("qwen2-vl")
|| lc.contains("qwen2.5-vl")
|| lc.contains("minicpm-v")
}
pub fn from_env() -> OllamaBackend {
OllamaBackend::from_env()
}
pub fn local() -> OllamaBackend {
OllamaBackend::local()
}
pub fn with_api_key(api_key: Option<String>) -> OllamaBackend {
OllamaBackend::with_api_key(api_key)
}
#[allow(dead_code)]
type OllamaChatStream =
Pin<Box<dyn Stream<Item = Result<crate::backends::ChatChunk, BackendError>> + Send>>;
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::Message;
#[test]
fn local_constructs_ollama_backend() {
let b = OllamaBackend::local();
assert_eq!(b.name(), "ollama");
assert_eq!(b.default_model(), "llama3.1:8b");
}
#[test]
fn module_factory_local_works() {
let b = local();
assert_eq!(b.name(), "ollama");
}
#[test]
fn module_factory_from_env_works() {
let b = from_env();
assert_eq!(b.name(), "ollama");
}
#[test]
fn module_factory_with_api_key_explicit() {
let b = with_api_key(Some("proxy-token".into()));
assert_eq!(b.name(), "ollama");
}
#[test]
fn with_default_model_overrides() {
let b = OllamaBackend::local().with_default_model("qwen2.5:14b");
assert_eq!(b.default_model(), "qwen2.5:14b");
}
#[test]
fn with_base_url_overrides_for_test_fixtures() {
let _b = OllamaBackend::local().with_base_url("http://remote-host:11435");
}
#[test]
fn inner_accessor_returns_compat_backend() {
let b = OllamaBackend::local();
assert_eq!(b.inner().name(), "ollama");
}
#[test]
fn supports_vision_for_llava_family() {
let b = OllamaBackend::local();
assert!(b.supports(Capability::Vision, "llava"));
assert!(b.supports(Capability::Vision, "llava:7b"));
assert!(b.supports(Capability::Vision, "llava:13b"));
assert!(b.supports(Capability::Vision, "llava-llama3:8b"));
}
#[test]
fn supports_vision_for_bakllava() {
let b = OllamaBackend::local();
assert!(b.supports(Capability::Vision, "bakllava"));
assert!(b.supports(Capability::Vision, "bakllava:7b"));
}
#[test]
fn supports_vision_for_llama_3_2_vision() {
let b = OllamaBackend::local();
assert!(b.supports(Capability::Vision, "llama3.2-vision:11b"));
assert!(b.supports(Capability::Vision, "llama3.2-vision:90b"));
assert!(b.supports(Capability::Vision, "llama-3.2-vision"));
}
#[test]
fn supports_vision_for_qwen_vl() {
let b = OllamaBackend::local();
assert!(b.supports(Capability::Vision, "qwen2-vl:7b"));
assert!(b.supports(Capability::Vision, "qwen2.5-vl:7b"));
}
#[test]
fn supports_vision_for_minicpm_v() {
let b = OllamaBackend::local();
assert!(b.supports(Capability::Vision, "minicpm-v:8b"));
}
#[test]
fn does_not_support_vision_for_text_only_models() {
let b = OllamaBackend::local();
assert!(!b.supports(Capability::Vision, "llama3.1:8b"));
assert!(!b.supports(Capability::Vision, "llama3.1:70b"));
assert!(!b.supports(Capability::Vision, "mistral-small:24b"));
assert!(!b.supports(Capability::Vision, "qwen2.5:14b"));
assert!(!b.supports(Capability::Vision, "phi-4"));
assert!(!b.supports(Capability::Vision, "deepseek-r1:32b"));
}
#[test]
fn vision_is_case_insensitive() {
let b = OllamaBackend::local();
assert!(b.supports(Capability::Vision, "LLaVA:7b"));
assert!(b.supports(Capability::Vision, "LLAMA3.2-VISION:11b"));
}
#[test]
fn supports_streaming_tooluse_structured_via_base() {
let b = OllamaBackend::local();
let any_model = "llama3.1:8b";
assert!(b.supports(Capability::Streaming, any_model));
assert!(b.supports(Capability::ToolUse, any_model));
assert!(b.supports(Capability::StructuredOutput, any_model));
}
#[test]
fn does_not_support_anthropic_or_gemini_only_caps() {
let b = OllamaBackend::local();
let any_model = "llama3.1:8b";
assert!(!b.supports(Capability::PromptCaching, any_model));
assert!(!b.supports(Capability::SafetySettings, any_model));
assert!(!b.supports(Capability::LockedParams, any_model));
}
#[test]
fn count_tokens_uses_estimate_for_ollama_models() {
let b = OllamaBackend::local();
assert_eq!(b.count_tokens("llama3.1:8b", "ABCDEFGH"), 2);
}
#[tokio::test]
async fn stream_delegates_to_base_real_sse_implementation() {
let b = OllamaBackend::local().with_base_url("http://127.0.0.1:1");
match b.stream(ChatRequest::default()).await {
Err(BackendError::Generic { ref message, .. }) => {
assert!(
message.contains("streaming transport failure")
|| message.contains("transport"),
"unexpected message: {message}"
);
}
Err(other) => panic!("expected Generic, got {other:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[tokio::test]
async fn complete_without_api_key_does_not_return_auth_error() {
let b = OllamaBackend::local()
.with_base_url("http://127.0.0.1:0")
.with_default_model("llama3.1:8b");
let inner = b.inner();
let _ = inner; let err = b
.complete(ChatRequest {
messages: vec![Message::user("hi")],
..Default::default()
})
.await
.unwrap_err();
assert!(
!matches!(err, BackendError::Auth { .. }),
"Ollama must not require an API key; got Auth error: {err:?}"
);
}
#[test]
fn default_constructs_via_from_env() {
let b = OllamaBackend::default();
assert_eq!(b.name(), "ollama");
}
#[test]
fn is_known_multimodal_recognises_all_documented_families() {
for model in &[
"llava",
"llava:7b",
"bakllava:7b",
"llama3.2-vision:11b",
"llama-3.2-vision",
"qwen2-vl:7b",
"qwen2.5-vl:7b",
"minicpm-v:8b",
] {
assert!(is_known_multimodal(model), "{model} should be multimodal");
}
}
#[test]
fn is_known_multimodal_rejects_text_only_families() {
for model in &[
"llama3.1:8b",
"mistral:7b",
"qwen2.5:14b",
"phi-4",
"deepseek-r1:32b",
"gemma3:12b",
] {
assert!(!is_known_multimodal(model), "{model} should not be multimodal");
}
}
}