brainwires_training/cloud/
mod.rs1pub mod anyscale;
3pub mod bedrock;
5pub mod cost;
7pub mod fireworks;
9pub mod openai;
11pub mod polling;
13pub mod together;
15pub mod vertex;
17
18use async_trait::async_trait;
19use brainwires_datasets::DataFormat;
20
21use crate::config::{AlignmentMethod, LoraConfig, TrainingHyperparams};
22use crate::error::TrainingError;
23use crate::types::{DatasetId, TrainingJobId, TrainingJobStatus, TrainingJobSummary};
24
25#[derive(Debug, Clone)]
27pub struct CloudFineTuneConfig {
28 pub base_model: String,
30 pub training_dataset: DatasetId,
32 pub validation_dataset: Option<DatasetId>,
34 pub hyperparams: TrainingHyperparams,
36 pub lora: Option<LoraConfig>,
38 pub alignment: AlignmentMethod,
40 pub suffix: Option<String>,
42}
43
44impl CloudFineTuneConfig {
45 pub fn new(base_model: impl Into<String>, training_dataset: DatasetId) -> Self {
47 Self {
48 base_model: base_model.into(),
49 training_dataset,
50 validation_dataset: None,
51 hyperparams: TrainingHyperparams::default(),
52 lora: None,
53 alignment: AlignmentMethod::None,
54 suffix: None,
55 }
56 }
57
58 pub fn with_validation(mut self, dataset: DatasetId) -> Self {
60 self.validation_dataset = Some(dataset);
61 self
62 }
63
64 pub fn with_hyperparams(mut self, h: TrainingHyperparams) -> Self {
66 self.hyperparams = h;
67 self
68 }
69
70 pub fn with_lora(mut self, lora: LoraConfig) -> Self {
72 self.lora = Some(lora);
73 self
74 }
75
76 pub fn with_alignment(mut self, alignment: AlignmentMethod) -> Self {
78 self.alignment = alignment;
79 self
80 }
81
82 pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
84 self.suffix = Some(suffix.into());
85 self
86 }
87}
88
89#[async_trait]
91pub trait FineTuneProvider: Send + Sync {
92 fn name(&self) -> &str;
94
95 fn supported_base_models(&self) -> Vec<String>;
97
98 fn supports_dpo(&self) -> bool;
100
101 async fn upload_dataset(
103 &self,
104 data: &[u8],
105 format: DataFormat,
106 ) -> Result<DatasetId, TrainingError>;
107
108 async fn create_job(&self, config: CloudFineTuneConfig)
110 -> Result<TrainingJobId, TrainingError>;
111
112 async fn get_job_status(
114 &self,
115 job_id: &TrainingJobId,
116 ) -> Result<TrainingJobStatus, TrainingError>;
117
118 async fn cancel_job(&self, job_id: &TrainingJobId) -> Result<(), TrainingError>;
120
121 async fn list_jobs(&self) -> Result<Vec<TrainingJobSummary>, TrainingError>;
123
124 async fn delete_model(&self, model_id: &str) -> Result<(), TrainingError>;
126}
127
128pub struct FineTuneProviderFactory;
130
131impl FineTuneProviderFactory {
132 pub fn openai(api_key: impl Into<String>) -> openai::OpenAiFineTune {
134 openai::OpenAiFineTune::new(api_key)
135 }
136
137 pub fn together(api_key: impl Into<String>) -> together::TogetherFineTune {
139 together::TogetherFineTune::new(api_key)
140 }
141
142 pub fn fireworks(api_key: impl Into<String>) -> fireworks::FireworksFineTune {
144 fireworks::FireworksFineTune::new(api_key)
145 }
146
147 pub fn anyscale(api_key: impl Into<String>) -> anyscale::AnyscaleFineTune {
149 anyscale::AnyscaleFineTune::new(api_key)
150 }
151
152 pub fn bedrock(region: impl Into<String>) -> bedrock::BedrockFineTune {
154 bedrock::BedrockFineTune::new(region)
155 }
156
157 pub fn vertex(
159 project_id: impl Into<String>,
160 location: impl Into<String>,
161 ) -> vertex::VertexFineTune {
162 vertex::VertexFineTune::new(project_id, location)
163 }
164}
165
166pub use self::anyscale::AnyscaleFineTune;
167pub use self::bedrock::BedrockFineTune;
168pub use self::cost::CostEstimator;
169pub use self::fireworks::FireworksFineTune;
170pub use self::openai::OpenAiFineTune;
171pub use self::polling::JobPoller;
172pub use self::together::TogetherFineTune;
173pub use self::vertex::VertexFineTune;