use anyhow::Result;
use anyhow::bail;
use serde::Deserialize;
use serde::Serialize;
use crate::kv_cache_dtype::KvCacheDtype;
use crate::pooling_type::PoolingType;
use crate::validates::Validates;
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(deny_unknown_fields)]
pub struct InferenceParameters {
pub batch_n_tokens: usize,
pub context_size: u32,
pub enable_embeddings: bool,
pub image_resize_to_fit: u32,
pub k_cache_dtype: KvCacheDtype,
pub v_cache_dtype: KvCacheDtype,
pub min_p: f32,
pub n_gpu_layers: u32,
pub penalty_frequency: f32,
pub penalty_last_n: i32,
pub penalty_presence: f32,
pub penalty_repeat: f32,
pub pooling_type: PoolingType,
pub temperature: f32,
pub top_k: i32,
pub top_p: f32,
}
impl Validates<Self> for InferenceParameters {
fn validate(self) -> Result<Self> {
if self.image_resize_to_fit == 0 {
bail!("image_resize_to_fit must be greater than zero");
}
Ok(self)
}
}
impl Default for InferenceParameters {
fn default() -> Self {
Self {
batch_n_tokens: 512,
context_size: 8192,
enable_embeddings: false,
image_resize_to_fit: 1024,
k_cache_dtype: KvCacheDtype::Q8_0,
v_cache_dtype: KvCacheDtype::Q8_0,
min_p: 0.05,
n_gpu_layers: 0,
penalty_frequency: 0.0,
penalty_last_n: -1,
penalty_presence: 0.8,
penalty_repeat: 1.1,
pooling_type: PoolingType::Last,
temperature: 0.8,
top_k: 80,
top_p: 0.8,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_succeeds_with_default_params() {
let params = InferenceParameters::default();
assert!(params.validate().is_ok());
}
#[test]
fn validate_fails_when_image_resize_to_fit_is_zero() {
let params = InferenceParameters {
image_resize_to_fit: 0,
..InferenceParameters::default()
};
assert!(params.validate().is_err());
}
}