use crate::chat::Message as ChatMessage;
use crate::client::ModelClient;
use crate::error::{OllamaError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio_stream::{Stream, StreamExt};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub model: String,
pub modified_at: String,
pub size: u64,
pub digest: String,
pub details: ModelDetails,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_host: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDetails {
pub parent_model: String,
pub format: String,
pub family: String,
pub families: Option<Vec<String>>,
pub parameter_size: String,
pub quantization_level: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListModelsResponse {
pub models: Vec<ModelInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ShowModelRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub verbose: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShowModelResponse {
#[serde(default)]
pub modelfile: String,
#[serde(default)]
pub parameters: String,
#[serde(default)]
pub template: String,
pub details: ModelDetails,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_info: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capabilities: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CopyModelRequest {
pub source: String,
pub destination: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeleteModelRequest {
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PullModelRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub insecure: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PushModelRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub insecure: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateModelRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub from: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub files: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub adapters: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub license: Option<License>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<ChatMessage>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantize: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum License {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunningModel {
pub name: String,
pub model: String,
pub size: u64,
pub digest: String,
pub details: ModelDetails,
pub expires_at: String,
pub size_vram: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListRunningModelsResponse {
pub models: Vec<RunningModel>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionResponse {
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatusResponse {
pub status: String,
}
impl ModelClient {
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let url = self
.base_url
.join("api/tags")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.get(url)
.send()
.await
.map_err(OllamaError::RequestError)?;
let models: ListModelsResponse = self.handle_response(response, None).await?;
Ok(models.models)
}
pub async fn show_model(&self, request: ShowModelRequest) -> Result<ShowModelResponse> {
let url = self
.base_url
.join("api/show")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
pub async fn copy_model(&self, request: CopyModelRequest) -> Result<()> {
let url = self
.base_url
.join("api/copy")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_void_response(response).await
}
pub async fn delete_model(&self, request: DeleteModelRequest) -> Result<()> {
let url = self
.base_url
.join("api/delete")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.delete(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_void_response(response).await
}
pub async fn pull_model(
&self,
mut request: PullModelRequest,
) -> Result<impl Stream<Item = Result<StatusResponse>> + '_> {
let url = self
.base_url
.join("api/pull")
.map_err(OllamaError::UrlError)?;
request.stream = Some(true);
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
let stream = response.bytes_stream().map(|item| match item {
Ok(bytes) => match serde_json::from_slice::<StatusResponse>(&bytes) {
Ok(response) => Ok(response),
Err(e) => Err(OllamaError::JsonError(e)),
},
Err(e) => Err(OllamaError::RequestError(e)),
});
Ok(stream)
}
pub async fn push_model(
&self,
request: PushModelRequest,
) -> Result<impl Stream<Item = Result<StatusResponse>> + '_> {
let url = self
.base_url
.join("api/push")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
let stream = response.bytes_stream().map(|item| match item {
Ok(bytes) => match serde_json::from_slice::<StatusResponse>(&bytes) {
Ok(response) => Ok(response),
Err(e) => Err(OllamaError::JsonError(e)),
},
Err(e) => Err(OllamaError::RequestError(e)),
});
Ok(stream)
}
pub async fn create_model(
&self,
request: CreateModelRequest,
) -> Result<impl Stream<Item = Result<StatusResponse>> + '_> {
let url = self
.base_url
.join("api/create")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
let stream = response.bytes_stream().map(|item| match item {
Ok(bytes) => match serde_json::from_slice::<StatusResponse>(&bytes) {
Ok(response) => Ok(response),
Err(e) => Err(OllamaError::JsonError(e)),
},
Err(e) => Err(OllamaError::RequestError(e)),
});
Ok(stream)
}
#[cfg(feature = "local")]
pub async fn list_running_models(&self) -> Result<Vec<RunningModel>> {
let url = self
.base_url
.join("api/ps")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.get(url)
.send()
.await
.map_err(OllamaError::RequestError)?;
let models: ListRunningModelsResponse = self.handle_response(response, None).await?;
Ok(models.models)
}
}