brainwires-finetune 0.11.0

Cloud fine-tune APIs (OpenAI, Anthropic, Together, Fireworks, Anyscale, Bedrock, Vertex AI) + dataset pipelines for the Brainwires Agent Framework. Local LoRA/QLoRA/DoRA fine-tuning and training-from-scratch live in the sibling `rullama` workspace (`rullama-finetune`, `rullama-training`).
Documentation
use async_trait::async_trait;
use reqwest::Client;
use serde_json::json;
use tracing::debug;

use crate::datasets::DataFormat;

use super::{CloudFineTuneConfig, FineTuneProvider};
use crate::error::TrainingError;
use crate::types::{
    DatasetId, TrainingJobId, TrainingJobStatus, TrainingJobSummary, TrainingProgress,
};

const ANYSCALE_API_URL: &str = "https://api.endpoints.anyscale.com/v1";

/// Anyscale fine-tuning provider.
///
/// OpenAI-compatible API drop-in.
pub struct AnyscaleFineTune {
    api_key: String,
    client: Client,
    base_url: String,
}

impl AnyscaleFineTune {
    /// Create a new Anyscale fine-tune provider.
    pub fn new(api_key: impl Into<String>) -> Self {
        Self {
            api_key: api_key.into(),
            client: Client::new(),
            base_url: ANYSCALE_API_URL.to_string(),
        }
    }

    /// Set a custom base URL.
    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
        self.base_url = url.into();
        self
    }

    /// Extract error message from API response body (OpenAI-compatible format).
    fn extract_error(body: &serde_json::Value) -> String {
        body.get("error")
            .and_then(|e| {
                e.get("message")
                    .and_then(|m| m.as_str())
                    .or_else(|| e.as_str())
            })
            .unwrap_or("Unknown error")
            .to_string()
    }

    /// Parse job status from API response (OpenAI-compatible format).
    fn parse_job_status(body: &serde_json::Value) -> TrainingJobStatus {
        let status_str = body
            .get("status")
            .and_then(|v| v.as_str())
            .unwrap_or("pending");

        match status_str {
            "queued" => TrainingJobStatus::Queued,
            "validating_files" => TrainingJobStatus::Validating,
            "running" => TrainingJobStatus::Running {
                progress: TrainingProgress::default(),
            },
            "succeeded" => {
                let model_id = body
                    .get("fine_tuned_model")
                    .and_then(|v| v.as_str())
                    .unwrap_or("unknown")
                    .to_string();
                TrainingJobStatus::Succeeded { model_id }
            }
            "failed" => TrainingJobStatus::Failed {
                error: Self::extract_error(body),
            },
            "cancelled" => TrainingJobStatus::Cancelled,
            _ => TrainingJobStatus::Pending,
        }
    }
}

#[async_trait]
impl FineTuneProvider for AnyscaleFineTune {
    fn name(&self) -> &str {
        "anyscale"
    }

    fn supported_base_models(&self) -> Vec<String> {
        vec![
            "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
            "meta-llama/Meta-Llama-3.1-70B-Instruct".to_string(),
            "mistralai/Mixtral-8x7B-Instruct-v0.1".to_string(),
        ]
    }

    fn supports_dpo(&self) -> bool {
        false
    }

    async fn upload_dataset(
        &self,
        data: &[u8],
        _format: DataFormat,
    ) -> Result<DatasetId, TrainingError> {
        debug!("Uploading dataset to Anyscale ({} bytes)", data.len());

        let part = reqwest::multipart::Part::bytes(data.to_vec()).file_name("training_data.jsonl");

        let form = reqwest::multipart::Form::new()
            .text("purpose", "fine-tune")
            .part("file", part);

        let response = self
            .client
            .post(format!("{}/files", self.base_url))
            .bearer_auth(&self.api_key)
            .multipart(form)
            .send()
            .await?;

        let status = response.status();
        let body: serde_json::Value = response.json().await?;

        if !status.is_success() {
            return Err(TrainingError::Api {
                message: Self::extract_error(&body),
                status_code: status.as_u16(),
            });
        }

        let file_id = body
            .get("id")
            .and_then(|v| v.as_str())
            .ok_or_else(|| TrainingError::Upload("Missing file ID".to_string()))?
            .to_string();

        Ok(DatasetId(file_id))
    }

