use crate::cli::MSSwiftCommands;
use orign::config::GlobalConfig;
use orign::models::{MSSwiftParams, TrainingRequest};
use reqwest::Client;
use serde_json::Value;
use std::error::Error;
pub async fn run_swift_training(args: MSSwiftCommands) -> Result<(), Box<dyn Error>> {
let training_request = convert_to_training_request(args);
fn convert_to_training_request(args: MSSwiftCommands) -> TrainingRequest {
let ms_swift_params = MSSwiftParams {
model: args.ms_swift_args.model,
model_type: args.ms_swift_args.model_type,
train_type: args.ms_swift_args.train_type,
deepspeed: args.ms_swift_args.deepspeed,
torch_dtype: args.ms_swift_args.torch_dtype,
max_length: args.ms_swift_args.max_length,
dataset: args.ms_swift_args.dataset,
val_split_ratio: args.ms_swift_args.val_split_ratio,
num_train_epochs: args.ms_swift_args.num_train_epochs,
eval_strategy: args.ms_swift_args.eval_strategy,
save_strategy: args.ms_swift_args.save_strategy,
save_total_limit: args.ms_swift_args.save_total_limit,
lora_rank: args.ms_swift_args.lora_rank,
lora_alpha: args.ms_swift_args.lora_alpha,
size_factor: args.ms_swift_args.size_factor,
max_pixels: args.ms_swift_args.max_pixels,
resume_from_checkpoint: None,
freeze_vit: Some(args.ms_swift_args.freeze_vit),
rlhf_type: args.ms_swift_args.rlhf_type,
gradient_accumulation_steps_total: args.ms_swift_args.gradient_accumulation_steps_total,
learning_rate: args.ms_swift_args.learning_rate,
save_steps: args.ms_swift_args.save_steps,
};
TrainingRequest {
name: args.train_args.name,
framework: "ms-swift".to_string(),
namespace: args.train_args.namespace,
vram_request: args.train_args.vram,
cpu_request: args.train_args.cpu_request,
trust_remote_code: args.train_args.trust_remote_code,
adapter: args.train_args.adapter,
buffer: args.train_args.buffer,
accelerators: args.train_args.accelerators,
ms_swift_params: Some(ms_swift_params),
llama_factory_params: None,
trl_params: None,
resume: args.train_args.resume,
queue: args.train_args.queue,
platform: args.train_args.platform,
labels: None,
}
}
let client = Client::new();
let config = GlobalConfig::read()?;
let server = config.server.unwrap();
let api_key = config.api_key.as_deref().ok_or("API key not set")?;
let bearer_token = format!("Bearer {}", api_key);
let url = format!("{}/v1/trainings", server);
let response = client
.post(&url)
.header("Authorization", bearer_token)
.json(&training_request)
.send()
.await?;
if response.status().is_success() {
let resp_json: Value = response.json().await?;
println!("Training started successfully:");
println!("{}", serde_json::to_string_pretty(&resp_json)?);
} else {
let error_text = response.text().await?;
eprintln!("Error starting training: {}", error_text);
}
Ok(())
}
pub async fn stop_training(id: String) -> Result<(), Box<dyn Error>> {
let client = Client::new();
let config = GlobalConfig::read()?;
let server = config.server.unwrap();
let api_key = config.api_key.as_deref().ok_or("API key not set")?;
let bearer_token = format!("Bearer {}", api_key);
println!("Executing train stop command: {} {}", server, id);
let url = format!("{}/v1/trainings/{}", server, id);
let response = client
.delete(&url)
.header("Authorization", bearer_token)
.send()
.await?;
if response.status().is_success() {
let resp_json: Value = response.json().await?;
println!("{}", serde_json::to_string_pretty(&resp_json)?);
} else {
let error_text = response.text().await?;
eprintln!("Error stopping training: {}", error_text);
}
Ok(())
}