brainwires_training/
manager.rs1use std::collections::HashMap;
2use tracing::info;
3
4use crate::error::TrainingError;
5use crate::types::{TrainingJobId, TrainingJobStatus, TrainingJobSummary};
6
7#[cfg(feature = "cloud")]
8use crate::cloud::{CloudFineTuneConfig, FineTuneProvider, JobPoller};
9
10#[cfg(feature = "local")]
11use crate::local::{LocalTrainingConfig, TrainedModelArtifact, TrainingBackend};
12
13pub struct TrainingManager {
17 #[cfg(feature = "cloud")]
18 cloud_providers: HashMap<String, Box<dyn FineTuneProvider>>,
19
20 #[cfg(feature = "local")]
21 local_backend: Option<Box<dyn TrainingBackend>>,
22}
23
24impl TrainingManager {
25 pub fn new() -> Self {
27 Self {
28 #[cfg(feature = "cloud")]
29 cloud_providers: HashMap::new(),
30 #[cfg(feature = "local")]
31 local_backend: None,
32 }
33 }
34
35 #[cfg(feature = "cloud")]
37 pub fn add_cloud_provider(&mut self, provider: Box<dyn FineTuneProvider>) {
38 let name = provider.name().to_string();
39 info!("Registered cloud fine-tune provider: {}", name);
40 self.cloud_providers.insert(name, provider);
41 }
42
43 #[cfg(feature = "local")]
45 pub fn set_local_backend(&mut self, backend: Box<dyn TrainingBackend>) {
46 info!("Set local training backend: {}", backend.name());
47 self.local_backend = Some(backend);
48 }
49
50 #[cfg(feature = "cloud")]
52 pub fn cloud_providers(&self) -> Vec<&str> {
53 self.cloud_providers.keys().map(|s| s.as_str()).collect()
54 }
55
56 #[cfg(feature = "cloud")]
58 pub fn get_cloud_provider(&self, name: &str) -> Option<&dyn FineTuneProvider> {
59 self.cloud_providers.get(name).map(|p| p.as_ref())
60 }
61
62 #[cfg(feature = "cloud")]
64 pub async fn start_cloud_job(
65 &self,
66 provider_name: &str,
67 config: CloudFineTuneConfig,
68 ) -> Result<TrainingJobId, TrainingError> {
69 let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
70 TrainingError::Provider(format!(
71 "Unknown provider: {}. Available: {:?}",
72 provider_name,
73 self.cloud_providers.keys().collect::<Vec<_>>()
74 ))
75 })?;
76
77 info!(
78 "Starting cloud fine-tuning job on {} with model {}",
79 provider_name, config.base_model
80 );
81
82 provider.create_job(config).await
83 }
84
85 #[cfg(feature = "cloud")]
87 pub async fn wait_for_cloud_job(
88 &self,
89 provider_name: &str,
90 job_id: &TrainingJobId,
91 ) -> Result<TrainingJobStatus, TrainingError> {
92 let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
93 TrainingError::Provider(format!("Unknown provider: {}", provider_name))
94 })?;
95
96 let poller = JobPoller::default();
97 poller.poll_with_logging(provider.as_ref(), job_id).await
98 }
99
100 #[cfg(feature = "cloud")]
102 pub async fn check_cloud_job(
103 &self,
104 provider_name: &str,
105 job_id: &TrainingJobId,
106 ) -> Result<TrainingJobStatus, TrainingError> {
107 let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
108 TrainingError::Provider(format!("Unknown provider: {}", provider_name))
109 })?;
110
111 provider.get_job_status(job_id).await
112 }
113
114 #[cfg(feature = "cloud")]
116 pub async fn cancel_cloud_job(
117 &self,
118 provider_name: &str,
119 job_id: &TrainingJobId,
120 ) -> Result<(), TrainingError> {
121 let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
122 TrainingError::Provider(format!("Unknown provider: {}", provider_name))
123 })?;
124
125 provider.cancel_job(job_id).await
126 }
127
128 #[cfg(feature = "cloud")]
130 pub async fn list_all_cloud_jobs(&self) -> Result<Vec<TrainingJobSummary>, TrainingError> {
131 let mut all_jobs = Vec::new();
132 for provider in self.cloud_providers.values() {
133 match provider.list_jobs().await {
134 Ok(jobs) => all_jobs.extend(jobs),
135 Err(e) => {
136 tracing::warn!("Failed to list jobs from {}: {}", provider.name(), e);
137 }
138 }
139 }
140 Ok(all_jobs)
141 }
142
143 #[cfg(feature = "local")]
145 pub fn train_local(
146 &self,
147 config: LocalTrainingConfig,
148 callback: Box<dyn Fn(crate::types::TrainingProgress) + Send>,
149 ) -> Result<TrainedModelArtifact, TrainingError> {
150 let backend = self.local_backend.as_ref().ok_or_else(|| {
151 TrainingError::Backend("No local training backend configured".to_string())
152 })?;
153
154 info!("Starting local training with {} backend", backend.name());
155 backend.train(config, callback)
156 }
157}
158
159impl Default for TrainingManager {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_training_manager_creation() {
171 let manager = TrainingManager::new();
172
173 #[cfg(feature = "cloud")]
174 assert!(manager.cloud_providers().is_empty());
175 }
176}