orign 0.2.3

A globally distributed container orchestrator
Documentation
// src/commands/train_cmd.rs

use crate::cli::MSSwiftCommands;
use orign::config::GlobalConfig;
use orign::models::{MSSwiftParams, TrainingRequest};
use reqwest::Client;
use serde_json::Value;
use std::error::Error;

pub async fn run_swift_training(args: MSSwiftCommands) -> Result<(), Box<dyn Error>> {
    let training_request = convert_to_training_request(args);
    fn convert_to_training_request(args: MSSwiftCommands) -> TrainingRequest {
        let ms_swift_params = MSSwiftParams {
            model: args.ms_swift_args.model,
            model_type: args.ms_swift_args.model_type,
            train_type: args.ms_swift_args.train_type,
            deepspeed: args.ms_swift_args.deepspeed,
            torch_dtype: args.ms_swift_args.torch_dtype,
            max_length: args.ms_swift_args.max_length,
            dataset: args.ms_swift_args.dataset,
            val_split_ratio: args.ms_swift_args.val_split_ratio,
            num_train_epochs: args.ms_swift_args.num_train_epochs,
            eval_strategy: args.ms_swift_args.eval_strategy,
            save_strategy: args.ms_swift_args.save_strategy,
            save_total_limit: args.ms_swift_args.save_total_limit,
            lora_rank: args.ms_swift_args.lora_rank,
            lora_alpha: args.ms_swift_args.lora_alpha,
            size_factor: args.ms_swift_args.size_factor,
            max_pixels: args.ms_swift_args.max_pixels,
            resume_from_checkpoint: None,
            freeze_vit: Some(args.ms_swift_args.freeze_vit),
            rlhf_type: args.ms_swift_args.rlhf_type,
            gradient_accumulation_steps_total: args.ms_swift_args.gradient_accumulation_steps_total,
            learning_rate: args.ms_swift_args.learning_rate,
            save_steps: args.ms_swift_args.save_steps,
        };

        TrainingRequest {
            name: args.train_args.name,
            framework: "ms-swift".to_string(),
            namespace: args.train_args.namespace,
            vram_request: args.train_args.vram,
            cpu_request: args.train_args.cpu_request,
            trust_remote_code: args.train_args.trust_remote_code,
            adapter: args.train_args.adapter,
            buffer: args.train_args.buffer,
            accelerators: args.train_args.accelerators,
            ms_swift_params: Some(ms_swift_params),
            llama_factory_params: None,
            trl_params: None,
            resume: args.train_args.resume,
            queue: args.train_args.queue,
            platform: args.train_args.platform,
            labels: None,
        }
    }

    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap();
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    let url = format!("{}/v1/trainings", server);
    let response = client
        .post(&url)
        .header("Authorization", bearer_token)
        .json(&training_request)
        .send()
        .await?;

    if response.status().is_success() {
        let resp_json: Value = response.json().await?;
        println!("Training started successfully:");
        println!("{}", serde_json::to_string_pretty(&resp_json)?);
    } else {
        let error_text = response.text().await?;
        eprintln!("Error starting training: {}", error_text);
    }

    Ok(())
}

pub async fn stop_training(id: String) -> Result<(), Box<dyn Error>> {
    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap();
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    println!("Executing train stop command: {} {}", server, id);

    let url = format!("{}/v1/trainings/{}", server, id);
    let response = client
        .delete(&url)
        .header("Authorization", bearer_token)
        .send()
        .await?;

    if response.status().is_success() {
        let resp_json: Value = response.json().await?;
        println!("{}", serde_json::to_string_pretty(&resp_json)?);
    } else {
        let error_text = response.text().await?;
        eprintln!("Error stopping training: {}", error_text);
    }

    Ok(())
}