Skip to main content

brainwires_training/cloud/
mod.rs

1/// Anyscale fine-tuning provider.
2pub mod anyscale;
3/// AWS Bedrock fine-tuning provider.
4pub mod bedrock;
5/// Cost estimation utilities.
6pub mod cost;
7/// Fireworks AI fine-tuning provider.
8pub mod fireworks;
9/// OpenAI fine-tuning provider.
10pub mod openai;
11/// Job polling utilities.
12pub mod polling;
13/// Together AI fine-tuning provider.
14pub mod together;
15/// Google Vertex AI fine-tuning provider.
16pub mod vertex;
17
18use async_trait::async_trait;
19use brainwires_datasets::DataFormat;
20
21use crate::config::{AlignmentMethod, LoraConfig, TrainingHyperparams};
22use crate::error::TrainingError;
23use crate::types::{DatasetId, TrainingJobId, TrainingJobStatus, TrainingJobSummary};
24
25/// Configuration for a cloud fine-tuning job.
26#[derive(Debug, Clone)]
27pub struct CloudFineTuneConfig {
28    /// Base model to fine-tune (provider-specific ID).
29    pub base_model: String,
30    /// Uploaded training dataset ID.
31    pub training_dataset: DatasetId,
32    /// Optional validation dataset ID.
33    pub validation_dataset: Option<DatasetId>,
34    /// Training hyperparameters.
35    pub hyperparams: TrainingHyperparams,
36    /// LoRA config (if provider supports PEFT).
37    pub lora: Option<LoraConfig>,
38    /// Alignment method (DPO/ORPO if provider supports it).
39    pub alignment: AlignmentMethod,
40    /// Suffix appended to fine-tuned model name.
41    pub suffix: Option<String>,
42}
43
44impl CloudFineTuneConfig {
45    /// Create a new cloud fine-tune config with default hyperparameters.
46    pub fn new(base_model: impl Into<String>, training_dataset: DatasetId) -> Self {
47        Self {
48            base_model: base_model.into(),
49            training_dataset,
50            validation_dataset: None,
51            hyperparams: TrainingHyperparams::default(),
52            lora: None,
53            alignment: AlignmentMethod::None,
54            suffix: None,
55        }
56    }
57
58    /// Set the validation dataset.
59    pub fn with_validation(mut self, dataset: DatasetId) -> Self {
60        self.validation_dataset = Some(dataset);
61        self
62    }
63
64    /// Set training hyperparameters.
65    pub fn with_hyperparams(mut self, h: TrainingHyperparams) -> Self {
66        self.hyperparams = h;
67        self
68    }
69
70    /// Set LoRA configuration.
71    pub fn with_lora(mut self, lora: LoraConfig) -> Self {
72        self.lora = Some(lora);
73        self
74    }
75
76    /// Set alignment method.
77    pub fn with_alignment(mut self, alignment: AlignmentMethod) -> Self {
78        self.alignment = alignment;
79        self
80    }
81
82    /// Set model name suffix.
83    pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
84        self.suffix = Some(suffix.into());
85        self
86    }
87}
88
89/// Trait for cloud fine-tuning providers.
90#[async_trait]
91pub trait FineTuneProvider: Send + Sync {
92    /// Provider name.
93    fn name(&self) -> &str;
94
95    /// List base models available for fine-tuning.
96    fn supported_base_models(&self) -> Vec<String>;
97
98    /// Whether this provider supports DPO/preference optimization.
99    fn supports_dpo(&self) -> bool;
100
101    /// Upload a dataset (JSONL bytes) and get a dataset ID.
102    async fn upload_dataset(
103        &self,
104        data: &[u8],
105        format: DataFormat,
106    ) -> Result<DatasetId, TrainingError>;
107
108    /// Create a fine-tuning job.
109    async fn create_job(&self, config: CloudFineTuneConfig)
110    -> Result<TrainingJobId, TrainingError>;
111
112    /// Get the current status of a training job.
113    async fn get_job_status(
114        &self,
115        job_id: &TrainingJobId,
116    ) -> Result<TrainingJobStatus, TrainingError>;
117
118    /// Cancel a running training job.
119    async fn cancel_job(&self, job_id: &TrainingJobId) -> Result<(), TrainingError>;
120
121    /// List all training jobs.
122    async fn list_jobs(&self) -> Result<Vec<TrainingJobSummary>, TrainingError>;
123
124    /// Delete a fine-tuned model.
125    async fn delete_model(&self, model_id: &str) -> Result<(), TrainingError>;
126}
127
128/// Factory for creating cloud fine-tune providers.
129pub struct FineTuneProviderFactory;
130
131impl FineTuneProviderFactory {
132    /// Create an OpenAI fine-tune provider.
133    pub fn openai(api_key: impl Into<String>) -> openai::OpenAiFineTune {
134        openai::OpenAiFineTune::new(api_key)
135    }
136
137    /// Create a Together AI fine-tune provider.
138    pub fn together(api_key: impl Into<String>) -> together::TogetherFineTune {
139        together::TogetherFineTune::new(api_key)
140    }
141
142    /// Create a Fireworks AI fine-tune provider.
143    pub fn fireworks(api_key: impl Into<String>) -> fireworks::FireworksFineTune {
144        fireworks::FireworksFineTune::new(api_key)
145    }
146
147    /// Create an Anyscale fine-tune provider.
148    pub fn anyscale(api_key: impl Into<String>) -> anyscale::AnyscaleFineTune {
149        anyscale::AnyscaleFineTune::new(api_key)
150    }
151
152    /// Create an AWS Bedrock fine-tune provider.
153    pub fn bedrock(region: impl Into<String>) -> bedrock::BedrockFineTune {
154        bedrock::BedrockFineTune::new(region)
155    }
156
157    /// Create a Google Vertex AI fine-tune provider.
158    pub fn vertex(
159        project_id: impl Into<String>,
160        location: impl Into<String>,
161    ) -> vertex::VertexFineTune {
162        vertex::VertexFineTune::new(project_id, location)
163    }
164}
165
166pub use self::anyscale::AnyscaleFineTune;
167pub use self::bedrock::BedrockFineTune;
168pub use self::cost::CostEstimator;
169pub use self::fireworks::FireworksFineTune;
170pub use self::openai::OpenAiFineTune;
171pub use self::polling::JobPoller;
172pub use self::together::TogetherFineTune;
173pub use self::vertex::VertexFineTune;