use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::{FerrumError, Result, TokenId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingParams {
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub top_k: Option<usize>,
pub repetition_penalty: f32,
pub presence_penalty: f32,
pub frequency_penalty: f32,
pub stop_sequences: Vec<String>,
pub seed: Option<u64>,
pub min_p: Option<f32>,
pub tfs: Option<f32>,
pub typical_p: Option<f32>,
pub mirostat: Option<MirostatParams>,
#[serde(default)]
pub response_format: ResponseFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", content = "schema")]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema(String),
}
impl Default for ResponseFormat {
fn default() -> Self {
Self::Text
}
}
impl Default for SamplingParams {
fn default() -> Self {
Self {
max_tokens: 512,
temperature: 1.0,
top_p: 1.0,
top_k: None,
repetition_penalty: 1.0,
presence_penalty: 0.0,
frequency_penalty: 0.0,
stop_sequences: vec![],
seed: None,
min_p: None,
tfs: None,
typical_p: None,
mirostat: None,
response_format: ResponseFormat::default(),
}
}
}
impl SamplingParams {
pub fn greedy() -> Self {
Self {
temperature: 0.0,
top_p: 1.0,
top_k: None,
..Default::default()
}
}
pub fn with_temperature(temperature: f32) -> Self {
Self {
temperature,
..Default::default()
}
}
pub fn validate(&self) -> Result<()> {
if self.temperature < 0.0 {
return Err(FerrumError::invalid_request(
"Temperature must be non-negative".to_string(),
));
}
if self.top_p <= 0.0 || self.top_p > 1.0 {
return Err(FerrumError::invalid_request(
"top_p must be in range (0, 1]".to_string(),
));
}
if let Some(top_k) = self.top_k {
if top_k == 0 {
return Err(FerrumError::invalid_request(
"top_k must be positive".to_string(),
));
}
}
if self.repetition_penalty <= 0.0 {
return Err(FerrumError::invalid_request(
"Repetition penalty must be positive".to_string(),
));
}
if let Some(min_p) = self.min_p {
if min_p <= 0.0 || min_p > 1.0 {
return Err(FerrumError::invalid_request(
"min_p must be in range (0, 1]".to_string(),
));
}
}
if let Some(tfs) = self.tfs {
if tfs <= 0.0 || tfs > 1.0 {
return Err(FerrumError::invalid_request(
"tfs must be in range (0, 1]".to_string(),
));
}
}
if let Some(typical_p) = self.typical_p {
if typical_p <= 0.0 || typical_p > 1.0 {
return Err(FerrumError::invalid_request(
"typical_p must be in range (0, 1]".to_string(),
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MirostatParams {
pub mode: u8,
pub tau: f32,
pub eta: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingPresets {
pub presets: HashMap<String, SamplingParams>,
}
impl Default for SamplingPresets {
fn default() -> Self {
let mut presets = HashMap::new();
presets.insert("greedy".to_string(), SamplingParams::greedy());
presets.insert(
"creative".to_string(),
SamplingParams {
temperature: 1.2,
top_p: 0.9,
top_k: Some(50),
repetition_penalty: 1.1,
..Default::default()
},
);
presets.insert(
"precise".to_string(),
SamplingParams {
temperature: 0.3,
top_p: 0.95,
top_k: Some(20),
repetition_penalty: 1.05,
..Default::default()
},
);
Self { presets }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
impl Default for Priority {
fn default() -> Self {
Priority::Normal
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FinishReason {
Length,
Stop,
EOS,
Cancelled,
Error,
ContentFilter,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpecialTokens {
pub bos_token: Option<TokenId>,
pub eos_token: Option<TokenId>,
pub unk_token: Option<TokenId>,
pub pad_token: Option<TokenId>,
pub sep_token: Option<TokenId>,
pub cls_token: Option<TokenId>,
pub mask_token: Option<TokenId>,
}
impl Default for SpecialTokens {
fn default() -> Self {
Self {
bos_token: None,
eos_token: None,
unk_token: None,
pad_token: None,
sep_token: None,
cls_token: None,
mask_token: None,
}
}
}