use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use super::client::LlmClient;
use super::error::{LlmError, LlmResult};
#[derive(Debug, Clone)]
pub struct OllamaClient {
base_url: String,
model: String,
timeout_seconds: u64,
max_context_tokens: usize,
temperature: f32,
#[cfg(feature = "llm-online")]
client: reqwest::Client,
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct GenerateRequest<'a> {
model: &'a str,
prompt: &'a str,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<GenerateOptions>,
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct GenerateOptions {
temperature: f32,
num_ctx: usize,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct GenerateResponse {
response: String,
#[serde(default)]
done: bool,
#[serde(default)]
total_duration: Option<u64>,
#[serde(default)]
prompt_eval_count: Option<usize>,
#[serde(default)]
eval_count: Option<usize>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct TagsResponse {
models: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ModelInfo {
name: String,
#[serde(default)]
size: u64,
}
impl OllamaClient {
#[cfg(feature = "llm-online")]
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
model: model.into(),
timeout_seconds: 120,
max_context_tokens: 4096,
temperature: 0.1,
client: reqwest::Client::new(),
}
}
#[cfg(not(feature = "llm-online"))]
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
model: model.into(),
timeout_seconds: 120,
max_context_tokens: 4096,
temperature: 0.1,
}
}
pub fn with_timeout(mut self, seconds: u64) -> Self {
self.timeout_seconds = seconds;
self
}
pub fn with_max_context(mut self, tokens: usize) -> Self {
self.max_context_tokens = tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature.clamp(0.0, 2.0);
self
}
pub fn base_url(&self) -> &str {
&self.base_url
}
#[cfg(feature = "llm-online")]
pub async fn list_models(&self) -> LlmResult<Vec<String>> {
let url = format!("{}/api/tags", self.base_url);
let response = self
.client
.get(&url)
.timeout(std::time::Duration::from_secs(10))
.send()
.await
.map_err(|e| LlmError::ConnectionError(e.to_string()))?;
if !response.status().is_success() {
return Err(LlmError::ConnectionError(format!(
"Failed to list models: HTTP {}",
response.status()
)));
}
let tags: TagsResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(e.to_string()))?;
Ok(tags.models.into_iter().map(|m| m.name).collect())
}
#[cfg(not(feature = "llm-online"))]
pub async fn list_models(&self) -> LlmResult<Vec<String>> {
Err(LlmError::FeatureNotAvailable(
"Online LLM".to_string(),
"llm-online".to_string(),
))
}
#[cfg(feature = "llm-online")]
pub async fn model_available(&self) -> LlmResult<bool> {
let models = self.list_models().await?;
Ok(models.iter().any(|m| m.starts_with(&self.model)))
}
#[cfg(not(feature = "llm-online"))]
pub async fn model_available(&self) -> LlmResult<bool> {
Err(LlmError::FeatureNotAvailable(
"Online LLM".to_string(),
"llm-online".to_string(),
))
}
}
#[cfg(feature = "llm-online")]
#[async_trait]
impl LlmClient for OllamaClient {
async fn complete(&self, prompt: &str) -> LlmResult<String> {
let url = format!("{}/api/generate", self.base_url);
let request = GenerateRequest {
model: &self.model,
prompt,
stream: false,
options: Some(GenerateOptions {
temperature: self.temperature,
num_ctx: self.max_context_tokens,
}),
};
tracing::debug!("Sending request to Ollama: {}", url);
let response = self
.client
.post(&url)
.json(&request)
.timeout(std::time::Duration::from_secs(self.timeout_seconds))
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout(self.timeout_seconds)
} else if e.is_connect() {
LlmError::ConnectionError(format!(
"Failed to connect to Ollama at {}: {}",
self.base_url, e
))
} else {
LlmError::ConnectionError(e.to_string())
}
})?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
return Err(LlmError::RateLimited(60));
}
return Err(LlmError::ConnectionError(format!(
"Ollama API error (HTTP {}): {}",
status, error_text
)));
}
let gen_response: GenerateResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(e.to_string()))?;
if let Some(duration) = gen_response.total_duration {
tracing::debug!(
"Ollama completion took {} ms, {} prompt tokens, {} completion tokens",
duration / 1_000_000,
gen_response.prompt_eval_count.unwrap_or(0),
gen_response.eval_count.unwrap_or(0)
);
}
Ok(gen_response.response)
}
fn model_name(&self) -> &str {
&self.model
}
fn max_tokens(&self) -> usize {
self.max_context_tokens
}
async fn is_ready(&self) -> bool {
self.list_models().await.is_ok()
}
}
#[cfg(not(feature = "llm-online"))]
#[async_trait]
impl LlmClient for OllamaClient {
async fn complete(&self, _prompt: &str) -> LlmResult<String> {
Err(LlmError::FeatureNotAvailable(
"Online LLM".to_string(),
"llm-online".to_string(),
))
}
fn model_name(&self) -> &str {
&self.model
}
fn max_tokens(&self) -> usize {
self.max_context_tokens
}
async fn is_ready(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ollama_client_new() {
let client = OllamaClient::new("http://localhost:11434", "llama3.2");
assert_eq!(client.base_url(), "http://localhost:11434");
assert_eq!(client.model_name(), "llama3.2");
assert_eq!(client.max_tokens(), 4096);
}
#[test]
fn test_ollama_client_builder() {
let client = OllamaClient::new("http://remote:11434", "mistral")
.with_timeout(60)
.with_max_context(8192)
.with_temperature(0.5);
assert_eq!(client.timeout_seconds, 60);
assert_eq!(client.max_context_tokens, 8192);
assert!((client.temperature - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_temperature_clamp() {
let client = OllamaClient::new("http://localhost:11434", "llama3.2").with_temperature(5.0);
assert!((client.temperature - 2.0).abs() < f32::EPSILON);
let client = OllamaClient::new("http://localhost:11434", "llama3.2").with_temperature(-1.0);
assert!(client.temperature.abs() < f32::EPSILON);
}
#[test]
fn test_generate_request_serialize() {
let request = GenerateRequest {
model: "llama3.2",
prompt: "Test prompt",
stream: false,
options: Some(GenerateOptions {
temperature: 0.1,
num_ctx: 4096,
}),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("llama3.2"));
assert!(json.contains("Test prompt"));
assert!(json.contains("temperature"));
}
#[test]
fn test_generate_response_deserialize() {
let json = r#"{
"response": "Generated text",
"done": true,
"total_duration": 1500000000,
"prompt_eval_count": 50,
"eval_count": 100
}"#;
let response: GenerateResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.response, "Generated text");
assert!(response.done);
assert_eq!(response.total_duration, Some(1500000000));
assert_eq!(response.prompt_eval_count, Some(50));
assert_eq!(response.eval_count, Some(100));
}
#[test]
fn test_generate_response_minimal() {
let json = r#"{"response": "Text", "done": true}"#;
let response: GenerateResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.response, "Text");
assert!(response.total_duration.is_none());
}
}