pub mod anyscale;
pub mod bedrock;
pub mod cost;
pub mod fireworks;
pub mod openai;
pub mod polling;
pub mod together;
pub mod vertex;
use crate::datasets::DataFormat;
use async_trait::async_trait;
use crate::config::{AlignmentMethod, LoraConfig, TrainingHyperparams};
use crate::error::TrainingError;
use crate::types::{DatasetId, TrainingJobId, TrainingJobStatus, TrainingJobSummary};
#[derive(Debug, Clone)]
pub struct CloudFineTuneConfig {
pub base_model: String,
pub training_dataset: DatasetId,
pub validation_dataset: Option<DatasetId>,
pub hyperparams: TrainingHyperparams,
pub lora: Option<LoraConfig>,
pub alignment: AlignmentMethod,
pub suffix: Option<String>,
}
impl CloudFineTuneConfig {
pub fn new(base_model: impl Into<String>, training_dataset: DatasetId) -> Self {
Self {
base_model: base_model.into(),
training_dataset,
validation_dataset: None,
hyperparams: TrainingHyperparams::default(),
lora: None,
alignment: AlignmentMethod::None,
suffix: None,
}
}
pub fn with_validation(mut self, dataset: DatasetId) -> Self {
self.validation_dataset = Some(dataset);
self
}
pub fn with_hyperparams(mut self, h: TrainingHyperparams) -> Self {
self.hyperparams = h;
self
}
pub fn with_lora(mut self, lora: LoraConfig) -> Self {
self.lora = Some(lora);
self
}
pub fn with_alignment(mut self, alignment: AlignmentMethod) -> Self {
self.alignment = alignment;
self
}
pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
self.suffix = Some(suffix.into());
self
}
}
#[async_trait]
pub trait FineTuneProvider: Send + Sync {
fn name(&self) -> &str;
fn supported_base_models(&self) -> Vec<String>;
fn supports_dpo(&self) -> bool;
async fn upload_dataset(
&self,
data: &[u8],
format: DataFormat,
) -> Result<DatasetId, TrainingError>;
async fn create_job(&self, config: CloudFineTuneConfig)
-> Result<TrainingJobId, TrainingError>;
async fn get_job_status(
&self,
job_id: &TrainingJobId,
) -> Result<TrainingJobStatus, TrainingError>;
async fn cancel_job(&self, job_id: &TrainingJobId) -> Result<(), TrainingError>;
async fn list_jobs(&self) -> Result<Vec<TrainingJobSummary>, TrainingError>;
async fn delete_model(&self, model_id: &str) -> Result<(), TrainingError>;
}
pub struct FineTuneProviderFactory;
impl FineTuneProviderFactory {
pub fn openai(api_key: impl Into<String>) -> openai::OpenAiFineTune {
openai::OpenAiFineTune::new(api_key)
}
pub fn together(api_key: impl Into<String>) -> together::TogetherFineTune {
together::TogetherFineTune::new(api_key)
}
pub fn fireworks(api_key: impl Into<String>) -> fireworks::FireworksFineTune {
fireworks::FireworksFineTune::new(api_key)
}
pub fn anyscale(api_key: impl Into<String>) -> anyscale::AnyscaleFineTune {
anyscale::AnyscaleFineTune::new(api_key)
}
pub fn bedrock(region: impl Into<String>) -> bedrock::BedrockFineTune {
bedrock::BedrockFineTune::new(region)
}
pub fn vertex(
project_id: impl Into<String>,
location: impl Into<String>,
) -> vertex::VertexFineTune {
vertex::VertexFineTune::new(project_id, location)
}
}
pub use self::anyscale::AnyscaleFineTune;
pub use self::bedrock::BedrockFineTune;
pub use self::cost::CostEstimator;
pub use self::fireworks::FireworksFineTune;
pub use self::openai::OpenAiFineTune;
pub use self::polling::JobPoller;
pub use self::together::TogetherFineTune;
pub use self::vertex::VertexFineTune;