1use axum::{extract::State, http::StatusCode, response::Json};
7use serde::{Deserialize, Serialize};
8
9use super::state::BancoState;
10use super::types::ErrorResponse;
11
12pub async fn merge_models_handler(
14 State(state): State<BancoState>,
15 Json(request): Json<MergeRequest>,
16) -> Result<Json<MergeResult>, (StatusCode, Json<ErrorResponse>)> {
17 if request.models.len() < 2 {
18 return Err((
19 StatusCode::BAD_REQUEST,
20 Json(ErrorResponse::new(
21 "At least 2 models required for merge",
22 "invalid_request",
23 400,
24 )),
25 ));
26 }
27
28 if request.strategy == MergeStrategy::Slerp && request.models.len() != 2 {
30 return Err((
31 StatusCode::BAD_REQUEST,
32 Json(ErrorResponse::new(
33 "SLERP merge requires exactly 2 models",
34 "invalid_request",
35 400,
36 )),
37 ));
38 }
39
40 let result = execute_merge(&state, &request);
41 state.events.emit(&super::events::BancoEvent::MergeComplete {
42 merge_id: result.merge_id.clone(),
43 strategy: format!("{:?}", result.strategy).to_lowercase(),
44 });
45 Ok(Json(result))
46}
47
48pub async fn list_merge_strategies_handler() -> Json<MergeStrategiesResponse> {
50 Json(MergeStrategiesResponse {
51 strategies: vec![
52 StrategyInfo {
53 name: "weighted_average".to_string(),
54 description: "Element-wise weighted average of model parameters".to_string(),
55 min_models: 2,
56 max_models: None,
57 },
58 StrategyInfo {
59 name: "ties".to_string(),
60 description: "Trim, Elect, Sign merge — reduces noise across multiple fine-tunes"
61 .to_string(),
62 min_models: 2,
63 max_models: None,
64 },
65 StrategyInfo {
66 name: "dare".to_string(),
67 description: "Drop And REscale — stochastic sparsity-based merge".to_string(),
68 min_models: 2,
69 max_models: None,
70 },
71 StrategyInfo {
72 name: "slerp".to_string(),
73 description: "Spherical linear interpolation — smooth two-model blending"
74 .to_string(),
75 min_models: 2,
76 max_models: Some(2),
77 },
78 ],
79 })
80}
81
82#[cfg(feature = "entrenar")]
84fn execute_merge(state: &BancoState, request: &MergeRequest) -> MergeResult {
85 use entrenar::merge::{DareConfig, EnsembleConfig, MergeError, SlerpConfig, TiesConfig};
86
87 let models: Vec<entrenar::merge::Model> =
90 request.models.iter().map(|_| std::collections::HashMap::new()).collect();
91
92 let merge_result: Result<entrenar::merge::Model, MergeError> = match &request.strategy {
93 MergeStrategy::WeightedAverage => {
94 let weights = request
95 .weights
96 .clone()
97 .unwrap_or_else(|| vec![1.0 / request.models.len() as f32; request.models.len()]);
98 let config = EnsembleConfig::weighted_average(weights);
99 entrenar::merge::ensemble_merge(&models, &config)
100 }
101 MergeStrategy::Ties => {
102 let density = request.density.unwrap_or(0.2);
103 let base = std::collections::HashMap::new();
104 let config = TiesConfig { density };
105 entrenar::merge::ties_merge(&models, &base, &config)
106 }
107 MergeStrategy::Dare => {
108 let drop_prob = request.drop_prob.unwrap_or(0.5);
109 let base = std::collections::HashMap::new();
110 let config = DareConfig { drop_prob, seed: request.seed };
111 entrenar::merge::dare_merge(&models, &base, &config)
112 }
113 MergeStrategy::Slerp => {
114 let t = request.interpolation_t.unwrap_or(0.5);
115 let config = SlerpConfig { t };
116 entrenar::merge::slerp_merge(&models[0], &models[1], &config)
117 }
118 };
119
120 let (status, error) = match merge_result {
121 Ok(_merged) => ("complete".to_string(), None),
122 Err(e) => ("failed".to_string(), Some(e.to_string())),
123 };
124
125 let _ = state; MergeResult {
128 merge_id: format!("merge-{}", epoch_secs()),
129 strategy: request.strategy.clone(),
130 models: request.models.clone(),
131 status,
132 simulated: true, error,
134 output_path: None,
135 }
136}
137
138#[cfg(not(feature = "entrenar"))]
140fn execute_merge(_state: &BancoState, request: &MergeRequest) -> MergeResult {
141 MergeResult {
142 merge_id: format!("merge-{}", epoch_secs()),
143 strategy: request.strategy.clone(),
144 models: request.models.clone(),
145 status: "dry_run".to_string(),
146 simulated: true,
147 error: None,
148 output_path: None,
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
158#[serde(rename_all = "snake_case")]
159pub enum MergeStrategy {
160 WeightedAverage,
161 Ties,
162 Dare,
163 Slerp,
164}
165
166#[derive(Debug, Clone, Deserialize)]
168pub struct MergeRequest {
169 pub models: Vec<String>,
171 pub strategy: MergeStrategy,
173 #[serde(default)]
175 pub weights: Option<Vec<f32>>,
176 #[serde(default)]
178 pub density: Option<f32>,
179 #[serde(default)]
181 pub drop_prob: Option<f32>,
182 #[serde(default)]
184 pub interpolation_t: Option<f32>,
185 #[serde(default)]
187 pub seed: Option<u64>,
188 #[serde(default)]
190 pub output_format: Option<String>,
191}
192
193#[derive(Debug, Clone, Serialize)]
195pub struct MergeResult {
196 pub merge_id: String,
197 pub strategy: MergeStrategy,
198 pub models: Vec<String>,
199 pub status: String,
200 #[serde(default)]
202 pub simulated: bool,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub error: Option<String>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub output_path: Option<String>,
207}
208
209#[derive(Debug, Serialize)]
211pub struct MergeStrategiesResponse {
212 pub strategies: Vec<StrategyInfo>,
213}
214
215#[derive(Debug, Serialize)]
217pub struct StrategyInfo {
218 pub name: String,
219 pub description: String,
220 pub min_models: usize,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 pub max_models: Option<usize>,
223}
224
225fn epoch_secs() -> u64 {
226 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
227}