#[cfg(not(feature = "std"))]
use alloc::{format, string::ToString};
use crate::core::error::{OxiRouterError, Result};
use crate::ml::{ModelState, ModelType};
#[derive(Debug, Clone)]
pub enum MergeStrategy {
Average,
WeightedAverage(f32),
KeepLatest,
KeepBest,
}
pub fn merge_states(
local: &mut ModelState,
remote: &ModelState,
strategy: MergeStrategy,
) -> Result<()> {
if local.config.model_type != remote.config.model_type {
return Err(OxiRouterError::IncompatibleModel {
reason: format!(
"model_type mismatch: local={:?}, remote={:?}",
local.config.model_type, remote.config.model_type
),
});
}
if local.config.model_type == ModelType::Ensemble {
return Err(OxiRouterError::IncompatibleModel {
reason:
"Merging Ensemble models is not supported; merge individual components separately"
.to_string(),
});
}
if local.config.feature_dim != remote.config.feature_dim {
return Err(OxiRouterError::IncompatibleModel {
reason: format!(
"feature_dim mismatch: local={}, remote={}",
local.config.feature_dim, remote.config.feature_dim
),
});
}
if local.source_ids != remote.source_ids {
return Err(OxiRouterError::IncompatibleModel {
reason: format!(
"source_ids mismatch: local has {} sources, remote has {}",
local.source_ids.len(),
remote.source_ids.len()
),
});
}
if local.weights.len() != remote.weights.len() {
return Err(OxiRouterError::IncompatibleModel {
reason: format!(
"weights length mismatch: local={}, remote={}",
local.weights.len(),
remote.weights.len()
),
});
}
match strategy {
MergeStrategy::Average => {
for (l, r) in local.weights.iter_mut().zip(&remote.weights) {
*l = (*l + *r) * 0.5;
}
for (l, r) in local.extra_params.iter_mut().zip(&remote.extra_params) {
*l = (*l + *r) * 0.5;
}
}
MergeStrategy::WeightedAverage(w) => {
let w = w.clamp(0.0, 1.0);
for (l, r) in local.weights.iter_mut().zip(&remote.weights) {
*l = (1.0 - w) * *l + w * *r;
}
}
MergeStrategy::KeepLatest => {
if remote.iterations > local.iterations {
*local = remote.clone();
}
}
MergeStrategy::KeepBest => {
let local_reward = local.extra_params.last().copied().unwrap_or(0.0);
let remote_reward = remote.extra_params.last().copied().unwrap_or(0.0);
if remote_reward > local_reward {
*local = remote.clone();
}
}
}
Ok(())
}