pmetal-core 0.1.2

Core types, traits, and configuration for PMetal LLM fine-tuning
Documentation
//! Core trait definitions.
//!
//! Note: Some traits in this module are deprecated. See individual trait documentation
//! for current recommended patterns.

use crate::{EvalMetrics, LoraConfig, MemoryStats, ModelConfig, Result, TrainingConfig};
use std::path::Path;

/// Core model trait for all LLM architectures.
///
/// # Deprecated
///
/// This trait is deprecated and not implemented by any models.
/// Use the following traits from `pmetal_models` instead:
/// - [`CausalLMModel`] - For forward pass and generation
/// - [`LoraCapable`] - For LoRA adapter management
/// - [`Quantizable`] - For model quantization
///
/// For trainable models, use [`TrainableModel`] from `pmetal_lora`.
///
/// [`CausalLMModel`]: pmetal_models::CausalLMModel
/// [`LoraCapable`]: pmetal_models::LoraCapable
/// [`Quantizable`]: pmetal_models::Quantizable
/// [`TrainableModel`]: pmetal_lora::TrainableModel
#[deprecated(
    since = "0.1.0",
    note = "Use CausalLMModel from pmetal_models instead."
)]
pub trait PMetalModel: Send + Sync {
    /// The tensor type used by this model.
    type Tensor;

    /// Run forward pass.
    ///
    /// # Arguments
    /// * `input_ids` - Token IDs tensor of shape `[batch_size, seq_len]`
    /// * `attention_mask` - Optional attention mask tensor
    ///
    /// # Returns
    /// Model output containing logits and optional hidden states.
    fn forward(
        &self,
        input_ids: &Self::Tensor,
        attention_mask: Option<&Self::Tensor>,
    ) -> Result<crate::ModelOutput<Self::Tensor>>;

    /// Get model configuration.
    fn config(&self) -> &ModelConfig;

    /// Get all trainable parameters with their names.
    fn trainable_parameters(&self) -> Vec<(&str, Self::Tensor)>;

    /// Get total number of parameters.
    fn num_parameters(&self) -> usize;

    /// Get number of trainable parameters.
    fn num_trainable_parameters(&self) -> usize;

    /// Apply LoRA adapters to the model.
    ///
    /// # Arguments
    /// * `lora_config` - Configuration for LoRA adaptation
    fn apply_lora(&mut self, lora_config: &LoraConfig) -> Result<()>;

    /// Merge LoRA weights into base model.
    fn merge_lora(&mut self) -> Result<()>;

    /// Get memory footprint statistics.
    fn memory_footprint(&self) -> MemoryStats;

    /// Save model to disk.
    fn save<P: AsRef<Path>>(&self, path: P) -> Result<()>;

    /// Load model from disk.
    fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<()>;
}

/// Dataset trait for training data.
pub trait Dataset: Send + Sync {
    /// The item type yielded by this dataset.
    type Item;

    /// Get the number of samples in the dataset.
    fn len(&self) -> usize;

    /// Check if the dataset is empty.
    fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Get a sample by index.
    fn get(&self, index: usize) -> Option<Self::Item>;

    /// Get an iterator over the dataset.
    fn iter(&self) -> impl Iterator<Item = Self::Item>;
}

/// Trainer trait for different training methods.
///
/// # Deprecated
///
/// This trait is deprecated and not implemented by any trainers.
/// Use the concrete trainer implementations from `pmetal_trainer` instead:
/// - [`SftTrainer`] - For supervised fine-tuning
/// - [`LoraTrainer`] - For LoRA fine-tuning
/// - [`DpoTrainer`] - For Direct Preference Optimization
/// - [`GrpoTrainer`] - For Group Relative Policy Optimization
/// - [`TrainingLoop`] - For the main training loop abstraction
///
/// [`SftTrainer`]: pmetal_trainer::SftTrainer
/// [`LoraTrainer`]: pmetal_trainer::LoraTrainer
/// [`DpoTrainer`]: pmetal_trainer::DpoTrainer
/// [`GrpoTrainer`]: pmetal_trainer::GrpoTrainer
/// [`TrainingLoop`]: pmetal_trainer::TrainingLoop
#[deprecated(
    since = "0.1.0",
    note = "Use concrete trainers from pmetal_trainer instead."
)]
#[allow(deprecated)]
pub trait Trainer {
    /// The model type this trainer works with.
    type Model: PMetalModel;

    /// Training output/result type.
    type Output;

    /// Create a new trainer with the given model and configuration.
    fn new(model: Self::Model, config: TrainingConfig) -> Result<Self>
    where
        Self: Sized;

    /// Run training on the given dataset.
    fn train<D: Dataset>(&mut self, dataset: &D) -> Result<Self::Output>;

    /// Evaluate the model on a dataset.
    fn evaluate<D: Dataset>(&self, dataset: &D) -> Result<EvalMetrics>;

    /// Save a training checkpoint.
    fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> Result<()>;

    /// Load a training checkpoint.
    fn load_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> Result<()>;

    /// Get the current training step.
    fn current_step(&self) -> usize;

    /// Get the current loss value.
    fn current_loss(&self) -> Option<f64>;
}

