use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use crate::daemon::ollama_template::ChatTemplate;
use crate::MullamaError;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OllamaManifest {
#[serde(rename = "schemaVersion")]
pub schema_version: u32,
#[serde(rename = "mediaType")]
pub media_type: Option<String>,
pub config: LayerRef,
pub layers: Vec<Layer>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LayerRef {
#[serde(rename = "mediaType")]
pub media_type: String,
pub digest: String,
pub size: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Layer {
#[serde(rename = "mediaType")]
pub media_type: String,
pub digest: String,
pub size: u64,
}
#[derive(Debug, Clone)]
pub struct OllamaModelRef {
pub namespace: String,
pub repository: String,
pub tag: String,
}
impl OllamaModelRef {
pub fn parse(name: &str) -> Self {
let name = name.strip_prefix("ollama:").unwrap_or(name);
let (name_part, tag) = name.split_once(':').unwrap_or((name, "latest"));
let (namespace, repository) = if let Some((ns, repo)) = name_part.split_once('/') {
(ns.to_string(), repo.to_string())
} else {
("library".to_string(), name_part.to_string())
};
Self {
namespace,
repository,
tag: tag.to_string(),
}
}
pub fn display_name(&self) -> String {
if self.namespace == "library" {
format!("{}:{}", self.repository, self.tag)
} else {
format!("{}/{}:{}", self.namespace, self.repository, self.tag)
}
}
pub fn registry_path(&self) -> String {
format!("{}/{}", self.namespace, self.repository)
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct OllamaParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub typical_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repeat_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repeat_last_n: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub penalize_newline: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mirostat: Option<u8>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mirostat_tau: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mirostat_eta: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_ctx: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_predict: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_keep: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_batch: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_thread: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_gpu: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub main_gpu: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_mmap: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_mlock: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub low_vram: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub numa: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vocab_only: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tfs_z: Option<f32>,
}
impl OllamaParameters {
pub fn to_sampler_params(&self) -> crate::SamplerParams {
let mut p = crate::SamplerParams::default();
if let Some(v) = self.temperature {
p.temperature = v;
}
if let Some(v) = self.top_k {
p.top_k = v;
}
if let Some(v) = self.top_p {
p.top_p = v;
}
if let Some(v) = self.min_p {
p.min_p = v;
}
if let Some(v) = self.typical_p {
p.typical_p = v;
}
if let Some(v) = self.repeat_penalty {
p.penalty_repeat = v;
}
if let Some(v) = self.repeat_last_n {
p.penalty_last_n = v;
}
if let Some(v) = self.frequency_penalty {
p.penalty_freq = v;
}
if let Some(v) = self.presence_penalty {
p.penalty_present = v;
}
if let Some(v) = self.penalize_newline {
p.penalize_nl = v;
}
if let Some(v) = self.seed {
p.seed = v as u32;
}
p
}
pub fn stop_sequences(&self) -> Vec<String> {
self.stop.clone().unwrap_or_default()
}
pub fn max_tokens(&self) -> Option<u32> {
self.num_predict
.and_then(|n| if n < 0 { None } else { Some(n as u32) })
}
pub fn context_size(&self) -> Option<u32> {
self.num_ctx
}
pub fn gpu_layers(&self) -> Option<i32> {
self.num_gpu
}
pub fn mirostat_enabled(&self) -> bool {
matches!(self.mirostat, Some(1) | Some(2))
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OllamaMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaModel {
pub name: String,
pub tag: String,
pub gguf_path: PathBuf,
pub projector_path: Option<PathBuf>,
pub template: Option<String>,
pub system_prompt: Option<String>,
pub parameters: OllamaParameters,
pub messages: Vec<OllamaMessage>,
pub license: Option<String>,
pub pulled_at: String,
pub total_size: u64,
}
impl OllamaModel {
pub fn get_stop_sequences(&self) -> Vec<String> {
let mut stops = Vec::new();
if let Some(ref param_stops) = self.parameters.stop {
stops.extend(param_stops.clone());
}
if let Some(ref template) = self.template {
let chat_template = ChatTemplate::from_ollama_template(template);
stops.extend(chat_template.stop_sequences);
}
let mut seen = std::collections::HashSet::new();
stops.retain(|s| seen.insert(s.clone()));
stops
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OllamaModelIndex {
pub models: HashMap<String, OllamaModel>,
}
impl OllamaModelIndex {
pub fn load(path: &Path) -> Result<Self, MullamaError> {
if path.exists() {
let content = fs::read_to_string(path).map_err(|e| {
MullamaError::OllamaError(format!("Failed to read Ollama index: {}", e))
})?;
serde_json::from_str(&content).map_err(|e| {
MullamaError::OllamaError(format!("Failed to parse Ollama index: {}", e))
})
} else {
Ok(Self::default())
}
}
pub fn save(&self, path: &Path) -> Result<(), MullamaError> {
let content = serde_json::to_string_pretty(self).map_err(|e| {
MullamaError::OllamaError(format!("Failed to serialize Ollama index: {}", e))
})?;
fs::write(path, content)
.map_err(|e| MullamaError::OllamaError(format!("Failed to write Ollama index: {}", e)))
}
pub fn get(&self, name: &str) -> Option<&OllamaModel> {
self.models.get(name)
}
pub fn insert(&mut self, model: OllamaModel) {
let key = format!("{}:{}", model.name, model.tag);
self.models.insert(key, model);
}
pub fn list(&self) -> Vec<&OllamaModel> {
self.models.values().collect()
}
}