use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::adapter::{
blake3_hex, BoxStream, LlmAdapter, LlmError, LlmRequest, LlmResponse, StreamChunk,
};
use crate::ollama::{validate_endpoint_url, validate_model_ref, OllamaConfig};
#[derive(Debug, Clone)]
pub struct OllamaHttpAdapter {
config: OllamaConfig,
}
impl OllamaHttpAdapter {
pub fn new(config: OllamaConfig) -> Result<Self, LlmError> {
validate_endpoint_url(&config.endpoint_url).map_err(|e| match e {
LlmError::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
other => other,
})?;
Ok(Self { config })
}
}
#[derive(Debug, Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<OllamaMessage<'a>>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct OllamaMessage<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
#[serde(default)]
message: MessageField,
}
#[derive(Debug, Default, Deserialize)]
struct MessageField {
#[serde(default)]
content: String,
}
#[derive(Debug, Deserialize)]
struct StreamLine {
#[serde(default)]
message: MessageField,
#[serde(default)]
done: bool,
done_reason: Option<String>,
}
#[async_trait]
impl LlmAdapter for OllamaHttpAdapter {
fn adapter_id(&self) -> &'static str {
"ollama"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
validate_model_ref(&self.config.model)?;
let req = LlmRequest { model: self.config.model.clone(), ..req };
let config = self.config.clone();
let timeout_ms = req.timeout_ms;
let result = tokio::task::spawn_blocking(move || call_ollama(&config, &req, timeout_ms))
.await
.map_err(|e| LlmError::Transport(format!("spawn_blocking join error: {e}")))?;
result
}
fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
let req = LlmRequest { model: self.config.model.clone(), ..req };
validate_model_ref_and_stream(self.config.clone(), req)
}
}
fn call_ollama(
config: &OllamaConfig,
req: &LlmRequest,
timeout_ms: u64,
) -> Result<LlmResponse, LlmError> {
let url = format!("{}/api/chat", config.endpoint_url);
let messages: Vec<OllamaMessage<'_>> = req
.messages
.iter()
.map(|m| OllamaMessage {
role: m.role.as_str(),
content: &m.content,
})
.collect();
let ollama_model = req.model.split('@').next().unwrap_or(&req.model);
let body = ChatRequest {
model: ollama_model,
messages,
stream: false,
};
let timeout = Duration::from_millis(timeout_ms);
let agent = ureq::AgentBuilder::new().timeout(timeout).build();
let raw_response = agent
.post(&url)
.send_json(
serde_json::to_value(&body)
.map_err(|e| LlmError::Transport(format!("request serialization failed: {e}")))?,
)
.map_err(|err| map_ureq_error(err, timeout_ms))?;
let status = raw_response.status();
if status != 200 {
return Err(LlmError::Transport(format!("HTTP {status}")));
}
let response_text = raw_response
.into_string()
.map_err(|e| LlmError::Transport(format!("reading response body: {e}")))?;
const MAX_RESPONSE_BYTES: usize = 16 * 1024 * 1024; if response_text.len() > MAX_RESPONSE_BYTES {
return Err(LlmError::Transport(format!(
"ollama response body exceeds 16 MiB limit ({} bytes); refusing to store",
response_text.len()
)));
}
let parsed: ChatResponse = serde_json::from_str(&response_text)
.map_err(|e| LlmError::Parse(format!("ollama response parse: {e}")))?;
let text = parsed.message.content;
let raw_hash = blake3_hex(response_text.as_bytes());
Ok(LlmResponse {
text,
parsed_json: None,
model: config.model.clone(),
usage: None,
raw_hash,
})
}
fn validate_model_ref_and_stream(config: OllamaConfig, req: LlmRequest) -> BoxStream<'static> {
Box::pin(async_stream::stream! {
if let Err(e) = validate_model_ref(&config.model) {
yield Err(e);
return;
}
let timeout_ms = req.timeout_ms;
let result = tokio::task::spawn_blocking(move || {
call_ollama_streaming(&config, &req, timeout_ms)
})
.await;
match result {
Ok(chunks) => {
for chunk in chunks {
yield chunk;
}
}
Err(e) => yield Err(LlmError::Transport(format!("spawn_blocking join error: {e}"))),
}
})
}
fn call_ollama_streaming(
config: &OllamaConfig,
req: &LlmRequest,
timeout_ms: u64,
) -> Vec<Result<StreamChunk, LlmError>> {
let url = format!("{}/api/chat", config.endpoint_url);
let messages: Vec<OllamaMessage<'_>> = req
.messages
.iter()
.map(|m| OllamaMessage {
role: m.role.as_str(),
content: &m.content,
})
.collect();
let ollama_model = req.model.split('@').next().unwrap_or(&req.model);
let body = ChatRequest {
model: ollama_model,
messages,
stream: true,
};
let timeout = Duration::from_millis(timeout_ms);
let agent = ureq::AgentBuilder::new().timeout(timeout).build();
let body_value = match serde_json::to_value(&body) {
Ok(v) => v,
Err(e) => {
return vec![Err(LlmError::Transport(format!(
"request serialization failed: {e}"
)))]
}
};
let raw_response = match agent.post(&url).send_json(body_value) {
Ok(r) => r,
Err(err) => return vec![Err(map_ureq_error(err, timeout_ms))],
};
let status = raw_response.status();
if status != 200 {
return vec![Err(LlmError::Transport(format!("HTTP {status}")))];
}
let body_text = match raw_response.into_string() {
Ok(s) => s,
Err(e) => {
return vec![Err(LlmError::Transport(format!(
"reading streaming response body: {e}"
)))]
}
};
body_text
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| {
let parsed: StreamLine = serde_json::from_str(line)
.map_err(|e| LlmError::Parse(format!("ollama stream line parse: {e}")))?;
Ok(StreamChunk {
delta: parsed.message.content,
finish_reason: if parsed.done {
parsed.done_reason
} else {
None
},
})
})
.collect()
}
fn map_ureq_error(err: ureq::Error, timeout_ms: u64) -> LlmError {
match err {
ureq::Error::Transport(t) => {
let msg = t.to_string();
if is_timeout_message(&msg) {
LlmError::Timeout { timeout_ms }
} else {
LlmError::Transport(msg)
}
}
ureq::Error::Status(code, _) => LlmError::Transport(format!("HTTP {code}")),
}
}
fn is_timeout_message(msg: &str) -> bool {
let lower = msg.to_ascii_lowercase();
lower.contains("timed out") || lower.contains("deadline exceeded") || lower.contains("timeout")
}
use crate::adapter::LlmRole;
impl LlmRole {
fn as_str(self) -> &'static str {
match self {
LlmRole::User => "user",
LlmRole::Assistant => "assistant",
LlmRole::Tool => "tool",
}
}
}