use crate::{
EvalMetrics, LoraConfig, MemoryStats, ModelConfig, Result, StepMetrics, TrainingConfig,
};
use std::path::Path;
#[deprecated(
since = "0.1.0",
note = "Use CausalLMModel from pmetal_models instead."
)]
pub trait PMetalModel: Send + Sync {
type Tensor;
fn forward(
&self,
input_ids: &Self::Tensor,
attention_mask: Option<&Self::Tensor>,
) -> Result<crate::ModelOutput<Self::Tensor>>;
fn config(&self) -> &ModelConfig;
fn trainable_parameters(&self) -> Vec<(&str, Self::Tensor)>;
fn num_parameters(&self) -> usize;
fn num_trainable_parameters(&self) -> usize;
fn apply_lora(&mut self, lora_config: &LoraConfig) -> Result<()>;
fn merge_lora(&mut self) -> Result<()>;
fn memory_footprint(&self) -> MemoryStats;
fn save<P: AsRef<Path>>(&self, path: P) -> Result<()>;
fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<()>;
}
pub trait Dataset: Send + Sync {
type Item;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn get(&self, index: usize) -> Option<Self::Item>;
fn iter(&self) -> impl Iterator<Item = Self::Item>;
}
#[deprecated(
since = "0.1.0",
note = "Use concrete trainers from pmetal_trainer instead."
)]
#[allow(deprecated)]
pub trait Trainer {
type Model: PMetalModel;
type Output;
fn new(model: Self::Model, config: TrainingConfig) -> Result<Self>
where
Self: Sized;
fn train<D: Dataset>(&mut self, dataset: &D) -> Result<Self::Output>;
fn evaluate<D: Dataset>(&self, dataset: &D) -> Result<EvalMetrics>;
fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> Result<()>;
fn load_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> Result<()>;
fn current_step(&self) -> usize;
fn current_loss(&self) -> Option<f64>;
}
pub trait Quantizer {
type Tensor;
type QuantizedTensor;
fn quantize(&self, tensor: &Self::Tensor, block_size: usize) -> Result<Self::QuantizedTensor>;
fn dequantize(&self, quantized: &Self::QuantizedTensor) -> Result<Self::Tensor>;
}
pub trait Optimizer {
type Tensor;
fn step(&mut self, params: &mut [Self::Tensor], grads: &[Self::Tensor]) -> Result<()>;
fn zero_grad(&mut self);
fn learning_rate(&self) -> f64;
fn set_learning_rate(&mut self, lr: f64);
}
pub trait LrScheduler {
fn get_lr(&self, step: usize) -> f64;
fn step(&mut self);
}
pub trait TrainingCallback: Send + Sync {
fn on_train_start(&mut self) {}
fn on_train_end(&mut self) {}
fn on_epoch_start(&mut self, _epoch: usize) {}
fn on_epoch_end(&mut self, _epoch: usize, _metrics: &EvalMetrics) {}
fn on_step_start(&mut self, _step: usize) {}
fn on_step_end(&mut self, _step: usize, _loss: f64) {}
fn on_step_end_with_metrics(&mut self, metrics: &StepMetrics) {
self.on_step_end(metrics.step, metrics.loss);
}
fn on_save(&mut self, _path: &Path) {}
fn on_lr_event(&mut self, _event: &str) {}
fn should_stop(&self) -> bool {
false
}
}
pub trait ConfigValidator {
fn validate(&self) -> Result<()>;
}
pub trait ConfigLoader: Sized + serde::de::DeserializeOwned {
fn from_yaml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
Self::from_yaml(&content)
}
fn from_yaml(yaml: &str) -> Result<Self> {
serde_yaml::from_str(yaml).map_err(|e| crate::PMetalError::Config(e.to_string()))
}
fn from_json_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
Self::from_json(&content)
}
fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json).map_err(|e| crate::PMetalError::Config(e.to_string()))
}
}
pub mod defaults {
pub const LEARNING_RATE: f64 = 2e-4;
pub const EMBEDDING_LR: f64 = 5e-5;
pub const BATCH_SIZE: usize = 4;
pub const EPOCHS: usize = 3;
pub const WARMUP_STEPS: usize = 100;
pub const WEIGHT_DECAY: f64 = 0.01;
pub const MAX_GRAD_NORM: f64 = 1.0;
pub const SEED: u64 = 42;
pub const LOGGING_STEPS: usize = 10;
pub const MAX_SEQ_LEN: usize = 2048;
pub const LORA_R: usize = 16;
pub const LORA_ALPHA: f32 = 32.0;
pub const BETA: f64 = 0.1;
pub const LABEL_SMOOTHING: f64 = 0.0;
pub const RMS_NORM_EPS: f32 = 1e-5;
pub const ROPE_THETA: f32 = 10000.0;
}