use std::collections::HashMap;
use serde::{Deserialize, Serialize};
pub use jammi_lora::{BackboneDtype, LoraInitMode};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FineTuneMethod {
Lora,
}
impl std::fmt::Display for FineTuneMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Lora => write!(f, "lora"),
}
}
}
impl std::str::FromStr for FineTuneMethod {
type Err = jammi_db::error::JammiError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"lora" => Ok(Self::Lora),
other => Err(jammi_db::error::JammiError::FineTune(format!(
"Unknown fine-tuning method '{other}'. Supported: lora"
))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EmbeddingLoss {
CoSent,
Triplet { margin: f64 },
MultipleNegativesRanking { temperature: f64 },
AnglE,
CosineMse,
}
impl Default for EmbeddingLoss {
fn default() -> Self {
Self::CoSent
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RegressionLoss {
GaussianNll,
BetaNll {
beta: f64,
},
Crps,
Pinball,
}
impl Default for RegressionLoss {
fn default() -> Self {
Self::BetaNll { beta: 0.5 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ClassificationLoss {
CrossEntropy,
}
impl Default for ClassificationLoss {
fn default() -> Self {
Self::CrossEntropy
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EarlyStoppingMetric {
ValLoss,
TrainLoss,
}
impl Default for EarlyStoppingMetric {
fn default() -> Self {
Self::ValLoss
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LrSchedule {
Constant,
CosineDecay,
LinearDecay,
}
impl Default for LrSchedule {
fn default() -> Self {
Self::CosineDecay
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct HardNegativeConfig {
pub mine: bool,
pub k: usize,
pub exclude_hops: usize,
pub refresh_every: usize,
}
impl Default for HardNegativeConfig {
fn default() -> Self {
Self {
mine: false,
k: 1,
exclude_hops: 1,
refresh_every: 1,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FineTuneConfig {
pub lora_rank: usize,
pub lora_alpha: f64,
pub lora_dropout: f64,
pub learning_rate: f64,
pub epochs: usize,
pub batch_size: usize,
pub max_seq_length: usize,
pub embedding_loss: Option<EmbeddingLoss>,
pub classification_loss: Option<ClassificationLoss>,
#[serde(default)]
pub regression_loss: Option<RegressionLoss>,
#[serde(default)]
pub quantile_levels: Vec<f64>,
pub gradient_accumulation_steps: usize,
pub validation_fraction: f64,
pub early_stopping_patience: usize,
pub warmup_steps: usize,
pub lr_schedule: LrSchedule,
#[serde(default)]
pub early_stopping_metric: EarlyStoppingMetric,
#[serde(default)]
pub target_modules: Vec<String>,
#[serde(default)]
pub layers_to_transform: Option<Vec<usize>>,
#[serde(default)]
pub use_rslora: bool,
#[serde(default)]
pub rank_pattern: HashMap<String, usize>,
#[serde(default)]
pub init_lora_weights: jammi_lora::LoraInitMode,
#[serde(default)]
pub backbone_dtype: jammi_lora::BackboneDtype,
#[serde(default = "default_weight_decay")]
pub weight_decay: f64,
#[serde(default = "default_max_grad_norm")]
pub max_grad_norm: f64,
#[serde(default)]
pub cached: bool,
#[serde(default)]
pub hard_negatives: HardNegativeConfig,
#[serde(default)]
pub matryoshka_dims: Vec<usize>,
#[serde(default = "default_fine_tune_seed")]
pub seed: u64,
}
pub const DEFAULT_FINE_TUNE_SEED: u64 = 42;
fn default_fine_tune_seed() -> u64 {
DEFAULT_FINE_TUNE_SEED
}
fn default_weight_decay() -> f64 {
0.01
}
fn default_max_grad_norm() -> f64 {
1.0
}
impl Default for FineTuneConfig {
fn default() -> Self {
Self {
lora_rank: 8,
lora_alpha: 16.0,
lora_dropout: 0.05,
learning_rate: 2e-4,
epochs: 3,
batch_size: 8,
max_seq_length: 512,
embedding_loss: None,
classification_loss: None,
regression_loss: None,
quantile_levels: Vec::new(),
gradient_accumulation_steps: 1,
validation_fraction: 0.1,
early_stopping_patience: 3,
warmup_steps: 100,
lr_schedule: LrSchedule::CosineDecay,
early_stopping_metric: EarlyStoppingMetric::ValLoss,
target_modules: Vec::new(),
layers_to_transform: None,
use_rslora: false,
rank_pattern: HashMap::new(),
init_lora_weights: jammi_lora::LoraInitMode::ZerosB,
backbone_dtype: jammi_lora::BackboneDtype::F32,
weight_decay: 0.01,
max_grad_norm: 1.0,
cached: false,
hard_negatives: HardNegativeConfig::default(),
matryoshka_dims: Vec::new(),
seed: DEFAULT_FINE_TUNE_SEED,
}
}
}
impl FineTuneConfig {
pub fn validate(&self) -> jammi_db::error::Result<()> {
use jammi_db::error::JammiError;
if self.lora_rank == 0 {
return Err(JammiError::FineTune("lora_rank must be > 0".into()));
}
if self.lora_alpha <= 0.0 {
return Err(JammiError::FineTune("lora_alpha must be > 0".into()));
}
if !(0.0..1.0).contains(&self.lora_dropout) {
return Err(JammiError::FineTune(
"lora_dropout must be in [0.0, 1.0)".into(),
));
}
if self.learning_rate <= 0.0 {
return Err(JammiError::FineTune("learning_rate must be > 0".into()));
}
if self.epochs == 0 {
return Err(JammiError::FineTune("epochs must be > 0".into()));
}
if self.batch_size == 0 {
return Err(JammiError::FineTune("batch_size must be > 0".into()));
}
if self.gradient_accumulation_steps == 0 {
return Err(JammiError::FineTune(
"gradient_accumulation_steps must be > 0".into(),
));
}
if !(0.0..1.0).contains(&self.validation_fraction) {
return Err(JammiError::FineTune(
"validation_fraction must be in [0.0, 1.0)".into(),
));
}
if self.early_stopping_patience == 0 {
return Err(JammiError::FineTune(
"early_stopping_patience must be > 0".into(),
));
}
if self.hard_negatives.mine {
if self.hard_negatives.k == 0 {
return Err(JammiError::FineTune(
"hard_negatives.k must be > 0 when mining is enabled".into(),
));
}
if self.hard_negatives.refresh_every == 0 {
return Err(JammiError::FineTune(
"hard_negatives.refresh_every must be > 0 when mining is enabled".into(),
));
}
}
if self.matryoshka_dims.contains(&0) {
return Err(JammiError::FineTune(
"matryoshka_dims entries must all be > 0".into(),
));
}
if let Some(RegressionLoss::BetaNll { beta }) = self.regression_loss {
if !(0.0..=1.0).contains(&beta) {
return Err(JammiError::FineTune(
"regression_loss BetaNll beta must be in [0.0, 1.0]".into(),
));
}
}
if matches!(self.regression_loss, Some(RegressionLoss::Pinball)) {
if self.quantile_levels.is_empty() {
return Err(JammiError::FineTune(
"Pinball regression loss requires at least one quantile level".into(),
));
}
if self
.quantile_levels
.iter()
.any(|&q| !(0.0..1.0).contains(&q) || q <= 0.0)
{
return Err(JammiError::FineTune(
"quantile_levels must lie strictly in (0, 1)".into(),
));
}
if self.quantile_levels.windows(2).any(|w| w[1] <= w[0]) {
return Err(JammiError::FineTune(
"quantile_levels must be strictly ascending".into(),
));
}
}
Ok(())
}
}