use crate::error::LlmError;
use crate::stream::{ChatStream, ChatStreamEvent};
use crate::types::{ResponseMetadata, Usage};
use crate::utils::streaming::{JsonEventConverter, StreamFactory};
use serde::Deserialize;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct OllamaStreamResponse {
model: Option<String>,
message: Option<OllamaMessage>,
done: Option<bool>,
total_duration: Option<u64>,
load_duration: Option<u64>,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct OllamaMessage {
role: Option<String>,
content: Option<String>,
tool_calls: Option<Vec<super::types::OllamaToolCall>>,
thinking: Option<String>,
}
#[derive(Clone)]
pub struct OllamaEventConverter {
stream_started: Arc<Mutex<bool>>,
}
impl Default for OllamaEventConverter {
fn default() -> Self {
Self::new()
}
}
impl OllamaEventConverter {
pub fn new() -> Self {
Self {
stream_started: Arc::new(Mutex::new(false)),
}
}
async fn convert_ollama_response_async(
&self,
response: OllamaStreamResponse,
) -> Vec<ChatStreamEvent> {
use crate::utils::streaming::EventBuilder;
let mut builder = EventBuilder::new();
if self.needs_stream_start().await {
let metadata = self.create_stream_start_metadata(&response);
builder = builder.add_stream_start(metadata);
}
if let Some(content) = self.extract_content(&response) {
builder = builder.add_content_delta(content, None);
}
if let Some(usage) = self.extract_usage(&response) {
builder = builder.add_usage_update(usage);
}
builder.build()
}
async fn needs_stream_start(&self) -> bool {
let mut started = self.stream_started.lock().await;
if !*started {
*started = true;
true
} else {
false
}
}
fn extract_content(&self, response: &OllamaStreamResponse) -> Option<String> {
response
.message
.as_ref()?
.content
.as_ref()
.filter(|content| !content.is_empty())
.cloned()
}
fn extract_usage(&self, response: &OllamaStreamResponse) -> Option<Usage> {
if response.done == Some(true)
&& let (Some(prompt_tokens), Some(completion_tokens)) =
(response.prompt_eval_count, response.eval_count)
{
return Some(Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
cached_tokens: None,
reasoning_tokens: None,
});
}
None
}
fn create_stream_start_metadata(&self, response: &OllamaStreamResponse) -> ResponseMetadata {
ResponseMetadata {
id: None, model: response.model.clone(),
created: Some(chrono::Utc::now()),
provider: "ollama".to_string(),
request_id: None,
}
}
}
impl JsonEventConverter for OllamaEventConverter {
fn convert_json<'a>(
&'a self,
json_data: &'a str,
) -> Pin<Box<dyn Future<Output = Vec<Result<ChatStreamEvent, LlmError>>> + Send + Sync + 'a>>
{
Box::pin(async move {
match serde_json::from_str::<OllamaStreamResponse>(json_data) {
Ok(ollama_response) => self
.convert_ollama_response_async(ollama_response)
.await
.into_iter()
.map(Ok)
.collect(),
Err(e) => {
vec![Err(LlmError::ParseError(format!(
"Failed to parse Ollama JSON: {e}"
)))]
}
}
})
}
}
#[derive(Clone)]
pub struct OllamaStreaming {
http_client: reqwest::Client,
}
impl OllamaStreaming {
pub fn new(http_client: reqwest::Client) -> Self {
Self { http_client }
}
pub async fn create_chat_stream(
self,
url: String,
headers: reqwest::header::HeaderMap,
body: crate::providers::ollama::types::OllamaChatRequest,
) -> Result<ChatStream, LlmError> {
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("Ollama API error {status}: {error_text}"),
details: None,
});
}
let converter = OllamaEventConverter::new();
StreamFactory::create_json_stream(response, converter).await
}
pub async fn create_completion_stream(
self,
url: String,
headers: reqwest::header::HeaderMap,
body: crate::providers::ollama::types::OllamaGenerateRequest,
) -> Result<ChatStream, LlmError> {
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("Ollama API error {status}: {error_text}"),
details: None,
});
}
let converter = OllamaEventConverter::new();
StreamFactory::create_json_stream(response, converter).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_ollama_streaming_conversion() {
let converter = OllamaEventConverter::new();
let json_data =
r#"{"model":"llama2","message":{"role":"assistant","content":"Hello"},"done":false}"#;
let result = converter.convert_json(json_data).await;
assert!(!result.is_empty());
let content_event = result
.iter()
.find(|event| matches!(event, Ok(ChatStreamEvent::ContentDelta { .. })));
if let Some(Ok(ChatStreamEvent::ContentDelta { delta, .. })) = content_event {
assert_eq!(delta, "Hello");
} else {
panic!("Expected ContentDelta event in results: {:?}", result);
}
}
#[tokio::test]
async fn test_ollama_stream_end() {
let converter = OllamaEventConverter::new();
let json_data = r#"{"model":"llama2","done":true,"prompt_eval_count":10,"eval_count":20}"#;
let result = converter.convert_json(json_data).await;
assert!(!result.is_empty());
let usage_event = result
.iter()
.find(|event| matches!(event, Ok(ChatStreamEvent::UsageUpdate { .. })));
if let Some(Ok(ChatStreamEvent::UsageUpdate { usage })) = usage_event {
assert_eq!(usage.prompt_tokens, 10);
assert_eq!(usage.completion_tokens, 20);
} else {
panic!("Expected UsageUpdate event in results: {:?}", result);
}
}
}