use std::time::Duration;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
pub const DEFAULT_MODEL: &str = "llama3";
pub const EMBEDDING_MODEL: &str = "nomic-embed-text";
#[derive(Debug, Clone)]
pub struct OllamaClient {
client: reqwest::Client,
base_url: String,
generate_timeout: Duration,
embedding_timeout: Duration,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest {
pub model: String,
pub prompt: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize)]
pub struct GenerateRequest {
pub model: String,
pub prompt: String,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GenerateResponse {
pub response: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TagsResponse {
pub models: Vec<ModelInfo>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub size: u64,
}
#[derive(Debug, Clone, Serialize)]
struct PullRequest {
name: String,
stream: bool,
}
const DEFAULT_GENERATE_TIMEOUT_SECS: u64 = 120;
const DEFAULT_EMBEDDING_TIMEOUT_SECS: u64 = 60;
impl OllamaClient {
pub fn new() -> Result<Self> {
Self::with_config(
DEFAULT_OLLAMA_URL,
DEFAULT_GENERATE_TIMEOUT_SECS,
DEFAULT_EMBEDDING_TIMEOUT_SECS,
)
}
pub fn with_url(base_url: &str) -> Result<Self> {
Self::with_config(
base_url,
DEFAULT_GENERATE_TIMEOUT_SECS,
DEFAULT_EMBEDDING_TIMEOUT_SECS,
)
}
pub fn with_config(
base_url: &str,
generate_timeout_secs: u64,
embedding_timeout_secs: u64,
) -> Result<Self> {
let client = reqwest::Client::builder()
.build()
.context("Failed to create HTTP client")?;
Ok(Self {
client,
base_url: base_url.trim_end_matches('/').to_string(),
generate_timeout: Duration::from_secs(generate_timeout_secs),
embedding_timeout: Duration::from_secs(embedding_timeout_secs),
})
}
#[must_use]
pub fn base_url(&self) -> &str {
&self.base_url
}
pub async fn health_check(&self) -> Result<bool> {
let url = format!("{}/api/tags", self.base_url);
let response = self
.client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
.with_context(|| {
format!(
"Cannot connect to Ollama at {}. Start with: ollama serve",
self.base_url
)
})?;
Ok(response.status().is_success())
}
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let url = format!("{}/api/tags", self.base_url);
let response = self
.client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
.with_context(|| {
format!(
"Cannot connect to Ollama at {}. Start with: ollama serve",
self.base_url
)
})?;
let tags: TagsResponse = response
.json()
.await
.context("Failed to parse model list response")?;
Ok(tags.models)
}
pub async fn generate_embedding(&self, model: &str, text: &str) -> Result<Vec<f32>> {
let url = format!("{}/api/embeddings", self.base_url);
let request = EmbeddingRequest {
model: model.to_string(),
prompt: text.to_string(),
};
let response = self
.client
.post(&url)
.timeout(self.embedding_timeout)
.json(&request)
.send()
.await
.with_context(|| {
format!(
"Cannot connect to Ollama at {} (timeout: {}s). Start with: ollama serve",
self.base_url,
self.embedding_timeout.as_secs()
)
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama embedding request failed ({status}): {body}");
}
let embedding_response: EmbeddingResponse = response
.json()
.await
.context("Failed to parse embedding response")?;
Ok(embedding_response.embedding)
}
pub async fn generate(&self, model: &str, prompt: &str, json_format: bool) -> Result<String> {
let url = format!("{}/api/generate", self.base_url);
let request = GenerateRequest {
model: model.to_string(),
prompt: prompt.to_string(),
stream: false,
format: if json_format {
Some("json".to_string())
} else {
None
},
};
let response = self
.client
.post(&url)
.timeout(self.generate_timeout)
.json(&request)
.send()
.await
.with_context(|| {
format!(
"Cannot connect to Ollama at {} (timeout: {}s). Start with: ollama serve",
self.base_url,
self.generate_timeout.as_secs()
)
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama generate request failed ({status}): {body}");
}
let generate_response: GenerateResponse = response
.json()
.await
.context("Failed to parse generate response")?;
Ok(generate_response.response)
}
pub async fn pull_model(&self, name: &str) -> Result<()> {
let url = format!("{}/api/pull", self.base_url);
let request = PullRequest {
name: name.to_string(),
stream: false,
};
let response = self
.client
.post(&url)
.timeout(Duration::from_secs(600)) .json(&request)
.send()
.await
.with_context(|| format!("Failed to pull model '{name}'"))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to pull model '{name}' ({status}): {body}");
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_request_serialization() {
let request = EmbeddingRequest {
model: "llama3".to_string(),
prompt: "test prompt".to_string(),
};
let json = serde_json::to_string(&request).expect("Failed to serialize");
assert!(json.contains("\"model\":\"llama3\""));
assert!(json.contains("\"prompt\":\"test prompt\""));
}
#[test]
fn test_embedding_response_deserialization() {
let json = r#"{"embedding": [0.1, 0.2, 0.3]}"#;
let response: EmbeddingResponse =
serde_json::from_str(json).expect("Failed to deserialize");
assert_eq!(response.embedding, vec![0.1, 0.2, 0.3]);
}
#[test]
fn test_generate_request_serialization() {
let request = GenerateRequest {
model: "llama3".to_string(),
prompt: "Hello".to_string(),
stream: false,
format: Some("json".to_string()),
};
let json = serde_json::to_string(&request).expect("Failed to serialize");
assert!(json.contains("\"model\":\"llama3\""));
assert!(json.contains("\"stream\":false"));
assert!(json.contains("\"format\":\"json\""));
}
#[test]
fn test_generate_request_no_format() {
let request = GenerateRequest {
model: "llama3".to_string(),
prompt: "Hello".to_string(),
stream: false,
format: None,
};
let json = serde_json::to_string(&request).expect("Failed to serialize");
assert!(!json.contains("format"));
}
#[test]
fn test_generate_response_deserialization() {
let json = r#"{"response": "Hello, world!"}"#;
let response: GenerateResponse = serde_json::from_str(json).expect("Failed to deserialize");
assert_eq!(response.response, "Hello, world!");
}
#[test]
fn test_tags_response_deserialization() {
let json = r#"{"models": [{"name": "llama3:latest", "size": 4000000000}]}"#;
let response: TagsResponse = serde_json::from_str(json).expect("Failed to deserialize");
assert_eq!(response.models.len(), 1);
assert_eq!(response.models[0].name, "llama3:latest");
assert_eq!(response.models[0].size, 4_000_000_000);
}
#[test]
fn test_client_creation() {
let client = OllamaClient::new().expect("Failed to create client");
assert_eq!(client.base_url(), DEFAULT_OLLAMA_URL);
}
#[test]
fn test_client_custom_url() {
let client =
OllamaClient::with_url("http://custom:8080/").expect("Failed to create client");
assert_eq!(client.base_url(), "http://custom:8080");
}
}