    async fn create_job(
        &self,
        config: CloudFineTuneConfig,
    ) -> Result<TrainingJobId, TrainingError> {
        debug!(
            "Creating Anyscale fine-tuning job for: {}",
            config.base_model
        );

        let mut body = json!({
            "training_file": config.training_dataset.0,
            "model": config.base_model,
            "hyperparameters": {
                "n_epochs": config.hyperparams.epochs,
            },
        });

        if let Some(ref suffix) = config.suffix {
            body["suffix"] = json!(suffix);
        }

        let response = self
            .client
            .post(format!("{}/fine_tuning/jobs", self.base_url))
            .bearer_auth(&self.api_key)
            .json(&body)
            .send()
            .await?;

        let status = response.status();
        let response_body: serde_json::Value = response.json().await?;

        if !status.is_success() {
            return Err(TrainingError::Api {
                message: Self::extract_error(&response_body),
                status_code: status.as_u16(),
            });
        }

        let job_id = response_body
            .get("id")
            .and_then(|v| v.as_str())
            .ok_or_else(|| TrainingError::Provider("Missing job ID".to_string()))?
            .to_string();

        Ok(TrainingJobId(job_id))
    }

    async fn get_job_status(
        &self,
        job_id: &TrainingJobId,
    ) -> Result<TrainingJobStatus, TrainingError> {
        let url = format!("{}/fine_tuning/jobs/{}", self.base_url, job_id.0);

        let response = self
            .client
            .get(&url)
            .bearer_auth(&self.api_key)
            .send()
            .await?;

        let body: serde_json::Value = response.json().await?;
        Ok(Self::parse_job_status(&body))
    }

    async fn cancel_job(&self, job_id: &TrainingJobId) -> Result<(), TrainingError> {
        let url = format!("{}/fine_tuning/jobs/{}/cancel", self.base_url, job_id.0);
        let response = self
            .client
            .post(&url)
            .bearer_auth(&self.api_key)
            .send()
            .await?;

        if !response.status().is_success() {
            let body: serde_json::Value = response.json().await.unwrap_or_default();
            return Err(TrainingError::Provider(Self::extract_error(&body)));
        }
        Ok(())
    }

    async fn list_jobs(&self) -> Result<Vec<TrainingJobSummary>, TrainingError> {
        let response = self
            .client
            .get(format!("{}/fine_tuning/jobs", self.base_url))
            .bearer_auth(&self.api_key)
            .send()
            .await?;

        let body: serde_json::Value = response.json().await?;
        Ok(body
            .get("data")
            .and_then(|v| v.as_array())
            .map(|jobs| {
                jobs.iter()
                    .filter_map(|j| {
                        Some(TrainingJobSummary {
                            job_id: TrainingJobId(j.get("id")?.as_str()?.to_string()),
                            provider: "anyscale".to_string(),
                            base_model: j.get("model")?.as_str()?.to_string(),
                            status: Self::parse_job_status(j),
                            created_at: chrono::Utc::now(),
                            metrics: None,
                        })
                    })
                    .collect()
            })
            .unwrap_or_default())
    }

    async fn delete_model(&self, model_id: &str) -> Result<(), TrainingError> {
        let url = format!("{}/models/{}", self.base_url, model_id);
        let response = self
            .client
            .delete(&url)
            .bearer_auth(&self.api_key)
            .send()
            .await?;

        if !response.status().is_success() {
            let body: serde_json::Value = response.json().await.unwrap_or_default();
            return Err(TrainingError::Provider(Self::extract_error(&body)));
        }
        Ok(())
    }
}