use reqwest::Client;
use serde_json::{Value, json};
use std::time::Duration;
use crate::{
base::error::{TestError, TestsResult},
ollama::{chat::ChatMessage, response::OllamaResponse},
};
#[derive(Debug)]
pub struct OllamaClient {
client: Client,
host: String,
timeout: Duration,
}
impl OllamaClient {
pub fn new(host: impl Into<String>) -> Self {
Self {
client: Client::new(),
host: host.into(),
timeout: Duration::from_secs(60),
}
}
pub async fn chat_with_history(
&self,
id: &str,
model: &str,
messages: &[ChatMessage],
tools: &[Value],
) -> TestsResult<OllamaResponse> {
let url = format!("{}/api/chat", self.host);
let body = json!({
"model": model,
"messages": messages,
"tools": tools,
"stream": false,
"session_id": id.to_string(),
"options": {
"temperature": 0.1,
"seed": 42,
"num_ctx": 4096,
"num_predict": 1024,
}
});
let response = self
.client
.post(&url)
.timeout(self.timeout)
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
TestError::Timeout(format!("Ollama request timed out after {:?}", self.timeout))
} else {
TestError::Ollama(format!("Ollama request failed: {}", e))
}
})?
.json::<OllamaResponse>()
.await
.map_err(|e| TestError::Ollama(format!("Failed to parse Ollama response: {}", e)))?;
Ok(response)
}
pub async fn unload_except(&self, keep: &str) -> TestsResult<()> {
let loaded = self.loaded_models().await?;
for model in loaded {
if model != keep {
let _ = self.unload_model(&model).await;
}
}
Ok(())
}
async fn loaded_models(&self) -> TestsResult<Vec<String>> {
let url = format!("{}/api/ps", self.host);
let json: Value = self
.client
.get(&url)
.send()
.await
.map_err(|e| TestError::Ollama(format!("Failed to list models: {}", e)))?
.json()
.await
.map_err(|e| TestError::Ollama(format!("Failed to parse response: {}", e)))?;
Ok(json["models"]
.as_array()
.map(|a| {
a.iter()
.filter_map(|m| m["name"].as_str().map(String::from))
.collect()
})
.unwrap_or_default())
}
async fn unload_model(&self, model: &str) -> TestsResult<()> {
let url = format!("{}/api/generate", self.host);
self.client
.post(&url)
.json(&json!({"model": model, "keep_alive": 0}))
.send()
.await
.map_err(|e| TestError::Ollama(format!("Failed to unload model: {}", e)))?;
Ok(())
}
}