use super::collective::CollectiveError;
use super::process::{Communicator, ProcessError};
use oxiarc_lz4::{compress as lz4_compress, decompress as lz4_decompress};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum OptimizationError {
#[error("Process error: {0}")]
Process(#[from] ProcessError),
#[error("Collective operation error: {0}")]
Collective(#[from] CollectiveError),
#[error("Topology detection failed: {0}")]
TopologyError(String),
#[error("Measurement failed: {0}")]
MeasurementError(String),
#[error("Optimization failed: {0}")]
OptimizationFailed(String),
#[error("Compression error: {0}")]
CompressionError(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NetworkTopology {
FullyConnected,
Tree { arity: usize },
Ring,
Mesh { dims: [usize; 3] },
Hypercube { dimension: usize },
FatTree { levels: usize },
Custom,
}
impl NetworkTopology {
pub fn optimal_algorithm(&self, op: &str) -> Algorithm {
match (self, op) {
(NetworkTopology::Tree { .. }, "broadcast") => Algorithm::TreeBroadcast,
(NetworkTopology::Ring, "reduce") => Algorithm::RingReduce,
(NetworkTopology::Hypercube { .. }, "allreduce") => Algorithm::HypercubeAllReduce,
_ => Algorithm::Default,
}
}
pub fn has_direct_connection(&self, src: usize, dst: usize, _size: usize) -> bool {
match self {
NetworkTopology::FullyConnected => true,
NetworkTopology::Ring => (src as i64 - dst as i64).abs() == 1,
_ => false, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Algorithm {
Default,
TreeBroadcast,
RingReduce,
HypercubeAllReduce,
RecursiveDoubling,
PairwiseExchange,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BandwidthModel {
measurements: Vec<(usize, usize, f64)>,
average: f64,
min: f64,
max: f64,
}
impl BandwidthModel {
pub fn new() -> Self {
Self {
measurements: Vec::new(),
average: 0.0,
min: 0.0,
max: 0.0,
}
}
pub fn add_measurement(&mut self, src: usize, dst: usize, bandwidth: f64) {
self.measurements.push((src, dst, bandwidth));
self.update_statistics();
}
fn update_statistics(&mut self) {
if self.measurements.is_empty() {
return;
}
let values: Vec<f64> = self.measurements.iter().map(|(_, _, bw)| *bw).collect();
self.average = values.iter().sum::<f64>() / values.len() as f64;
self.min = values.iter().copied().fold(f64::INFINITY, f64::min);
self.max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
}
pub fn estimate(&self, src: usize, dst: usize) -> f64 {
for &(s, d, bw) in &self.measurements {
if (s == src && d == dst) || (s == dst && d == src) {
return bw;
}
}
self.average
}
}
impl Default for BandwidthModel {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyModel {
measurements: Vec<(usize, usize, f64)>,
average: f64,
min: f64,
max: f64,
}
impl LatencyModel {
pub fn new() -> Self {
Self {
measurements: Vec::new(),
average: 0.0,
min: 0.0,
max: 0.0,
}
}
pub fn add_measurement(&mut self, src: usize, dst: usize, latency: f64) {
self.measurements.push((src, dst, latency));
self.update_statistics();
}
fn update_statistics(&mut self) {
if self.measurements.is_empty() {
return;
}
let values: Vec<f64> = self.measurements.iter().map(|(_, _, lat)| *lat).collect();
self.average = values.iter().sum::<f64>() / values.len() as f64;
self.min = values.iter().copied().fold(f64::INFINITY, f64::min);
self.max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
}
pub fn estimate(&self, src: usize, dst: usize) -> f64 {
for &(s, d, lat) in &self.measurements {
if (s == src && d == dst) || (s == dst && d == src) {
return lat;
}
}
self.average
}
}
impl Default for LatencyModel {
fn default() -> Self {
Self::new()
}
}
pub async fn detect_topology(comm: &Communicator) -> Result<NetworkTopology, OptimizationError> {
let size = comm.size();
if size.is_power_of_two() && size >= 8 {
Ok(NetworkTopology::Hypercube {
dimension: (size as f64).log2() as usize,
})
} else {
Ok(NetworkTopology::FullyConnected)
}
}
pub async fn measure_bandwidth(
_src: usize,
_dst: usize,
_comm: &Communicator,
) -> Result<f64, OptimizationError> {
Ok(1000.0) }
pub async fn measure_latency(
_src: usize,
_dst: usize,
_comm: &Communicator,
) -> Result<f64, OptimizationError> {
Ok(10.0) }
pub fn optimize_collective(
_op: &str,
topology: &NetworkTopology,
) -> Result<Algorithm, OptimizationError> {
Ok(topology.optimal_algorithm(_op))
}
pub async fn overlap_compute_communicate() -> Result<(), OptimizationError> {
Ok(())
}
pub fn compress_data<T: Serialize>(data: &[T]) -> Result<Vec<u8>, OptimizationError> {
let json_bytes = serde_json::to_vec(data)
.map_err(|e| OptimizationError::CompressionError(format!("Serialization error: {}", e)))?;
let uncompressed_size = json_bytes.len() as u64;
let compressed = lz4_compress(&json_bytes).map_err(|e| {
OptimizationError::CompressionError(format!("LZ4 compression error: {}", e))
})?;
let mut result = Vec::with_capacity(8 + compressed.len());
result.extend_from_slice(&uncompressed_size.to_le_bytes());
result.extend_from_slice(&compressed);
Ok(result)
}
pub fn decompress_data<T: for<'de> Deserialize<'de>>(
data: &[u8],
) -> Result<Vec<T>, OptimizationError> {
if data.len() < 8 {
return Err(OptimizationError::CompressionError(format!(
"Data too short: expected at least 8 bytes, got {}",
data.len()
)));
}
let size_bytes: [u8; 8] = data[..8].try_into().map_err(|_| {
OptimizationError::CompressionError("Failed to read uncompressed size header".to_string())
})?;
let uncompressed_size = u64::from_le_bytes(size_bytes) as usize;
let json_bytes = lz4_decompress(&data[8..], uncompressed_size).map_err(|e| {
OptimizationError::CompressionError(format!("LZ4 decompression error: {}", e))
})?;
serde_json::from_slice(&json_bytes)
.map_err(|e| OptimizationError::CompressionError(format!("Deserialization error: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_network_topology() {
let topology = NetworkTopology::Tree { arity: 2 };
assert_eq!(
topology.optimal_algorithm("broadcast"),
Algorithm::TreeBroadcast
);
}
#[test]
fn test_bandwidth_model() {
let mut model = BandwidthModel::new();
model.add_measurement(0, 1, 1000.0);
model.add_measurement(1, 2, 950.0);
model.add_measurement(2, 3, 1050.0);
assert_eq!(model.estimate(0, 1), 1000.0);
assert!((model.average - 1000.0).abs() < 50.0);
}
#[test]
fn test_latency_model() {
let mut model = LatencyModel::new();
model.add_measurement(0, 1, 10.0);
model.add_measurement(1, 2, 12.0);
model.add_measurement(2, 3, 11.0);
assert_eq!(model.estimate(0, 1), 10.0);
assert!((model.average - 11.0).abs() < 1.0);
}
#[test]
fn test_topology_direct_connection() {
let topology = NetworkTopology::FullyConnected;
assert!(topology.has_direct_connection(0, 1, 4));
assert!(topology.has_direct_connection(0, 3, 4));
let ring = NetworkTopology::Ring;
assert!(ring.has_direct_connection(0, 1, 4));
assert!(!ring.has_direct_connection(0, 2, 4));
}
#[test]
fn test_compress_decompress_roundtrip_floats() {
let data: Vec<f64> = vec![1.0, 2.5, -3.14, 0.0, f64::MAX];
let compressed = compress_data(&data).expect("compression should succeed");
let recovered: Vec<f64> =
decompress_data(&compressed).expect("decompression should succeed");
assert_eq!(data.len(), recovered.len());
for (a, b) in data.iter().zip(recovered.iter()) {
assert!(
(a - b).abs() < f64::EPSILON * 100.0,
"mismatch: {} vs {}",
a,
b
);
}
}
#[test]
fn test_compress_decompress_roundtrip_strings() {
let data: Vec<String> = vec![
"hello".to_string(),
"world".to_string(),
"oxiarc".to_string(),
];
let compressed = compress_data(&data).expect("compression should succeed");
let recovered: Vec<String> =
decompress_data(&compressed).expect("decompression should succeed");
assert_eq!(data, recovered);
}
#[test]
fn test_compress_empty_slice() {
let data: Vec<u32> = vec![];
let compressed = compress_data(&data).expect("compression of empty slice should succeed");
let recovered: Vec<u32> =
decompress_data(&compressed).expect("decompression should succeed");
assert_eq!(recovered, data);
}
#[test]
fn test_compress_highly_compressible() {
let data: Vec<u32> = vec![42u32; 10_000];
let compressed = compress_data(&data).expect("compression should succeed");
assert!(
compressed.len() < data.len() * 4,
"expected compression, got {} bytes for {} elements",
compressed.len(),
data.len()
);
let recovered: Vec<u32> =
decompress_data(&compressed).expect("decompression should succeed");
assert_eq!(data, recovered);
}
#[test]
fn test_decompress_invalid_data() {
let bad_data = b"too short";
let result: Result<Vec<u32>, _> = decompress_data(bad_data);
assert!(result.is_err(), "should fail on short data");
}
}