use async_trait::async_trait;
use crate::error::TokenOptError;
use crate::ports::SummarizationPort;
#[derive(Debug, Clone)]
pub struct OllamaSummarizer {
base_url: String,
model: String,
client: reqwest::Client,
}
impl OllamaSummarizer {
#[must_use]
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
base_url: base_url.into().trim_end_matches('/').to_string(),
model: model.into(),
client: reqwest::Client::new(),
}
}
#[must_use]
pub fn with_client(
base_url: impl Into<String>,
model: impl Into<String>,
client: reqwest::Client,
) -> Self {
Self {
base_url: base_url.into().trim_end_matches('/').to_string(),
model: model.into(),
client,
}
}
}
#[derive(Debug, serde::Serialize)]
struct GenerateRequest<'a> {
model: &'a str,
prompt: &'a str,
system: &'a str,
stream: bool,
}
#[derive(Debug, serde::Deserialize)]
struct GenerateResponse {
response: String,
}
#[async_trait]
impl SummarizationPort for OllamaSummarizer {
async fn summarize(&self, system_prompt: &str, text: &str) -> Result<String, TokenOptError> {
let url = format!("{}/api/generate", self.base_url);
let request = GenerateRequest {
model: &self.model,
prompt: text,
system: system_prompt,
stream: false,
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| TokenOptError::InferenceError(format!("Ollama request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|_| "<no body>".to_string());
return Err(TokenOptError::InferenceError(format!(
"Ollama returned status {status}: {body}"
)));
}
let result: GenerateResponse = response.json().await.map_err(|e| {
TokenOptError::InferenceError(format!("Failed to parse Ollama response: {e}"))
})?;
Ok(result.response)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn url_trailing_slash_stripped() {
let s = OllamaSummarizer::new("http://localhost:11434/", "model");
assert_eq!(s.base_url, "http://localhost:11434");
}
#[test]
fn url_without_trailing_slash() {
let s = OllamaSummarizer::new("http://localhost:11434", "model");
assert_eq!(s.base_url, "http://localhost:11434");
}
#[test]
fn model_stored() {
let s = OllamaSummarizer::new("http://localhost:11434", "llama3.2:3b");
assert_eq!(s.model, "llama3.2:3b");
}
#[tokio::test]
async fn summarize_connection_refused() {
let s = OllamaSummarizer::new("http://127.0.0.1:19999", "test");
let result = s.summarize("system", "text").await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("request failed") || err.contains("error"),
"Unexpected error: {err}"
);
}
}