Skip to main content

brainwires_training/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Unique identifier for a training job.
4#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
5pub struct TrainingJobId(pub String);
6
7impl std::fmt::Display for TrainingJobId {
8    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9        write!(f, "{}", self.0)
10    }
11}
12
13impl<S: Into<String>> From<S> for TrainingJobId {
14    fn from(s: S) -> Self {
15        Self(s.into())
16    }
17}
18
19/// Unique identifier for an uploaded dataset.
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub struct DatasetId(pub String);
22
23impl std::fmt::Display for DatasetId {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(f, "{}", self.0)
26    }
27}
28
29impl<S: Into<String>> From<S> for DatasetId {
30    fn from(s: S) -> Self {
31        Self(s.into())
32    }
33}
34
35impl DatasetId {
36    /// Create a DatasetId from an S3 URI (for Bedrock).
37    pub fn from_s3_uri(uri: &str) -> Self {
38        Self(uri.to_string())
39    }
40
41    /// Create a DatasetId from a GCS URI (for Vertex AI).
42    pub fn from_gcs_uri(uri: &str) -> Self {
43        Self(uri.to_string())
44    }
45}
46
47/// Status of a training job.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(tag = "status", rename_all = "snake_case")]
50pub enum TrainingJobStatus {
51    /// Job is pending.
52    Pending,
53    /// Job is validating inputs.
54    Validating,
55    /// Job is queued for execution.
56    Queued,
57    /// Job is actively running.
58    Running {
59        /// Current training progress.
60        progress: TrainingProgress,
61    },
62    /// Job completed successfully.
63    Succeeded {
64        /// ID of the fine-tuned model.
65        model_id: String,
66    },
67    /// Job failed with an error.
68    Failed {
69        /// Error description.
70        error: String,
71    },
72    /// Job was cancelled.
73    Cancelled,
74}
75
76impl TrainingJobStatus {
77    /// Whether the job has reached a terminal state.
78    pub fn is_terminal(&self) -> bool {
79        matches!(
80            self,
81            Self::Succeeded { .. } | Self::Failed { .. } | Self::Cancelled
82        )
83    }
84
85    /// Whether the job is currently running.
86    pub fn is_running(&self) -> bool {
87        matches!(self, Self::Running { .. })
88    }
89
90    /// Whether the job succeeded.
91    pub fn is_succeeded(&self) -> bool {
92        matches!(self, Self::Succeeded { .. })
93    }
94}
95
96/// Progress information for a running training job.
97#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct TrainingProgress {
99    /// Current epoch.
100    pub epoch: u32,
101    /// Total number of epochs.
102    pub total_epochs: u32,
103    /// Current training step.
104    pub step: u64,
105    /// Total training steps.
106    pub total_steps: u64,
107    /// Training loss.
108    pub train_loss: Option<f64>,
109    /// Evaluation loss.
110    pub eval_loss: Option<f64>,
111    /// Current learning rate.
112    pub learning_rate: Option<f64>,
113    /// Elapsed time in seconds.
114    pub elapsed_secs: u64,
115}
116
117impl TrainingProgress {
118    /// Fraction of training completed (0.0-1.0).
119    pub fn completion_fraction(&self) -> f64 {
120        if self.total_steps == 0 {
121            return 0.0;
122        }
123        self.step as f64 / self.total_steps as f64
124    }
125}
126
127/// Metrics from a completed training job.
128#[derive(Debug, Clone, Default, Serialize, Deserialize)]
129pub struct TrainingMetrics {
130    /// Final training loss.
131    pub final_train_loss: Option<f64>,
132    /// Final evaluation loss.
133    pub final_eval_loss: Option<f64>,
134    /// Total training steps completed.
135    pub total_steps: u64,
136    /// Total epochs completed.
137    pub total_epochs: u32,
138    /// Total tokens processed.
139    pub total_tokens_trained: Option<u64>,
140    /// Total training duration in seconds.
141    pub duration_secs: u64,
142    /// Estimated cost in USD.
143    pub estimated_cost_usd: Option<f64>,
144}
145
146/// Summary of a training job for listing.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TrainingJobSummary {
149    /// Job identifier.
150    pub job_id: TrainingJobId,
151    /// Provider name.
152    pub provider: String,
153    /// Base model being fine-tuned.
154    pub base_model: String,
155    /// Current job status.
156    pub status: TrainingJobStatus,
157    /// Job creation time.
158    pub created_at: chrono::DateTime<chrono::Utc>,
159    /// Training metrics (if available).
160    pub metrics: Option<TrainingMetrics>,
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_job_status_terminal() {
169        assert!(!TrainingJobStatus::Pending.is_terminal());
170        assert!(!TrainingJobStatus::Queued.is_terminal());
171        assert!(
172            TrainingJobStatus::Succeeded {
173                model_id: "m".into()
174            }
175            .is_terminal()
176        );
177        assert!(
178            TrainingJobStatus::Failed {
179                error: "err".into()
180            }
181            .is_terminal()
182        );
183        assert!(TrainingJobStatus::Cancelled.is_terminal());
184    }
185
186    #[test]
187    fn test_progress_completion() {
188        let p = TrainingProgress {
189            step: 50,
190            total_steps: 100,
191            ..Default::default()
192        };
193        assert!((p.completion_fraction() - 0.5).abs() < f64::EPSILON);
194    }
195
196    #[test]
197    fn test_job_id_from_string() {
198        let id: TrainingJobId = "ft-abc123".into();
199        assert_eq!(id.0, "ft-abc123");
200        assert_eq!(id.to_string(), "ft-abc123");
201    }
202}