use async_trait::async_trait;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use crate::error::LlmError;
use crate::traits::ModelListingCapability;
use crate::types::*;
use super::types::*;
use super::utils::*;
#[derive(Clone)]
pub struct OllamaModelsCapability {
pub base_url: String,
pub http_client: reqwest::Client,
pub http_config: HttpConfig,
}
impl OllamaModelsCapability {
pub const fn new(
base_url: String,
http_client: reqwest::Client,
http_config: HttpConfig,
) -> Self {
Self {
base_url,
http_client,
http_config,
}
}
fn convert_model_info(&self, model: &OllamaModel) -> ModelInfo {
ModelInfo {
id: model.name.clone(),
name: Some(model.name.clone()),
description: Some(format!("Ollama model: {}", model.details.family)),
capabilities: vec!["chat".to_string(), "completion".to_string()],
context_window: None, max_output_tokens: None, created: Some(0), owned_by: "ollama".to_string(),
input_cost_per_token: None, output_cost_per_token: None, }
}
pub async fn pull_model(&self, model_name: String) -> Result<(), LlmError> {
validate_model_name(&model_name)?;
let headers = build_headers(&self.http_config.headers)?;
let url = format!("{}/api/pull", self.base_url);
let body = serde_json::json!({
"model": model_name,
"stream": false
});
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::HttpError(format!(
"Pull model request failed: {status} - {error_text}"
)));
}
Ok(())
}
pub async fn pull_model_stream(
&self,
model_name: String,
) -> Result<impl futures_util::Stream<Item = Result<PullProgress, LlmError>>, LlmError> {
validate_model_name(&model_name)?;
let headers = build_headers(&self.http_config.headers)?;
let url = format!("{}/api/pull", self.base_url);
let body = serde_json::json!({
"model": model_name,
"stream": true
});
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::HttpError(format!(
"Pull model stream request failed: {status} - {error_text}"
)));
}
let stream = response
.bytes_stream()
.map(|chunk_result| match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
for line in chunk_str.lines() {
if let Ok(Some(json_value)) = parse_streaming_line(line)
&& let Ok(progress) = serde_json::from_value::<PullProgress>(json_value)
{
return Ok(progress);
}
}
Ok(PullProgress {
status: "processing".to_string(),
digest: None,
total: None,
completed: None,
})
}
Err(e) => Err(LlmError::StreamError(format!("Stream error: {e}"))),
});
Ok(stream)
}
pub async fn delete_model(&self, model_name: String) -> Result<(), LlmError> {
validate_model_name(&model_name)?;
let headers = build_headers(&self.http_config.headers)?;
let url = format!("{}/api/delete", self.base_url);
let body = serde_json::json!({
"model": model_name
});
let response = self
.http_client
.delete(&url)
.headers(headers)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::HttpError(format!(
"Delete model request failed: {status} - {error_text}"
)));
}
Ok(())
}
pub async fn copy_model(&self, source: String, destination: String) -> Result<(), LlmError> {
validate_model_name(&source)?;
validate_model_name(&destination)?;
let headers = build_headers(&self.http_config.headers)?;
let url = format!("{}/api/copy", self.base_url);
let body = serde_json::json!({
"source": source,
"destination": destination
});
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::HttpError(format!(
"Copy model request failed: {status} - {error_text}"
)));
}
Ok(())
}
pub async fn show_model(&self, model_name: String) -> Result<ModelDetails, LlmError> {
validate_model_name(&model_name)?;
let headers = build_headers(&self.http_config.headers)?;
let url = format!("{}/api/show", self.base_url);
let body = serde_json::json!({
"model": model_name
});
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::HttpError(format!(
"Show model request failed: {status} - {error_text}"
)));
}
let model_info: serde_json::Value = response.json().await?;
Ok(ModelDetails {
modelfile: model_info
.get("modelfile")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
parameters: model_info
.get("parameters")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
template: model_info
.get("template")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
details: model_info
.get("details")
.cloned()
.unwrap_or(serde_json::Value::Null),
})
}
pub async fn list_running_models(&self) -> Result<Vec<RunningModelInfo>, LlmError> {
let headers = build_headers(&self.http_config.headers)?;
let url = format!("{}/api/ps", self.base_url);
let response = self.http_client.get(&url).headers(headers).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::HttpError(format!(
"List running models request failed: {status} - {error_text}"
)));
}
let running_response: OllamaRunningModelsResponse = response.json().await?;
let running_models = running_response
.models
.into_iter()
.map(|model| RunningModelInfo {
name: model.name,
model: model.model,
size: model.size,
digest: model.digest,
expires_at: model.expires_at,
size_vram: model.size_vram,
})
.collect();
Ok(running_models)
}
}
#[async_trait]
impl ModelListingCapability for OllamaModelsCapability {
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
let headers = build_headers(&self.http_config.headers)?;
let url = crate::utils::url::join_url(&self.base_url, "api/tags");
let response = self.http_client.get(&url).headers(headers).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::HttpError(format!(
"List models request failed: {status} - {error_text}"
)));
}
let models_response: OllamaModelsResponse = response.json().await?;
let models = models_response
.models
.into_iter()
.map(|model| self.convert_model_info(&model))
.collect();
Ok(models)
}
async fn get_model(&self, model_id: String) -> Result<ModelInfo, LlmError> {
let models = self.list_models().await?;
for model in models {
if model.id == model_id {
return Ok(model);
}
}
Err(LlmError::NotFound(format!("Model '{model_id}' not found")))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PullProgress {
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub digest: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDetails {
pub modelfile: String,
pub parameters: String,
pub template: String,
pub details: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunningModelInfo {
pub name: String,
pub model: String,
pub size: u64,
pub digest: String,
pub expires_at: String,
pub size_vram: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_model_info() {
let capability = OllamaModelsCapability::new(
"http://localhost:11434".to_string(),
reqwest::Client::new(),
HttpConfig::default(),
);
let ollama_model = OllamaModel {
name: "llama3.2:latest".to_string(),
model: "llama3.2:latest".to_string(),
modified_at: "2023-01-01T00:00:00Z".to_string(),
size: 1_000_000,
digest: "sha256:abc123".to_string(),
details: OllamaModelDetails {
parent_model: "".to_string(),
format: "gguf".to_string(),
family: "llama".to_string(),
families: vec!["llama".to_string()],
parameter_size: "3.2B".to_string(),
quantization_level: "Q4_K_M".to_string(),
},
};
let model_info = capability.convert_model_info(&ollama_model);
assert_eq!(model_info.id, "llama3.2:latest");
assert_eq!(model_info.owned_by, "ollama");
assert_eq!(model_info.name, Some("llama3.2:latest".to_string()));
}
}