use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Instant;
use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
#[derive(Debug, Serialize)]
struct OllamaRequest {
model: String,
prompt: String,
system: Option<String>,
stream: bool,
options: OllamaOptions,
}
#[derive(Debug, Serialize)]
struct OllamaOptions {
temperature: f32,
num_predict: u32,
}
#[derive(Debug, Deserialize)]
struct OllamaApiResponse {
response: String,
model: String,
#[serde(default)]
eval_count: Option<u32>,
}
#[derive(Debug)]
pub struct OllamaProvider {
base_url: String,
model: String,
client: reqwest::Client,
default_timeout: std::time::Duration,
}
impl OllamaProvider {
pub fn new(model: &str) -> Self {
let timeout = std::env::var("VEX_LLM_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(30);
Self {
base_url: "http://localhost:11434".to_string(),
model: model.to_string(),
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(timeout))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
default_timeout: std::time::Duration::from_secs(timeout),
}
}
pub fn with_url(base_url: &str, model: &str) -> Self {
let timeout = std::env::var("VEX_LLM_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(30);
let url = base_url.to_lowercase();
if url.contains("localhost") || url.contains("127.0.0.1") || url.contains("::1") {
tracing::warn!(url = %base_url, "Potentially unsafe URL in OllamaProvider");
}
Self {
base_url: base_url.to_string(),
model: model.to_string(),
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(timeout))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
default_timeout: std::time::Duration::from_secs(timeout),
}
}
}
#[async_trait]
impl LlmProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
async fn is_available(&self) -> bool {
let url = format!("{}/api/tags", self.base_url);
self.client.get(&url).send().await.is_ok()
}
async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
let start = Instant::now();
let url = format!("{}/api/generate", self.base_url);
let ollama_request = OllamaRequest {
model: self.model.clone(),
prompt: request.prompt,
system: Some(request.system),
stream: false,
options: OllamaOptions {
temperature: request.temperature,
num_predict: request.max_tokens,
},
};
let request_timeout = request.timeout.unwrap_or(self.default_timeout);
let response = tokio::time::timeout(
request_timeout,
self.client.post(&url).json(&ollama_request).send(),
)
.await
.map_err(|_| LlmError::Timeout(request_timeout.as_millis() as u64))?
.map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(LlmError::RequestFailed(format!(
"Status: {}",
response.status()
)));
}
let api_response: OllamaApiResponse = response
.json()
.await
.map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
Ok(LlmResponse {
content: api_response.response,
model: api_response.model,
tokens_used: api_response.eval_count,
latency_ms: start.elapsed().as_millis() as u64,
trace_root: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_ollama_available() {
let provider = OllamaProvider::new("llama2");
if provider.is_available().await {
let response = provider.ask("Say hello in one word").await.unwrap();
assert!(!response.is_empty());
}
}
}