use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub use super::training_engine::{run_lora_training, TrainingPreset};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingRun {
pub id: String,
pub dataset_id: String,
pub method: TrainingMethod,
pub config: TrainingConfig,
pub status: TrainingStatus,
pub created_at: u64,
pub metrics: Vec<TrainingMetric>,
#[serde(default)]
pub simulated: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub export_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TrainingMethod {
Lora,
Qlora,
FullFinetune,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OptimizerType {
Adam,
AdamW,
Sgd,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerType {
Constant,
Cosine,
Linear,
StepDecay,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
#[serde(default = "default_lora_r")]
pub lora_r: u32,
#[serde(default = "default_lora_alpha")]
pub lora_alpha: u32,
#[serde(default = "default_learning_rate")]
pub learning_rate: f64,
#[serde(default = "default_epochs")]
pub epochs: u32,
#[serde(default = "default_batch_size")]
pub batch_size: u32,
#[serde(default = "default_max_seq_length")]
pub max_seq_length: u32,
#[serde(default)]
pub target_modules: Vec<String>,
#[serde(default = "default_optimizer")]
pub optimizer: OptimizerType,
#[serde(default = "default_scheduler")]
pub scheduler: SchedulerType,
#[serde(default = "default_warmup_steps")]
pub warmup_steps: u32,
#[serde(default = "default_grad_accum")]
pub gradient_accumulation_steps: u32,
#[serde(default = "default_max_grad_norm")]
pub max_grad_norm: f64,
}
fn default_lora_r() -> u32 {
16
}
fn default_lora_alpha() -> u32 {
32
}
fn default_learning_rate() -> f64 {
2e-4
}
fn default_epochs() -> u32 {
3
}
fn default_batch_size() -> u32 {
4
}
fn default_max_seq_length() -> u32 {
2048
}
fn default_optimizer() -> OptimizerType {
OptimizerType::AdamW
}
fn default_scheduler() -> SchedulerType {
SchedulerType::Cosine
}
fn default_warmup_steps() -> u32 {
100
}
fn default_grad_accum() -> u32 {
4
}
fn default_max_grad_norm() -> f64 {
1.0
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
lora_r: default_lora_r(),
lora_alpha: default_lora_alpha(),
learning_rate: default_learning_rate(),
epochs: default_epochs(),
batch_size: default_batch_size(),
max_seq_length: default_max_seq_length(),
target_modules: vec![
"q_proj".into(),
"k_proj".into(),
"v_proj".into(),
"o_proj".into(),
],
optimizer: default_optimizer(),
scheduler: default_scheduler(),
warmup_steps: default_warmup_steps(),
gradient_accumulation_steps: default_grad_accum(),
max_grad_norm: default_max_grad_norm(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TrainingStatus {
Queued,
Running,
Complete,
Failed,
Stopped,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetric {
pub step: u64,
pub loss: f32,
pub learning_rate: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub grad_norm: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens_per_sec: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub eta_secs: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ExportFormat {
Safetensors,
Gguf,
Apr,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportRequest {
#[serde(default = "default_export_format")]
pub format: ExportFormat,
#[serde(default)]
pub merge: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantization: Option<String>,
}
fn default_export_format() -> ExportFormat {
ExportFormat::Safetensors
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportResult {
pub run_id: String,
pub format: ExportFormat,
pub merged: bool,
pub path: String,
pub size_bytes: u64,
}
pub struct TrainingStore {
runs: RwLock<HashMap<String, TrainingRun>>,
counter: std::sync::atomic::AtomicU64,
}
impl TrainingStore {
#[must_use]
pub fn new() -> Arc<Self> {
Arc::new(Self {
runs: RwLock::new(HashMap::new()),
counter: std::sync::atomic::AtomicU64::new(0),
})
}
pub fn start(
&self,
dataset_id: &str,
method: TrainingMethod,
config: TrainingConfig,
) -> TrainingRun {
let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let run = TrainingRun {
id: format!("run-{}-{seq}", epoch_secs()),
dataset_id: dataset_id.to_string(),
method,
config,
status: TrainingStatus::Queued,
created_at: epoch_secs(),
metrics: Vec::new(),
simulated: true, export_path: None,
error: None,
};
if let Ok(mut store) = self.runs.write() {
store.insert(run.id.clone(), run.clone());
}
run
}
pub fn push_metric(&self, run_id: &str, metric: TrainingMetric) {
if let Ok(mut store) = self.runs.write() {
if let Some(run) = store.get_mut(run_id) {
run.metrics.push(metric);
}
}
}
pub fn set_status(&self, run_id: &str, status: TrainingStatus) {
if let Ok(mut store) = self.runs.write() {
if let Some(run) = store.get_mut(run_id) {
run.status = status;
}
}
}
pub fn fail(&self, run_id: &str, error: &str) {
if let Ok(mut store) = self.runs.write() {
if let Some(run) = store.get_mut(run_id) {
run.status = TrainingStatus::Failed;
run.error = Some(error.to_string());
}
}
}
pub fn set_export_path(&self, run_id: &str, path: &str) {
if let Ok(mut store) = self.runs.write() {
if let Some(run) = store.get_mut(run_id) {
run.export_path = Some(path.to_string());
}
}
}
#[must_use]
pub fn list(&self) -> Vec<TrainingRun> {
let store = self.runs.read().unwrap_or_else(|e| e.into_inner());
let mut runs: Vec<TrainingRun> = store.values().cloned().collect();
runs.sort_by(|a, b| b.created_at.cmp(&a.created_at));
runs
}
#[must_use]
pub fn get(&self, id: &str) -> Option<TrainingRun> {
self.runs.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
}
pub fn stop(&self, id: &str) -> Result<(), TrainingError> {
let mut store = self.runs.write().map_err(|_| TrainingError::LockPoisoned)?;
let run = store.get_mut(id).ok_or(TrainingError::NotFound(id.to_string()))?;
run.status = TrainingStatus::Stopped;
Ok(())
}
pub fn delete(&self, id: &str) -> Result<(), TrainingError> {
let mut store = self.runs.write().map_err(|_| TrainingError::LockPoisoned)?;
store.remove(id).ok_or(TrainingError::NotFound(id.to_string()))?;
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrainingError {
NotFound(String),
NoModel,
NoDataset(String),
LockPoisoned,
}
impl std::fmt::Display for TrainingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound(id) => write!(f, "Training run not found: {id}"),
Self::NoModel => write!(f, "No model loaded — load a model first"),
Self::NoDataset(id) => write!(f, "Dataset not found: {id}"),
Self::LockPoisoned => write!(f, "Internal lock error"),
}
}
}
impl std::error::Error for TrainingError {}
fn epoch_secs() -> u64 {
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
}