use crate::error::LlmError;
use crate::params::mapper::{ParameterConstraints, ParameterMapper};
use crate::types::{CommonParams, ProviderParams, ProviderType};
use serde_json::{Value, json};
use std::collections::HashMap;
pub struct OllamaParameterMapper;
impl ParameterMapper for OllamaParameterMapper {
fn provider_type(&self) -> ProviderType {
ProviderType::Ollama
}
fn map_common_params(&self, params: &CommonParams) -> Value {
let mut ollama_params = json!({});
if !params.model.is_empty() {
ollama_params["model"] = json!(params.model);
}
if let Some(temperature) = params.temperature {
ollama_params["temperature"] = json!(temperature);
}
if let Some(max_tokens) = params.max_tokens {
ollama_params["num_predict"] = json!(max_tokens);
}
if let Some(top_p) = params.top_p {
ollama_params["top_p"] = json!(top_p);
}
if let Some(stop_sequences) = ¶ms.stop_sequences {
ollama_params["stop"] = json!(stop_sequences);
}
if let Some(seed) = params.seed {
ollama_params["seed"] = json!(seed);
}
ollama_params
}
fn merge_provider_params(&self, mut base: Value, provider_params: &ProviderParams) -> Value {
if let Some(keep_alive) = provider_params.get::<String>("keep_alive") {
base["keep_alive"] = json!(keep_alive);
}
if let Some(raw) = provider_params.get::<bool>("raw") {
base["raw"] = json!(raw);
}
if let Some(format) = provider_params.get::<String>("format") {
base["format"] = json!(format);
}
if let Some(stop) = provider_params.get::<Vec<String>>("stop") {
base["stop"] = json!(stop);
}
if let Some(numa) = provider_params.get::<bool>("numa") {
base["numa"] = json!(numa);
}
if let Some(num_ctx) = provider_params.get::<u32>("num_ctx") {
base["num_ctx"] = json!(num_ctx);
}
if let Some(num_batch) = provider_params.get::<u32>("num_batch") {
base["num_batch"] = json!(num_batch);
}
if let Some(num_gpu) = provider_params.get::<u32>("num_gpu") {
base["num_gpu"] = json!(num_gpu);
}
if let Some(main_gpu) = provider_params.get::<u32>("main_gpu") {
base["main_gpu"] = json!(main_gpu);
}
if let Some(use_mmap) = provider_params.get::<bool>("use_mmap") {
base["use_mmap"] = json!(use_mmap);
}
if let Some(num_thread) = provider_params.get::<u32>("num_thread") {
base["num_thread"] = json!(num_thread);
}
for (key, value) in &provider_params.params {
if base.get(key).is_none() {
base[key] = value.clone();
}
}
base
}
fn validate_params(&self, params: &Value) -> Result<(), LlmError> {
if let Some(temp) = params
.get("temperature")
.and_then(serde_json::Value::as_f64)
&& !(0.0..=2.0).contains(&temp)
{
return Err(LlmError::InvalidParameter(
"Temperature must be between 0.0 and 2.0".to_string(),
));
}
if let Some(top_p) = params.get("top_p").and_then(serde_json::Value::as_f64)
&& !(0.0..=1.0).contains(&top_p)
{
return Err(LlmError::InvalidParameter(
"top_p must be between 0.0 and 1.0".to_string(),
));
}
if let Some(num_predict) = params
.get("num_predict")
.and_then(serde_json::Value::as_u64)
&& num_predict == 0
{
return Err(LlmError::InvalidParameter(
"num_predict must be greater than 0".to_string(),
));
}
if let Some(num_ctx) = params.get("num_ctx").and_then(serde_json::Value::as_u64)
&& num_ctx == 0
{
return Err(LlmError::InvalidParameter(
"num_ctx must be greater than 0".to_string(),
));
}
if let Some(num_batch) = params.get("num_batch").and_then(serde_json::Value::as_u64)
&& num_batch == 0
{
return Err(LlmError::InvalidParameter(
"num_batch must be greater than 0".to_string(),
));
}
if let Some(num_gpu) = params.get("num_gpu").and_then(serde_json::Value::as_u64)
&& num_gpu > 64
{
return Err(LlmError::InvalidParameter(
"num_gpu should not exceed 64".to_string(),
));
}
if let Some(num_thread) = params.get("num_thread").and_then(serde_json::Value::as_u64)
&& (num_thread == 0 || num_thread > 256)
{
return Err(LlmError::InvalidParameter(
"num_thread must be between 1 and 256".to_string(),
));
}
Ok(())
}
fn get_param_constraints(&self) -> ParameterConstraints {
ParameterConstraints {
temperature_min: 0.0,
temperature_max: 2.0,
max_tokens_min: 1,
max_tokens_max: 100_000,
top_p_min: 0.0,
top_p_max: 1.0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct OllamaProviderParams {
pub keep_alive: Option<String>,
pub raw: Option<bool>,
pub format: Option<String>,
pub stop: Option<Vec<String>>,
pub numa: Option<bool>,
pub num_ctx: Option<u32>,
pub num_batch: Option<u32>,
pub num_gpu: Option<u32>,
pub main_gpu: Option<u32>,
pub use_mmap: Option<bool>,
pub num_thread: Option<u32>,
pub options: Option<HashMap<String, Value>>,
}
impl OllamaProviderParams {
pub fn new() -> Self {
Self::default()
}
pub fn keep_alive<S: Into<String>>(mut self, duration: S) -> Self {
self.keep_alive = Some(duration.into());
self
}
pub const fn raw(mut self, raw: bool) -> Self {
self.raw = Some(raw);
self
}
pub fn format<S: Into<String>>(mut self, format: S) -> Self {
self.format = Some(format.into());
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub const fn numa(mut self, numa: bool) -> Self {
self.numa = Some(numa);
self
}
pub const fn num_ctx(mut self, num_ctx: u32) -> Self {
self.num_ctx = Some(num_ctx);
self
}
pub const fn num_batch(mut self, num_batch: u32) -> Self {
self.num_batch = Some(num_batch);
self
}
pub const fn num_gpu(mut self, num_gpu: u32) -> Self {
self.num_gpu = Some(num_gpu);
self
}
pub const fn main_gpu(mut self, main_gpu: u32) -> Self {
self.main_gpu = Some(main_gpu);
self
}
pub const fn use_mmap(mut self, use_mmap: bool) -> Self {
self.use_mmap = Some(use_mmap);
self
}
pub const fn num_thread(mut self, num_thread: u32) -> Self {
self.num_thread = Some(num_thread);
self
}
pub fn option<K: Into<String>>(mut self, key: K, value: Value) -> Self {
let mut options = self.options.unwrap_or_default();
options.insert(key.into(), value);
self.options = Some(options);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ollama_parameter_mapper() {
let mapper = OllamaParameterMapper;
assert_eq!(mapper.provider_type(), ProviderType::Ollama);
let common_params = CommonParams {
model: "llama3.2".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
top_p: Some(0.9),
stop_sequences: Some(vec!["\\n".to_string()]),
seed: Some(42),
};
let mapped_params = mapper.map_common_params(&common_params);
assert_eq!(mapped_params["model"], "llama3.2");
assert!((mapped_params["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
assert_eq!(mapped_params["num_predict"], 1000);
assert!((mapped_params["top_p"].as_f64().unwrap() - 0.9).abs() < 0.001);
assert_eq!(mapped_params["seed"], 42);
}
#[test]
fn test_ollama_provider_params() {
let params = OllamaProviderParams::new()
.keep_alive("10m".to_string())
.raw(true)
.numa(false)
.num_ctx(2048)
.num_gpu(1)
.num_thread(8);
assert_eq!(params.keep_alive, Some("10m".to_string()));
assert_eq!(params.raw, Some(true));
assert_eq!(params.numa, Some(false));
assert_eq!(params.num_ctx, Some(2048));
assert_eq!(params.num_gpu, Some(1));
assert_eq!(params.num_thread, Some(8));
}
}