pub mod adapters;
pub mod alignment;
pub mod architectures;
pub mod burn_backend;
pub mod burn_modules;
pub mod checkpointing;
pub mod dataset_loader;
pub mod export;
pub mod lr_schedule;
pub mod quantization;
pub mod weight_loader;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use crate::config::{AlignmentMethod, LoraConfig, TrainingHyperparams};
use crate::error::TrainingError;
use crate::types::TrainingProgress;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ComputeDevice {
Cpu,
Gpu {
index: usize,
name: String,
vram_mb: u64,
},
Mps,
}
impl std::fmt::Display for ComputeDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cpu => write!(f, "CPU"),
Self::Gpu {
index,
name,
vram_mb,
} => {
write!(f, "GPU:{} ({}, {}MB VRAM)", index, name, vram_mb)
}
Self::Mps => write!(f, "MPS (Apple Metal)"),
}
}
}
#[derive(Debug, Clone)]
pub struct LocalTrainingConfig {
pub model_path: PathBuf,
pub dataset_path: PathBuf,
pub validation_path: Option<PathBuf>,
pub tokenizer_path: Option<PathBuf>,
pub output_dir: PathBuf,
pub hyperparams: TrainingHyperparams,
pub lora: LoraConfig,
pub alignment: AlignmentMethod,
pub device: ComputeDevice,
pub gradient_checkpointing: bool,
pub mixed_precision: bool,
}
impl LocalTrainingConfig {
pub fn new(
model_path: impl Into<PathBuf>,
dataset_path: impl Into<PathBuf>,
output_dir: impl Into<PathBuf>,
) -> Self {
Self {
model_path: model_path.into(),
dataset_path: dataset_path.into(),
validation_path: None,
tokenizer_path: None,
output_dir: output_dir.into(),
hyperparams: TrainingHyperparams::default(),
lora: LoraConfig::default(),
alignment: AlignmentMethod::None,
device: ComputeDevice::Cpu,
gradient_checkpointing: true,
mixed_precision: false,
}
}
pub fn with_device(mut self, device: ComputeDevice) -> Self {
self.device = device;
self
}
pub fn with_validation(mut self, path: impl Into<PathBuf>) -> Self {
self.validation_path = Some(path.into());
self
}
pub fn with_tokenizer(mut self, path: impl Into<PathBuf>) -> Self {
self.tokenizer_path = Some(path.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainedModelArtifact {
pub model_path: PathBuf,
pub format: String,
pub base_model: String,
pub metrics: crate::types::TrainingMetrics,
pub lora_config: Option<LoraConfig>,
}
pub trait TrainingBackend: Send + Sync {
fn name(&self) -> &str;
fn available_devices(&self) -> Vec<ComputeDevice>;
fn train(
&self,
config: LocalTrainingConfig,
callback: Box<dyn Fn(TrainingProgress) + Send>,
) -> Result<TrainedModelArtifact, TrainingError>;
}
pub use burn_backend::BurnBackend;