Skip to main content

batuta/serve/banco/
handlers_merge.rs

1//! Model merge endpoint handlers — TIES, DARE, SLERP, weighted average.
2//!
3//! With `ml` feature: uses entrenar's merge module for real tensor merging.
4//! Without `ml`: returns dry-run merge results for API testing.
5
6use axum::{extract::State, http::StatusCode, response::Json};
7use serde::{Deserialize, Serialize};
8
9use super::state::BancoState;
10use super::types::ErrorResponse;
11
12/// POST /api/v1/models/merge — merge two or more models.
13pub async fn merge_models_handler(
14    State(state): State<BancoState>,
15    Json(request): Json<MergeRequest>,
16) -> Result<Json<MergeResult>, (StatusCode, Json<ErrorResponse>)> {
17    if request.models.len() < 2 {
18        return Err((
19            StatusCode::BAD_REQUEST,
20            Json(ErrorResponse::new(
21                "At least 2 models required for merge",
22                "invalid_request",
23                400,
24            )),
25        ));
26    }
27
28    // SLERP only works with exactly 2 models
29    if request.strategy == MergeStrategy::Slerp && request.models.len() != 2 {
30        return Err((
31            StatusCode::BAD_REQUEST,
32            Json(ErrorResponse::new(
33                "SLERP merge requires exactly 2 models",
34                "invalid_request",
35                400,
36            )),
37        ));
38    }
39
40    let result = execute_merge(&state, &request);
41    state.events.emit(&super::events::BancoEvent::MergeComplete {
42        merge_id: result.merge_id.clone(),
43        strategy: format!("{:?}", result.strategy).to_lowercase(),
44    });
45    Ok(Json(result))
46}
47
48/// GET /api/v1/models/merge/strategies — list available merge strategies.
49pub async fn list_merge_strategies_handler() -> Json<MergeStrategiesResponse> {
50    Json(MergeStrategiesResponse {
51        strategies: vec![
52            StrategyInfo {
53                name: "weighted_average".to_string(),
54                description: "Element-wise weighted average of model parameters".to_string(),
55                min_models: 2,
56                max_models: None,
57            },
58            StrategyInfo {
59                name: "ties".to_string(),
60                description: "Trim, Elect, Sign merge — reduces noise across multiple fine-tunes"
61                    .to_string(),
62                min_models: 2,
63                max_models: None,
64            },
65            StrategyInfo {
66                name: "dare".to_string(),
67                description: "Drop And REscale — stochastic sparsity-based merge".to_string(),
68                min_models: 2,
69                max_models: None,
70            },
71            StrategyInfo {
72                name: "slerp".to_string(),
73                description: "Spherical linear interpolation — smooth two-model blending"
74                    .to_string(),
75                min_models: 2,
76                max_models: Some(2),
77            },
78        ],
79    })
80}
81
82/// Execute a model merge.
83#[cfg(feature = "entrenar")]
84fn execute_merge(state: &BancoState, request: &MergeRequest) -> MergeResult {
85    use entrenar::merge::{DareConfig, EnsembleConfig, MergeError, SlerpConfig, TiesConfig};
86
87    // Build placeholder models from the request model names
88    // Real merge requires loaded model weights — this validates the entrenar API
89    let models: Vec<entrenar::merge::Model> =
90        request.models.iter().map(|_| std::collections::HashMap::new()).collect();
91
92    let merge_result: Result<entrenar::merge::Model, MergeError> = match &request.strategy {
93        MergeStrategy::WeightedAverage => {
94            let weights = request
95                .weights
96                .clone()
97                .unwrap_or_else(|| vec![1.0 / request.models.len() as f32; request.models.len()]);
98            let config = EnsembleConfig::weighted_average(weights);
99            entrenar::merge::ensemble_merge(&models, &config)
100        }
101        MergeStrategy::Ties => {
102            let density = request.density.unwrap_or(0.2);
103            let base = std::collections::HashMap::new();
104            let config = TiesConfig { density };
105            entrenar::merge::ties_merge(&models, &base, &config)
106        }
107        MergeStrategy::Dare => {
108            let drop_prob = request.drop_prob.unwrap_or(0.5);
109            let base = std::collections::HashMap::new();
110            let config = DareConfig { drop_prob, seed: request.seed };
111            entrenar::merge::dare_merge(&models, &base, &config)
112        }
113        MergeStrategy::Slerp => {
114            let t = request.interpolation_t.unwrap_or(0.5);
115            let config = SlerpConfig { t };
116            entrenar::merge::slerp_merge(&models[0], &models[1], &config)
117        }
118    };
119
120    let (status, error) = match merge_result {
121        Ok(_merged) => ("complete".to_string(), None),
122        Err(e) => ("failed".to_string(), Some(e.to_string())),
123    };
124
125    let _ = state; // used when loading real model weights
126
127    MergeResult {
128        merge_id: format!("merge-{}", epoch_secs()),
129        strategy: request.strategy.clone(),
130        models: request.models.clone(),
131        status,
132        simulated: true, // Uses empty placeholder tensors
133        error,
134        output_path: None,
135    }
136}
137
138/// Dry-run merge (no ml feature).
139#[cfg(not(feature = "entrenar"))]
140fn execute_merge(_state: &BancoState, request: &MergeRequest) -> MergeResult {
141    MergeResult {
142        merge_id: format!("merge-{}", epoch_secs()),
143        strategy: request.strategy.clone(),
144        models: request.models.clone(),
145        status: "dry_run".to_string(),
146        simulated: true,
147        error: None,
148        output_path: None,
149    }
150}
151
152// ============================================================================
153// Types
154// ============================================================================
155
156/// Merge strategy.
157#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
158#[serde(rename_all = "snake_case")]
159pub enum MergeStrategy {
160    WeightedAverage,
161    Ties,
162    Dare,
163    Slerp,
164}
165
166/// Merge request.
167#[derive(Debug, Clone, Deserialize)]
168pub struct MergeRequest {
169    /// Model identifiers (file paths or model IDs).
170    pub models: Vec<String>,
171    /// Merge strategy.
172    pub strategy: MergeStrategy,
173    /// Weights for weighted_average (one per model, auto-normalized).
174    #[serde(default)]
175    pub weights: Option<Vec<f32>>,
176    /// TIES density parameter (0.0-1.0, default 0.2).
177    #[serde(default)]
178    pub density: Option<f32>,
179    /// DARE drop probability (0.0-1.0, default 0.5).
180    #[serde(default)]
181    pub drop_prob: Option<f32>,
182    /// SLERP interpolation parameter (0.0-1.0, default 0.5).
183    #[serde(default)]
184    pub interpolation_t: Option<f32>,
185    /// Random seed for reproducibility (DARE).
186    #[serde(default)]
187    pub seed: Option<u64>,
188    /// Output format (safetensors, gguf, apr).
189    #[serde(default)]
190    pub output_format: Option<String>,
191}
192
193/// Merge result.
194#[derive(Debug, Clone, Serialize)]
195pub struct MergeResult {
196    pub merge_id: String,
197    pub strategy: MergeStrategy,
198    pub models: Vec<String>,
199    pub status: String,
200    /// True when merge runs on empty placeholder tensors (API validation only).
201    #[serde(default)]
202    pub simulated: bool,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub error: Option<String>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub output_path: Option<String>,
207}
208
209/// Merge strategies list response.
210#[derive(Debug, Serialize)]
211pub struct MergeStrategiesResponse {
212    pub strategies: Vec<StrategyInfo>,
213}
214
215/// Strategy info.
216#[derive(Debug, Serialize)]
217pub struct StrategyInfo {
218    pub name: String,
219    pub description: String,
220    pub min_models: usize,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub max_models: Option<usize>,
223}
224
225fn epoch_secs() -> u64 {
226    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
227}