use axum::{extract::State, http::StatusCode, response::Json};
use serde::{Deserialize, Serialize};
use super::state::BancoState;
use super::types::ErrorResponse;
pub async fn merge_models_handler(
State(state): State<BancoState>,
Json(request): Json<MergeRequest>,
) -> Result<Json<MergeResult>, (StatusCode, Json<ErrorResponse>)> {
if request.models.len() < 2 {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse::new(
"At least 2 models required for merge",
"invalid_request",
400,
)),
));
}
if request.strategy == MergeStrategy::Slerp && request.models.len() != 2 {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse::new(
"SLERP merge requires exactly 2 models",
"invalid_request",
400,
)),
));
}
let result = execute_merge(&state, &request);
state.events.emit(&super::events::BancoEvent::MergeComplete {
merge_id: result.merge_id.clone(),
strategy: format!("{:?}", result.strategy).to_lowercase(),
});
Ok(Json(result))
}
pub async fn list_merge_strategies_handler() -> Json<MergeStrategiesResponse> {
Json(MergeStrategiesResponse {
strategies: vec![
StrategyInfo {
name: "weighted_average".to_string(),
description: "Element-wise weighted average of model parameters".to_string(),
min_models: 2,
max_models: None,
},
StrategyInfo {
name: "ties".to_string(),
description: "Trim, Elect, Sign merge — reduces noise across multiple fine-tunes"
.to_string(),
min_models: 2,
max_models: None,
},
StrategyInfo {
name: "dare".to_string(),
description: "Drop And REscale — stochastic sparsity-based merge".to_string(),
min_models: 2,
max_models: None,
},
StrategyInfo {
name: "slerp".to_string(),
description: "Spherical linear interpolation — smooth two-model blending"
.to_string(),
min_models: 2,
max_models: Some(2),
},
],
})
}
#[cfg(feature = "entrenar")]
fn execute_merge(state: &BancoState, request: &MergeRequest) -> MergeResult {
use entrenar::merge::{DareConfig, EnsembleConfig, MergeError, SlerpConfig, TiesConfig};
let models: Vec<entrenar::merge::Model> =
request.models.iter().map(|_| std::collections::HashMap::new()).collect();
let merge_result: Result<entrenar::merge::Model, MergeError> = match &request.strategy {
MergeStrategy::WeightedAverage => {
let weights = request
.weights
.clone()
.unwrap_or_else(|| vec![1.0 / request.models.len() as f32; request.models.len()]);
let config = EnsembleConfig::weighted_average(weights);
entrenar::merge::ensemble_merge(&models, &config)
}
MergeStrategy::Ties => {
let density = request.density.unwrap_or(0.2);
let base = std::collections::HashMap::new();
let config = TiesConfig { density };
entrenar::merge::ties_merge(&models, &base, &config)
}
MergeStrategy::Dare => {
let drop_prob = request.drop_prob.unwrap_or(0.5);
let base = std::collections::HashMap::new();
let config = DareConfig { drop_prob, seed: request.seed };
entrenar::merge::dare_merge(&models, &base, &config)
}
MergeStrategy::Slerp => {
let t = request.interpolation_t.unwrap_or(0.5);
let config = SlerpConfig { t };
entrenar::merge::slerp_merge(&models[0], &models[1], &config)
}
};
let (status, error) = match merge_result {
Ok(_merged) => ("complete".to_string(), None),
Err(e) => ("failed".to_string(), Some(e.to_string())),
};
let _ = state;
MergeResult {
merge_id: format!("merge-{}", epoch_secs()),
strategy: request.strategy.clone(),
models: request.models.clone(),
status,
simulated: true, error,
output_path: None,
}
}
#[cfg(not(feature = "entrenar"))]
fn execute_merge(_state: &BancoState, request: &MergeRequest) -> MergeResult {
MergeResult {
merge_id: format!("merge-{}", epoch_secs()),
strategy: request.strategy.clone(),
models: request.models.clone(),
status: "dry_run".to_string(),
simulated: true,
error: None,
output_path: None,
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MergeStrategy {
WeightedAverage,
Ties,
Dare,
Slerp,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MergeRequest {
pub models: Vec<String>,
pub strategy: MergeStrategy,
#[serde(default)]
pub weights: Option<Vec<f32>>,
#[serde(default)]
pub density: Option<f32>,
#[serde(default)]
pub drop_prob: Option<f32>,
#[serde(default)]
pub interpolation_t: Option<f32>,
#[serde(default)]
pub seed: Option<u64>,
#[serde(default)]
pub output_format: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct MergeResult {
pub merge_id: String,
pub strategy: MergeStrategy,
pub models: Vec<String>,
pub status: String,
#[serde(default)]
pub simulated: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_path: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct MergeStrategiesResponse {
pub strategies: Vec<StrategyInfo>,
}
#[derive(Debug, Serialize)]
pub struct StrategyInfo {
pub name: String,
pub description: String,
pub min_models: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_models: Option<usize>,
}
fn epoch_secs() -> u64 {
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
}