use crate::filter::RemoteTarget;
use crate::graph::NodeId;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(tag = "type")]
#[non_exhaustive]
pub enum TrainingStrategy {
#[default]
Local,
DataParallel {
num_replicas: usize,
aggregation: GradientAggregation,
},
ModelParallel {
partitions: Vec<Partition>,
communication: CommunicationProtocol,
},
Federated {
num_clients: usize,
rounds: usize,
aggregation: FederatedAggregation,
client_selection: ClientSelection,
},
PopulationBased {
population_size: usize,
generations: usize,
exploit: ExploitStrategy,
explore: ExploreStrategy,
},
Custom {
coordinator: String,
config: serde_json::Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum GradientAggregation {
AllReduce,
ParameterServer,
Decentralized { topology: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Partition {
pub node_ids: Vec<NodeId>,
pub target: RemoteTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "protocol")]
#[non_exhaustive]
pub enum CommunicationProtocol {
DataStore,
Direct,
Pipeline { micro_batch_size: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum FederatedAggregation {
FedAvg,
FedProx { mu: f64 },
FedYogi { beta1: f64, beta2: f64, tau: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum ClientSelection {
All,
Random { fraction: f64 },
ByCapability { required_tags: Vec<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum ExploitStrategy {
Truncation { fraction: f64 },
Binary { threshold: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
#[non_exhaustive]
pub enum ExploreStrategy {
Perturbation { factor: f64 },
Resample,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_local() {
assert!(matches!(
TrainingStrategy::default(),
TrainingStrategy::Local
));
}
#[test]
fn serde_roundtrip_data_parallel() {
let strategy = TrainingStrategy::DataParallel {
num_replicas: 4,
aggregation: GradientAggregation::AllReduce,
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(
parsed,
TrainingStrategy::DataParallel {
num_replicas: 4,
..
}
));
}
#[test]
fn serde_roundtrip_model_parallel() {
let strategy = TrainingStrategy::ModelParallel {
partitions: vec![
Partition {
node_ids: vec!["embed".into(), "backbone".into()],
target: RemoteTarget::Tag("gpu-0".into()),
},
Partition {
node_ids: vec!["head_a".into()],
target: RemoteTarget::Tag("gpu-1".into()),
},
],
communication: CommunicationProtocol::Pipeline {
micro_batch_size: 4,
},
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(parsed, TrainingStrategy::ModelParallel { .. }));
}
#[test]
fn serde_roundtrip_federated() {
let strategy = TrainingStrategy::Federated {
num_clients: 10,
rounds: 50,
aggregation: FederatedAggregation::FedProx { mu: 0.01 },
client_selection: ClientSelection::Random { fraction: 0.3 },
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(
parsed,
TrainingStrategy::Federated {
num_clients: 10,
rounds: 50,
..
}
));
}
#[test]
fn serde_roundtrip_pbt() {
let strategy = TrainingStrategy::PopulationBased {
population_size: 20,
generations: 50,
exploit: ExploitStrategy::Truncation { fraction: 0.2 },
explore: ExploreStrategy::Perturbation { factor: 0.2 },
};
let json = serde_json::to_string(&strategy).unwrap();
let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
assert!(matches!(
parsed,
TrainingStrategy::PopulationBased {
population_size: 20,
..
}
));
}
}