brainwires_training/
types.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
5pub struct TrainingJobId(pub String);
6
7impl std::fmt::Display for TrainingJobId {
8 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9 write!(f, "{}", self.0)
10 }
11}
12
13impl<S: Into<String>> From<S> for TrainingJobId {
14 fn from(s: S) -> Self {
15 Self(s.into())
16 }
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub struct DatasetId(pub String);
22
23impl std::fmt::Display for DatasetId {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(f, "{}", self.0)
26 }
27}
28
29impl<S: Into<String>> From<S> for DatasetId {
30 fn from(s: S) -> Self {
31 Self(s.into())
32 }
33}
34
35impl DatasetId {
36 pub fn from_s3_uri(uri: &str) -> Self {
38 Self(uri.to_string())
39 }
40
41 pub fn from_gcs_uri(uri: &str) -> Self {
43 Self(uri.to_string())
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(tag = "status", rename_all = "snake_case")]
50pub enum TrainingJobStatus {
51 Pending,
53 Validating,
55 Queued,
57 Running {
59 progress: TrainingProgress,
61 },
62 Succeeded {
64 model_id: String,
66 },
67 Failed {
69 error: String,
71 },
72 Cancelled,
74}
75
76impl TrainingJobStatus {
77 pub fn is_terminal(&self) -> bool {
79 matches!(
80 self,
81 Self::Succeeded { .. } | Self::Failed { .. } | Self::Cancelled
82 )
83 }
84
85 pub fn is_running(&self) -> bool {
87 matches!(self, Self::Running { .. })
88 }
89
90 pub fn is_succeeded(&self) -> bool {
92 matches!(self, Self::Succeeded { .. })
93 }
94}
95
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct TrainingProgress {
99 pub epoch: u32,
101 pub total_epochs: u32,
103 pub step: u64,
105 pub total_steps: u64,
107 pub train_loss: Option<f64>,
109 pub eval_loss: Option<f64>,
111 pub learning_rate: Option<f64>,
113 pub elapsed_secs: u64,
115}
116
117impl TrainingProgress {
118 pub fn completion_fraction(&self) -> f64 {
120 if self.total_steps == 0 {
121 return 0.0;
122 }
123 self.step as f64 / self.total_steps as f64
124 }
125}
126
127#[derive(Debug, Clone, Default, Serialize, Deserialize)]
129pub struct TrainingMetrics {
130 pub final_train_loss: Option<f64>,
132 pub final_eval_loss: Option<f64>,
134 pub total_steps: u64,
136 pub total_epochs: u32,
138 pub total_tokens_trained: Option<u64>,
140 pub duration_secs: u64,
142 pub estimated_cost_usd: Option<f64>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TrainingJobSummary {
149 pub job_id: TrainingJobId,
151 pub provider: String,
153 pub base_model: String,
155 pub status: TrainingJobStatus,
157 pub created_at: chrono::DateTime<chrono::Utc>,
159 pub metrics: Option<TrainingMetrics>,
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_job_status_terminal() {
169 assert!(!TrainingJobStatus::Pending.is_terminal());
170 assert!(!TrainingJobStatus::Queued.is_terminal());
171 assert!(
172 TrainingJobStatus::Succeeded {
173 model_id: "m".into()
174 }
175 .is_terminal()
176 );
177 assert!(
178 TrainingJobStatus::Failed {
179 error: "err".into()
180 }
181 .is_terminal()
182 );
183 assert!(TrainingJobStatus::Cancelled.is_terminal());
184 }
185
186 #[test]
187 fn test_progress_completion() {
188 let p = TrainingProgress {
189 step: 50,
190 total_steps: 100,
191 ..Default::default()
192 };
193 assert!((p.completion_fraction() - 0.5).abs() < f64::EPSILON);
194 }
195
196 #[test]
197 fn test_job_id_from_string() {
198 let id: TrainingJobId = "ft-abc123".into();
199 assert_eq!(id.0, "ft-abc123");
200 assert_eq!(id.to_string(), "ft-abc123");
201 }
202}