use crate::error::Result;
use crate::filter::RemoteTarget;
use crate::graph::NodeId;
use crate::value::Value;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(tag = "type")]
#[non_exhaustive]
pub enum TrainingStrategy {
#[default]
Local,
DataParallel {
num_replicas: usize,
aggregation: GradientAggregation,
},
ModelParallel {
partitions: Vec<Partition>,
communication: CommunicationProtocol,
},
Federated {
num_clients: usize,
rounds: usize,
aggregation: FederatedAggregation,
client_selection: ClientSelection,
},
PopulationBased {
population_size: usize,
generations: usize,
exploit: ExploitStrategy,
explore: ExploreStrategy,
},
Custom {
coordinator: String,
config: serde_json::Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum GradientAggregation {
AllReduce,
ParameterServer,
Decentralized { topology: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Partition {
pub node_ids: Vec<NodeId>,
pub target: RemoteTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "protocol")]
#[non_exhaustive]
pub enum CommunicationProtocol {
DataStore,
Direct,
Pipeline { micro_batch_size: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum FederatedAggregation {
FedAvg,
FedProx { mu: f64 },
FedYogi { beta1: f64, beta2: f64, tau: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum ClientSelection {
All,
Random { fraction: f64 },
ByCapability { required_tags: Vec<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum ExploitStrategy {
Truncation { fraction: f64 },
Binary { threshold: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum ExploreStrategy {
Perturbation { factor: f64 },
Resample,
}
pub trait StrategyContext {
fn num_workers(&self) -> usize;
fn execute_on_worker(
&self,
worker_idx: usize,
plan: &serde_json::Value,
input: &Value,
y: Option<&Value>,
) -> Result<HashMap<String, Value>>;
fn get_state(&self, worker_idx: usize, node_ids: &[String]) -> Result<HashMap<String, Value>>;
fn set_state(&self, worker_idx: usize, states: &HashMap<String, Value>) -> Result<()>;
fn get_gradients(
&self,
worker_idx: usize,
node_ids: &[String],
) -> Result<HashMap<String, Value>>;
fn apply_gradients(&self, worker_idx: usize, gradients: &HashMap<String, Value>) -> Result<()>;
}
pub trait StrategyExecutor {
fn fit(
&self,
ctx: &dyn StrategyContext,
input: &Value,
y: Option<&Value>,
node_ids: &[String],
) -> Result<HashMap<String, Value>>;
}
pub trait GradientAggregator {
fn aggregate(&self, gradients: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>>;
}
pub trait StateAggregator {
fn aggregate(&self, states: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>>;
}
impl StrategyExecutor for TrainingStrategy {
fn fit(
&self,
ctx: &dyn StrategyContext,
input: &Value,
y: Option<&Value>,
node_ids: &[String],
) -> Result<HashMap<String, Value>> {
match self {
TrainingStrategy::Local => {
ctx.execute_on_worker(0, &serde_json::json!({}), input, y)
}
TrainingStrategy::DataParallel {
num_replicas,
aggregation,
} => {
let n = (*num_replicas).min(ctx.num_workers());
let shards = shard_value(input, n);
for (i, shard) in shards.iter().enumerate() {
ctx.execute_on_worker(i, &serde_json::json!({}), shard, y)?;
}
let mut all_grads = Vec::new();
for i in 0..n {
all_grads.push(ctx.get_gradients(i, node_ids)?);
}
let averaged = aggregation.aggregate(&all_grads)?;
for i in 0..n {
ctx.apply_gradients(i, &averaged)?;
}
ctx.get_state(0, node_ids)
}
TrainingStrategy::Federated {
num_clients,
rounds,
aggregation,
..
} => {
let n = (*num_clients).min(ctx.num_workers());
let shards = shard_value(input, n);
for _round in 0..*rounds {
for (i, shard) in shards.iter().enumerate().take(n) {
ctx.execute_on_worker(i, &serde_json::json!({}), shard, y)?;
}
let mut all_states = Vec::new();
for i in 0..n {
all_states.push(ctx.get_state(i, node_ids)?);
}
let aggregated = aggregation.aggregate(&all_states)?;
for i in 0..n {
ctx.set_state(i, &aggregated)?;
}
}
ctx.get_state(0, node_ids)
}
TrainingStrategy::ModelParallel { .. } => {
Err(crate::error::SomaError::Other(
"ModelParallel strategy execution not yet implemented".into(),
))
}
TrainingStrategy::PopulationBased { .. } => {
Err(crate::error::SomaError::Other(
"PopulationBased strategy execution not yet implemented".into(),
))
}
TrainingStrategy::Custom { .. } => Err(crate::error::SomaError::Other(
"Custom strategy requires a user-provided coordinator".into(),
)),
}
}
}
impl GradientAggregator for GradientAggregation {
fn aggregate(&self, gradients: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>> {
match self {
GradientAggregation::AllReduce | GradientAggregation::ParameterServer => {
Ok(gradients.first().cloned().unwrap_or_default())
}
GradientAggregation::Decentralized { .. } => {
Ok(gradients.first().cloned().unwrap_or_default())
}
}
}
}
impl StateAggregator for FederatedAggregation {
fn aggregate(&self, states: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>> {
match self {
FederatedAggregation::FedAvg
| FederatedAggregation::FedProx { .. }
| FederatedAggregation::FedYogi { .. } => {
Ok(states.first().cloned().unwrap_or_default())
}
}
}
}
fn shard_value(value: &Value, n: usize) -> Vec<Value> {
match value {
Value::Tensor { values, shape } if !shape.is_empty() && shape[0] >= n => {
let rows = shape[0];
let row_size: usize = shape[1..].iter().product::<usize>().max(1);
let shard_rows = rows / n;
let mut shards = Vec::new();
for i in 0..n {
let start = i * shard_rows;
let end = if i == n - 1 { rows } else { start + shard_rows };
let flat_start = start * row_size;
let flat_end = end * row_size;
let shard_vals = values[flat_start..flat_end].to_vec();
let mut shard_shape = shape.clone();
shard_shape[0] = end - start;
shards.push(Value::tensor(shard_vals, shard_shape));
}
shards
}
_ => (0..n).map(|_| value.clone()).collect(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_local() {
assert!(matches!(
TrainingStrategy::default(),
TrainingStrategy::Local
));
}
#[test]
fn serde_roundtrip_data_parallel() {
let strategy = TrainingStrategy::DataParallel {
num_replicas: 4,
aggregation: GradientAggregation::AllReduce,
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(
parsed,
TrainingStrategy::DataParallel {
num_replicas: 4,
..
}
));
}
#[test]
fn serde_roundtrip_model_parallel() {
let strategy = TrainingStrategy::ModelParallel {
partitions: vec![
Partition {
node_ids: vec!["embed".into(), "backbone".into()],
target: RemoteTarget::Tag("gpu-0".into()),
},
Partition {
node_ids: vec!["head_a".into()],
target: RemoteTarget::Tag("gpu-1".into()),
},
],
communication: CommunicationProtocol::Pipeline {
micro_batch_size: 4,
},
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(parsed, TrainingStrategy::ModelParallel { .. }));
}
#[test]
fn serde_roundtrip_federated() {
let strategy = TrainingStrategy::Federated {
num_clients: 10,
rounds: 50,
aggregation: FederatedAggregation::FedProx { mu: 0.01 },
client_selection: ClientSelection::Random { fraction: 0.3 },
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(
parsed,
TrainingStrategy::Federated {
num_clients: 10,
rounds: 50,
..
}
));
}
#[test]
fn serde_roundtrip_pbt() {
let strategy = TrainingStrategy::PopulationBased {
population_size: 20,
generations: 50,
exploit: ExploitStrategy::Truncation { fraction: 0.2 },
explore: ExploreStrategy::Perturbation { factor: 0.2 },
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(
parsed,
TrainingStrategy::PopulationBased {
population_size: 20,
..
}
));
}
}