Skip to main content

brainwires_training/
manager.rs

1use std::collections::HashMap;
2use tracing::info;
3
4use crate::error::TrainingError;
5use crate::types::{TrainingJobId, TrainingJobStatus, TrainingJobSummary};
6
7#[cfg(feature = "cloud")]
8use crate::cloud::{CloudFineTuneConfig, FineTuneProvider, JobPoller};
9
10#[cfg(feature = "local")]
11use crate::local::{LocalTrainingConfig, TrainedModelArtifact, TrainingBackend};
12
13/// High-level training orchestrator.
14///
15/// Provides a unified API across cloud and local training backends.
16pub struct TrainingManager {
17    #[cfg(feature = "cloud")]
18    cloud_providers: HashMap<String, Box<dyn FineTuneProvider>>,
19
20    #[cfg(feature = "local")]
21    local_backend: Option<Box<dyn TrainingBackend>>,
22}
23
24impl TrainingManager {
25    /// Create a new training manager.
26    pub fn new() -> Self {
27        Self {
28            #[cfg(feature = "cloud")]
29            cloud_providers: HashMap::new(),
30            #[cfg(feature = "local")]
31            local_backend: None,
32        }
33    }
34
35    /// Register a cloud fine-tuning provider.
36    #[cfg(feature = "cloud")]
37    pub fn add_cloud_provider(&mut self, provider: Box<dyn FineTuneProvider>) {
38        let name = provider.name().to_string();
39        info!("Registered cloud fine-tune provider: {}", name);
40        self.cloud_providers.insert(name, provider);
41    }
42
43    /// Set the local training backend.
44    #[cfg(feature = "local")]
45    pub fn set_local_backend(&mut self, backend: Box<dyn TrainingBackend>) {
46        info!("Set local training backend: {}", backend.name());
47        self.local_backend = Some(backend);
48    }
49
50    /// List registered cloud providers.
51    #[cfg(feature = "cloud")]
52    pub fn cloud_providers(&self) -> Vec<&str> {
53        self.cloud_providers.keys().map(|s| s.as_str()).collect()
54    }
55
56    /// Get a cloud provider by name.
57    #[cfg(feature = "cloud")]
58    pub fn get_cloud_provider(&self, name: &str) -> Option<&dyn FineTuneProvider> {
59        self.cloud_providers.get(name).map(|p| p.as_ref())
60    }
61
62    /// Start a cloud fine-tuning job.
63    #[cfg(feature = "cloud")]
64    pub async fn start_cloud_job(
65        &self,
66        provider_name: &str,
67        config: CloudFineTuneConfig,
68    ) -> Result<TrainingJobId, TrainingError> {
69        let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
70            TrainingError::Provider(format!(
71                "Unknown provider: {}. Available: {:?}",
72                provider_name,
73                self.cloud_providers.keys().collect::<Vec<_>>()
74            ))
75        })?;
76
77        info!(
78            "Starting cloud fine-tuning job on {} with model {}",
79            provider_name, config.base_model
80        );
81
82        provider.create_job(config).await
83    }
84
85    /// Poll a cloud job until completion.
86    #[cfg(feature = "cloud")]
87    pub async fn wait_for_cloud_job(
88        &self,
89        provider_name: &str,
90        job_id: &TrainingJobId,
91    ) -> Result<TrainingJobStatus, TrainingError> {
92        let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
93            TrainingError::Provider(format!("Unknown provider: {}", provider_name))
94        })?;
95
96        let poller = JobPoller::default();
97        poller.poll_with_logging(provider.as_ref(), job_id).await
98    }
99
100    /// Check status of a cloud job.
101    #[cfg(feature = "cloud")]
102    pub async fn check_cloud_job(
103        &self,
104        provider_name: &str,
105        job_id: &TrainingJobId,
106    ) -> Result<TrainingJobStatus, TrainingError> {
107        let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
108            TrainingError::Provider(format!("Unknown provider: {}", provider_name))
109        })?;
110
111        provider.get_job_status(job_id).await
112    }
113
114    /// Cancel a cloud job.
115    #[cfg(feature = "cloud")]
116    pub async fn cancel_cloud_job(
117        &self,
118        provider_name: &str,
119        job_id: &TrainingJobId,
120    ) -> Result<(), TrainingError> {
121        let provider = self.cloud_providers.get(provider_name).ok_or_else(|| {
122            TrainingError::Provider(format!("Unknown provider: {}", provider_name))
123        })?;
124
125        provider.cancel_job(job_id).await
126    }
127
128    /// List all jobs across all cloud providers.
129    #[cfg(feature = "cloud")]
130    pub async fn list_all_cloud_jobs(&self) -> Result<Vec<TrainingJobSummary>, TrainingError> {
131        let mut all_jobs = Vec::new();
132        for provider in self.cloud_providers.values() {
133            match provider.list_jobs().await {
134                Ok(jobs) => all_jobs.extend(jobs),
135                Err(e) => {
136                    tracing::warn!("Failed to list jobs from {}: {}", provider.name(), e);
137                }
138            }
139        }
140        Ok(all_jobs)
141    }
142
143    /// Run local training.
144    #[cfg(feature = "local")]
145    pub fn train_local(
146        &self,
147        config: LocalTrainingConfig,
148        callback: Box<dyn Fn(crate::types::TrainingProgress) + Send>,
149    ) -> Result<TrainedModelArtifact, TrainingError> {
150        let backend = self.local_backend.as_ref().ok_or_else(|| {
151            TrainingError::Backend("No local training backend configured".to_string())
152        })?;
153
154        info!("Starting local training with {} backend", backend.name());
155        backend.train(config, callback)
156    }
157}
158
159impl Default for TrainingManager {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_training_manager_creation() {
171        let manager = TrainingManager::new();
172
173        #[cfg(feature = "cloud")]
174        assert!(manager.cloud_providers().is_empty());
175    }
176}