/// Quantizer trait for different quantization schemes.
pub trait Quantizer {
    /// The tensor type used by this quantizer.
    type Tensor;

    /// The quantized tensor type.
    type QuantizedTensor;

    /// Quantize a tensor.
    ///
    /// # Arguments
    /// * `tensor` - The tensor to quantize
    /// * `block_size` - Block size for blockwise quantization
    fn quantize(&self, tensor: &Self::Tensor, block_size: usize) -> Result<Self::QuantizedTensor>;

    /// Dequantize a tensor back to full precision.
    fn dequantize(&self, quantized: &Self::QuantizedTensor) -> Result<Self::Tensor>;
}

/// Optimizer trait.
pub trait Optimizer {
    /// The tensor type used by this optimizer.
    type Tensor;

    /// Update parameters using gradients.
    fn step(&mut self, params: &mut [Self::Tensor], grads: &[Self::Tensor]) -> Result<()>;

    /// Zero all gradients.
    fn zero_grad(&mut self);

    /// Get current learning rate.
    fn learning_rate(&self) -> f64;

    /// Set learning rate.
    fn set_learning_rate(&mut self, lr: f64);
}

/// Learning rate scheduler trait.
pub trait LrScheduler {
    /// Get learning rate for the given step.
    fn get_lr(&self, step: usize) -> f64;

    /// Update scheduler state after a step.
    fn step(&mut self);
}

/// Callback trait for training events.
pub trait TrainingCallback: Send + Sync {
    /// Called at the start of training.
    fn on_train_start(&mut self) {}

    /// Called at the end of training.
    fn on_train_end(&mut self) {}

    /// Called at the start of each epoch.
    fn on_epoch_start(&mut self, _epoch: usize) {}

    /// Called at the end of each epoch.
    fn on_epoch_end(&mut self, _epoch: usize, _metrics: &EvalMetrics) {}

    /// Called at the start of each step.
    fn on_step_start(&mut self, _step: usize) {}

    /// Called at the end of each step.
    fn on_step_end(&mut self, _step: usize, _loss: f64) {}

    /// Called when a checkpoint is saved.
    fn on_save(&mut self, _path: &Path) {}
}

// ============================================================================
// Configuration Traits
// ============================================================================

/// Trait for configuration validation.
///
/// Implement this trait on configuration structs to provide consistent
/// validation across all crates.
pub trait ConfigValidator {
    /// Validate the configuration.
    ///
    /// Returns `Ok(())` if valid, or an error describing the validation failure.
    fn validate(&self) -> Result<()>;
}

/// Trait for loading configurations from files.
///
/// Provides default implementations for YAML and JSON loading.
pub trait ConfigLoader: Sized + serde::de::DeserializeOwned {
    /// Load configuration from a YAML file.
    fn from_yaml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
        let content = std::fs::read_to_string(path)?;
        Self::from_yaml(&content)
    }

    /// Load configuration from a YAML string.
    fn from_yaml(yaml: &str) -> Result<Self> {
        serde_yaml::from_str(yaml).map_err(|e| crate::PMetalError::Config(e.to_string()))
    }

    /// Load configuration from a JSON file.
    fn from_json_file<P: AsRef<Path>>(path: P) -> Result<Self> {
        let content = std::fs::read_to_string(path)?;
        Self::from_json(&content)
    }

    /// Load configuration from a JSON string.
    fn from_json(json: &str) -> Result<Self> {
        serde_json::from_str(json).map_err(|e| crate::PMetalError::Config(e.to_string()))
    }
}

// ============================================================================
// Default Value Helpers
// ============================================================================

/// Default training hyperparameters used across crates.
///
/// These constants provide consistent defaults and reduce duplication.
pub mod defaults {
    /// Default learning rate for LoRA training.
    pub const LEARNING_RATE: f64 = 2e-4;
    /// Default learning rate for embeddings (lower than base).
    pub const EMBEDDING_LR: f64 = 5e-5;
    /// Default batch size.
    pub const BATCH_SIZE: usize = 4;
    /// Default number of epochs.
    pub const EPOCHS: usize = 3;
    /// Default warmup steps.
    pub const WARMUP_STEPS: usize = 100;
    /// Default weight decay.
    pub const WEIGHT_DECAY: f64 = 0.01;
    /// Default gradient clipping norm.
    pub const MAX_GRAD_NORM: f64 = 1.0;
    /// Default random seed.
    pub const SEED: u64 = 42;
    /// Default logging frequency.
    pub const LOGGING_STEPS: usize = 10;
    /// Default maximum sequence length.
    pub const MAX_SEQ_LEN: usize = 2048;
    /// Default LoRA rank.
    pub const LORA_R: usize = 16;
    /// Default LoRA alpha.
    pub const LORA_ALPHA: f32 = 32.0;
    /// Default temperature for preference learning.
    pub const BETA: f64 = 0.1;
    /// Default label smoothing.
    pub const LABEL_SMOOTHING: f64 = 0.0;
    /// Default RMS norm epsilon.
    pub const RMS_NORM_EPS: f32 = 1e-5;
    /// Default RoPE theta.
    pub const ROPE_THETA: f32 = 10000.0;
}