1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10pub use super::training_engine::{run_lora_training, TrainingPreset};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TrainingRun {
20 pub id: String,
21 pub dataset_id: String,
22 pub method: TrainingMethod,
23 pub config: TrainingConfig,
24 pub status: TrainingStatus,
25 pub created_at: u64,
26 pub metrics: Vec<TrainingMetric>,
27 #[serde(default)]
30 pub simulated: bool,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub export_path: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub error: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum TrainingMethod {
41 Lora,
42 Qlora,
43 FullFinetune,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48#[serde(rename_all = "snake_case")]
49pub enum OptimizerType {
50 Adam,
51 AdamW,
52 Sgd,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub enum SchedulerType {
59 Constant,
60 Cosine,
61 Linear,
62 StepDecay,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct TrainingConfig {
68 #[serde(default = "default_lora_r")]
69 pub lora_r: u32,
70 #[serde(default = "default_lora_alpha")]
71 pub lora_alpha: u32,
72 #[serde(default = "default_learning_rate")]
73 pub learning_rate: f64,
74 #[serde(default = "default_epochs")]
75 pub epochs: u32,
76 #[serde(default = "default_batch_size")]
77 pub batch_size: u32,
78 #[serde(default = "default_max_seq_length")]
79 pub max_seq_length: u32,
80 #[serde(default)]
81 pub target_modules: Vec<String>,
82 #[serde(default = "default_optimizer")]
83 pub optimizer: OptimizerType,
84 #[serde(default = "default_scheduler")]
85 pub scheduler: SchedulerType,
86 #[serde(default = "default_warmup_steps")]
87 pub warmup_steps: u32,
88 #[serde(default = "default_grad_accum")]
89 pub gradient_accumulation_steps: u32,
90 #[serde(default = "default_max_grad_norm")]
91 pub max_grad_norm: f64,
92}
93
94fn default_lora_r() -> u32 {
95 16
96}
97fn default_lora_alpha() -> u32 {
98 32
99}
100fn default_learning_rate() -> f64 {
101 2e-4
102}
103fn default_epochs() -> u32 {
104 3
105}
106fn default_batch_size() -> u32 {
107 4
108}
109fn default_max_seq_length() -> u32 {
110 2048
111}
112fn default_optimizer() -> OptimizerType {
113 OptimizerType::AdamW
114}
115fn default_scheduler() -> SchedulerType {
116 SchedulerType::Cosine
117}
118fn default_warmup_steps() -> u32 {
119 100
120}
121fn default_grad_accum() -> u32 {
122 4
123}
124fn default_max_grad_norm() -> f64 {
125 1.0
126}
127
128impl Default for TrainingConfig {
129 fn default() -> Self {
130 Self {
131 lora_r: default_lora_r(),
132 lora_alpha: default_lora_alpha(),
133 learning_rate: default_learning_rate(),
134 epochs: default_epochs(),
135 batch_size: default_batch_size(),
136 max_seq_length: default_max_seq_length(),
137 target_modules: vec![
138 "q_proj".into(),
139 "k_proj".into(),
140 "v_proj".into(),
141 "o_proj".into(),
142 ],
143 optimizer: default_optimizer(),
144 scheduler: default_scheduler(),
145 warmup_steps: default_warmup_steps(),
146 gradient_accumulation_steps: default_grad_accum(),
147 max_grad_norm: default_max_grad_norm(),
148 }
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
154#[serde(rename_all = "snake_case")]
155pub enum TrainingStatus {
156 Queued,
157 Running,
158 Complete,
159 Failed,
160 Stopped,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct TrainingMetric {
166 pub step: u64,
167 pub loss: f32,
168 pub learning_rate: f64,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 pub grad_norm: Option<f32>,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 pub tokens_per_sec: Option<u64>,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 pub eta_secs: Option<u64>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
179#[serde(rename_all = "snake_case")]
180pub enum ExportFormat {
181 Safetensors,
182 Gguf,
183 Apr,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct ExportRequest {
189 #[serde(default = "default_export_format")]
190 pub format: ExportFormat,
191 #[serde(default)]
192 pub merge: bool,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 pub quantization: Option<String>,
195}
196
197fn default_export_format() -> ExportFormat {
198 ExportFormat::Safetensors
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct ExportResult {
204 pub run_id: String,
205 pub format: ExportFormat,
206 pub merged: bool,
207 pub path: String,
208 pub size_bytes: u64,
209}
210
211pub struct TrainingStore {
217 runs: RwLock<HashMap<String, TrainingRun>>,
218 counter: std::sync::atomic::AtomicU64,
219}
220
221impl TrainingStore {
222 #[must_use]
223 pub fn new() -> Arc<Self> {
224 Arc::new(Self {
225 runs: RwLock::new(HashMap::new()),
226 counter: std::sync::atomic::AtomicU64::new(0),
227 })
228 }
229
230 pub fn start(
232 &self,
233 dataset_id: &str,
234 method: TrainingMethod,
235 config: TrainingConfig,
236 ) -> TrainingRun {
237 let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
238 let run = TrainingRun {
239 id: format!("run-{}-{seq}", epoch_secs()),
240 dataset_id: dataset_id.to_string(),
241 method,
242 config,
243 status: TrainingStatus::Queued,
244 created_at: epoch_secs(),
245 metrics: Vec::new(),
246 simulated: true, export_path: None,
248 error: None,
249 };
250 if let Ok(mut store) = self.runs.write() {
251 store.insert(run.id.clone(), run.clone());
252 }
253 run
254 }
255
256 pub fn push_metric(&self, run_id: &str, metric: TrainingMetric) {
258 if let Ok(mut store) = self.runs.write() {
259 if let Some(run) = store.get_mut(run_id) {
260 run.metrics.push(metric);
261 }
262 }
263 }
264
265 pub fn set_status(&self, run_id: &str, status: TrainingStatus) {
267 if let Ok(mut store) = self.runs.write() {
268 if let Some(run) = store.get_mut(run_id) {
269 run.status = status;
270 }
271 }
272 }
273
274 pub fn fail(&self, run_id: &str, error: &str) {
276 if let Ok(mut store) = self.runs.write() {
277 if let Some(run) = store.get_mut(run_id) {
278 run.status = TrainingStatus::Failed;
279 run.error = Some(error.to_string());
280 }
281 }
282 }
283
284 pub fn set_export_path(&self, run_id: &str, path: &str) {
286 if let Ok(mut store) = self.runs.write() {
287 if let Some(run) = store.get_mut(run_id) {
288 run.export_path = Some(path.to_string());
289 }
290 }
291 }
292
293 #[must_use]
295 pub fn list(&self) -> Vec<TrainingRun> {
296 let store = self.runs.read().unwrap_or_else(|e| e.into_inner());
297 let mut runs: Vec<TrainingRun> = store.values().cloned().collect();
298 runs.sort_by(|a, b| b.created_at.cmp(&a.created_at));
299 runs
300 }
301
302 #[must_use]
304 pub fn get(&self, id: &str) -> Option<TrainingRun> {
305 self.runs.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
306 }
307
308 pub fn stop(&self, id: &str) -> Result<(), TrainingError> {
310 let mut store = self.runs.write().map_err(|_| TrainingError::LockPoisoned)?;
311 let run = store.get_mut(id).ok_or(TrainingError::NotFound(id.to_string()))?;
312 run.status = TrainingStatus::Stopped;
313 Ok(())
314 }
315
316 pub fn delete(&self, id: &str) -> Result<(), TrainingError> {
318 let mut store = self.runs.write().map_err(|_| TrainingError::LockPoisoned)?;
319 store.remove(id).ok_or(TrainingError::NotFound(id.to_string()))?;
320 Ok(())
321 }
322}
323
324#[derive(Debug, Clone, PartialEq, Eq)]
330pub enum TrainingError {
331 NotFound(String),
332 NoModel,
333 NoDataset(String),
334 LockPoisoned,
335}
336
337impl std::fmt::Display for TrainingError {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 match self {
340 Self::NotFound(id) => write!(f, "Training run not found: {id}"),
341 Self::NoModel => write!(f, "No model loaded — load a model first"),
342 Self::NoDataset(id) => write!(f, "Dataset not found: {id}"),
343 Self::LockPoisoned => write!(f, "Internal lock error"),
344 }
345 }
346}
347
348impl std::error::Error for TrainingError {}
349
350fn epoch_secs() -> u64 {
351 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
352}