use std::time::Duration;
use async_trait::async_trait;
use futures::stream::{self, BoxStream};
use reqwest::Client;
use serde_json::{Value, json};
use crate::domain::error::{ProviderError, Result, StygianError};
use crate::ports::{AIProvider, ProviderCapabilities};
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const DEFAULT_MODEL: &str = "qwen2.5:32b";
#[derive(Debug, Clone)]
pub struct OllamaConfig {
pub base_url: String,
pub model: String,
pub timeout: Duration,
}
impl OllamaConfig {
pub fn new() -> Self {
Self {
base_url: DEFAULT_BASE_URL.to_string(),
model: DEFAULT_MODEL.to_string(),
timeout: Duration::from_secs(300),
}
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
}
impl Default for OllamaConfig {
fn default() -> Self {
Self::new()
}
}
pub struct OllamaProvider {
config: OllamaConfig,
client: Client,
}
impl OllamaProvider {
pub fn new() -> Self {
Self::with_config(OllamaConfig::new())
}
pub fn with_config(config: OllamaConfig) -> Self {
#[allow(clippy::expect_used)]
let client = Client::builder()
.timeout(config.timeout)
.build()
.expect("Failed to build HTTP client");
Self { config, client }
}
fn api_url(&self) -> String {
format!("{}/api/generate", self.config.base_url)
}
fn build_body(&self, content: &str, schema: &Value) -> Value {
let prompt = format!(
"Extract structured data from the following content according to this JSON schema.\n\
Return ONLY valid JSON matching the schema, with no markdown, no code blocks, no extra text.\n\
Schema: {}\n\nContent:\n{}",
serde_json::to_string(schema).unwrap_or_default(),
content
);
json!({
"model": self.config.model,
"prompt": prompt,
"stream": false,
"format": "json"
})
}
fn parse_response(response: &Value) -> Result<Value> {
let text = response
.get("response")
.and_then(Value::as_str)
.ok_or_else(|| {
StygianError::Provider(ProviderError::ApiError(
"No response field in Ollama output".to_string(),
))
})?;
serde_json::from_str(text).map_err(|e| {
StygianError::Provider(ProviderError::ApiError(format!(
"Failed to parse Ollama JSON response: {e}"
)))
})
}
fn map_http_error(status: u16, body: &str) -> StygianError {
match status {
404 => StygianError::Provider(ProviderError::ModelUnavailable(format!(
"Model not found in Ollama: {body}"
))),
_ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
}
}
}
impl Default for OllamaProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AIProvider for OllamaProvider {
async fn extract(&self, content: String, schema: Value) -> Result<Value> {
let body = self.build_body(&content, &schema);
let url = self.api_url();
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| {
StygianError::Provider(ProviderError::ApiError(format!(
"Ollama request failed (is Ollama running?): {e}"
)))
})?;
let status = response.status().as_u16();
let text = response
.text()
.await
.map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
if status != 200 {
return Err(Self::map_http_error(status, &text));
}
let json_val: Value = serde_json::from_str(&text)
.map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
Self::parse_response(&json_val)
}
async fn stream_extract(
&self,
content: String,
schema: Value,
) -> Result<BoxStream<'static, Result<Value>>> {
let result = self.extract(content, schema).await;
Ok(Box::pin(stream::once(async move { result })))
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
streaming: true,
vision: false,
tool_use: false,
json_mode: true,
}
}
fn name(&self) -> &'static str {
"ollama"
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_name() {
assert_eq!(OllamaProvider::new().name(), "ollama");
}
#[test]
fn test_default() {
let p = OllamaProvider::default();
assert_eq!(p.config.model, DEFAULT_MODEL);
assert_eq!(p.config.base_url, DEFAULT_BASE_URL);
}
#[test]
fn test_capabilities_json_mode() {
let caps = OllamaProvider::new().capabilities();
assert!(caps.json_mode);
assert!(!caps.vision);
}
#[test]
fn test_api_url() {
let p = OllamaProvider::new();
assert_eq!(p.api_url(), "http://localhost:11434/api/generate");
}
#[test]
fn test_build_body_stream_false() {
let p = OllamaProvider::new();
let body = p.build_body("c", &json!({"type": "object"}));
assert_eq!(body.get("stream"), Some(&json!(false)));
assert_eq!(body.get("format").and_then(Value::as_str), Some("json"));
}
#[test]
fn test_parse_response_valid() -> Result<()> {
let resp = json!({"response": "{\"score\": 42}"});
let val = OllamaProvider::parse_response(&resp)?;
assert_eq!(val.get("score").and_then(Value::as_u64), Some(42));
Ok(())
}
#[test]
fn test_map_http_error_404() {
assert!(matches!(
OllamaProvider::map_http_error(404, "not found"),
StygianError::Provider(ProviderError::ModelUnavailable(_))
));
}
#[test]
fn test_config_builder() {
let config = OllamaConfig::new()
.with_model("llama3:latest")
.with_base_url("http://192.168.1.10:11434");
assert_eq!(config.model, "llama3:latest");
assert_eq!(config.base_url, "http://192.168.1.10:11434");
}
}