use std::time::Duration;
use async_trait::async_trait;
use futures::{StreamExt, stream};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::llm::{
BoxStream, CallOptions, ChatModel, Content, ContentPart, LlmError, Message, Role,
ToolDefinition,
};
use juncture_tracing::spans::attrs;
const OLLAMA_BASE_URL: &str = "http://localhost:11434";
#[derive(Clone, Debug)]
pub struct ChatOllama {
client: Client,
model: String,
base_url: String,
temperature: Option<f32>,
top_p: Option<f32>,
#[allow(dead_code, reason = "configured but not directly accessed")]
stream: bool,
}
impl ChatOllama {
#[must_use]
pub fn new(model: impl Into<String>) -> Self {
Self {
client: {
#[cfg(not(target_family = "wasm"))]
{
Client::builder()
.timeout(Duration::from_secs(300))
.build()
.expect("Failed to create HTTP client")
}
#[cfg(target_family = "wasm")]
{
Client::new()
}
},
model: model.into(),
base_url: OLLAMA_BASE_URL.to_string(),
temperature: None,
top_p: None,
stream: false,
}
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
#[must_use]
pub const fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl ChatModel for ChatOllama {
async fn invoke(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<Message, LlmError> {
let model = options
.and_then(|o| o.model_override.as_ref())
.unwrap_or(&self.model);
#[cfg(not(target_family = "wasm"))]
let span = tracing::info_span!(
"juncture.llm.call",
"juncture.llm.model" = %model,
"juncture.llm.provider" = "ollama",
"juncture.tokens.input" = tracing::field::Empty,
"juncture.tokens.output" = tracing::field::Empty,
"juncture.llm.has_tool_calls" = false,
"juncture.llm.stop_reason" = tracing::field::Empty,
);
#[cfg(not(target_family = "wasm"))]
let _enter = span.enter();
let api_messages: Vec<_> = messages
.iter()
.map(|m| OllamaMessage {
role: match m.role {
Role::System => "system",
Role::Human => "user",
Role::Ai => "assistant",
Role::Tool => "tool",
}
.to_string(),
content: extract_text_content(&m.content),
images: extract_images(&m.content),
})
.collect();
let request = OllamaRequest {
model: model.clone(),
messages: api_messages,
stream: false,
options: Some(OllamaOptions {
temperature: options.and_then(|o| o.temperature).or(self.temperature),
top_p: options.and_then(|o| o.top_p).or(self.top_p),
}),
};
#[cfg(not(target_family = "wasm"))]
let start = std::time::Instant::now();
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(LlmError::InvalidResponse(format!(
"HTTP {}: {}",
status.as_u16(),
response_text
)));
}
let api_response: OllamaResponse = serde_json::from_str(&response_text)
.map_err(|e| LlmError::InvalidResponse(format!("Failed to parse response: {e}")))?;
tracing::Span::current().record(attrs::LLM_HAS_TOOL_CALLS, false);
tracing::Span::current().record(attrs::LLM_STOP_REASON, "unknown");
tracing::debug!(
name: "juncture.llm.calls",
provider = "ollama",
model = %model,
);
#[cfg(not(target_family = "wasm"))]
tracing::debug!(
name: "juncture.llm.duration_ms",
duration_ms = start.elapsed().as_millis(),
model = %model,
);
#[cfg(not(target_family = "wasm"))]
{
let duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
let _ = juncture_core::pregel::try_report_llm_duration(duration_ms);
}
let _ = juncture_core::pregel::try_report_llm_call();
Ok(Message::ai_with_tool_calls(
api_response.message.content,
Vec::new(),
))
}
#[allow(
clippy::redundant_clone,
clippy::uninlined_format_args,
clippy::too_many_lines,
reason = "Complex SSE stream parsing logic with full Ollama protocol handling"
)]
fn stream(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> BoxStream<'_, Result<crate::llm::MessageChunk, LlmError>> {
let model = options
.and_then(|o| o.model_override.as_ref())
.unwrap_or(&self.model);
#[cfg(not(target_family = "wasm"))]
let span = tracing::info_span!(
"juncture.llm.call",
"juncture.llm.model" = %model,
"juncture.llm.provider" = "ollama",
);
#[cfg(not(target_family = "wasm"))]
let _enter = span.enter();
let api_messages: Vec<_> = messages
.iter()
.map(|m| OllamaMessage {
role: match m.role {
Role::System => "system",
Role::Human => "user",
Role::Ai => "assistant",
Role::Tool => "tool",
}
.to_string(),
content: extract_text_content(&m.content),
images: extract_images(&m.content),
})
.collect();
let request = OllamaRequest {
model: model.clone(),
messages: api_messages,
stream: true,
options: Some(OllamaOptions {
temperature: options.and_then(|o| o.temperature).or(self.temperature),
top_p: options.and_then(|o| o.top_p).or(self.top_p),
}),
};
let base_url = self.base_url.clone();
let client = self.client.clone();
Box::pin(stream::unfold(
(client, base_url, request, false, Vec::new()),
|(client, base_url, request, done, mut buffer)| async move {
if done {
return None;
}
let response = match client
.post(format!("{}/api/chat", base_url))
.header("content-type", "application/json")
.json(&request)
.send()
.await
{
Ok(r) => r,
Err(e) => {
return Some((
Err(LlmError::NetworkError(e)),
(client, base_url, request, true, buffer),
));
}
};
let status = response.status();
if !status.is_success() {
let response_text = match response.text().await {
Ok(t) => t,
Err(e) => {
return Some((
Err(LlmError::NetworkError(e)),
(client, base_url, request, true, buffer),
));
}
};
return Some((
Err(LlmError::InvalidResponse(format!(
"HTTP {}: {}",
status.as_u16(),
response_text
))),
(client, base_url, request, true, buffer),
));
}
let mut byte_stream = response.bytes_stream();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
return Some((
Err(LlmError::NetworkError(e)),
(client, base_url, request, true, buffer),
));
}
};
buffer.extend_from_slice(&chunk);
while let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') {
let line_bytes = buffer.drain(..=newline_pos).collect::<Vec<_>>();
let line = String::from_utf8_lossy(&line_bytes[..line_bytes.len() - 1]);
let line = line.trim();
if line.is_empty() {
continue;
}
if let Ok(ollama_response) =
serde_json::from_str::<OllamaStreamResponse>(line)
{
let chunk = crate::llm::MessageChunk {
content: ollama_response.message.content,
tool_call_chunks: Vec::new(),
usage_delta: None,
};
if ollama_response.done {
return None;
}
if !chunk.content.is_empty() {
return Some((
Ok(chunk),
(client, base_url, request, false, buffer),
));
}
}
}
}
None
},
))
}
fn bind_tools(&self, _tools: Vec<ToolDefinition>) -> Self {
self.clone()
}
fn model_name(&self) -> &str {
&self.model
}
}
#[allow(
clippy::match_same_arms,
reason = "Explicit handling for different content types"
)]
fn extract_text_content(content: &Content) -> String {
match content {
Content::Text(text) => text.clone(),
Content::MultiPart(parts) => parts
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::Thinking { text, .. } => Some(text.as_str()),
ContentPart::Image(_) => None,
})
.collect::<Vec<_>>()
.join(" "),
}
}
fn extract_images(content: &Content) -> Vec<String> {
match content {
Content::Text(_) => Vec::new(),
Content::MultiPart(parts) => parts
.iter()
.filter_map(|p| match p {
ContentPart::Image(img) => match &img.source {
crate::llm::ImageSource::Base64(b64) => Some(b64.clone()),
crate::llm::ImageSource::Url(_) => None,
},
_ => None,
})
.collect(),
}
}
#[derive(Debug, Serialize)]
struct OllamaRequest {
model: String,
messages: Vec<OllamaMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<OllamaOptions>,
}
#[derive(Debug, Serialize)]
struct OllamaMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
images: Vec<String>,
}
#[derive(Debug, Serialize)]
struct OllamaOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code, reason = "deserialization target, fields read indirectly")]
struct OllamaResponse {
message: OllamaResponseMessage,
#[serde(default)]
#[allow(dead_code, reason = "deserialization target, fields read indirectly")]
done: bool,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code, reason = "deserialization target, fields read indirectly")]
struct OllamaResponseMessage {
#[allow(dead_code, reason = "deserialization target, fields read indirectly")]
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct OllamaStreamResponse {
message: OllamaResponseMessage,
#[serde(default)]
done: bool,
}