use crate::hierarchy::StateAggregator;
use peat_mesh::routing::{AggregationError, Aggregator, DataDirection, DataPacket, DataType};
use peat_schema::hierarchy::v1::SquadSummary;
use peat_schema::node::v1::{NodeConfig, NodeState};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TelemetryPayload {
pub config: NodeConfig,
pub state: NodeState,
}
impl TelemetryPayload {
pub fn new(config: NodeConfig, state: NodeState) -> Self {
Self { config, state }
}
pub fn to_bytes(&self) -> Result<Vec<u8>, AggregationError> {
serde_json::to_vec(self).map_err(AggregationError::from)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, AggregationError> {
serde_json::from_slice(bytes).map_err(AggregationError::from)
}
}
pub struct PacketAggregator;
impl PacketAggregator {
pub fn new() -> Self {
Self
}
pub fn extract_squad_summary(
&self,
packet: &DataPacket,
) -> Result<SquadSummary, AggregationError> {
if packet.data_type != DataType::AggregatedTelemetry {
return Err(AggregationError::InvalidPacketType {
expected: "AggregatedTelemetry".to_string(),
actual: packet.data_type,
});
}
serde_json::from_slice(&packet.payload).map_err(AggregationError::from)
}
}
impl Default for PacketAggregator {
fn default() -> Self {
Self::new()
}
}
impl Aggregator for PacketAggregator {
fn aggregate_telemetry(
&self,
group_id: &str,
leader_id: &str,
telemetry_packets: Vec<DataPacket>,
) -> Result<DataPacket, AggregationError> {
if telemetry_packets.is_empty() {
return Err(AggregationError::EmptyInput);
}
for packet in &telemetry_packets {
if packet.data_type != DataType::Telemetry {
return Err(AggregationError::InvalidPacketType {
expected: "Telemetry".to_string(),
actual: packet.data_type,
});
}
}
let member_states: Result<Vec<(NodeConfig, NodeState)>, AggregationError> =
telemetry_packets
.iter()
.map(|packet| {
let payload = TelemetryPayload::from_bytes(&packet.payload)?;
Ok((payload.config, payload.state))
})
.collect();
let member_states = member_states?;
let squad_summary = StateAggregator::aggregate_squad(group_id, leader_id, member_states)
.map_err(|e| AggregationError::AggregationFailed(e.to_string()))?;
let aggregated_payload =
serde_json::to_vec(&squad_summary).map_err(AggregationError::from)?;
Ok(DataPacket {
packet_id: uuid::Uuid::new_v4().to_string(),
source_node_id: leader_id.to_string(),
destination_node_id: None,
data_type: DataType::AggregatedTelemetry,
direction: DataDirection::Upward,
hop_count: 0,
max_hops: 10,
payload: aggregated_payload,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aggregator_creation() {
let aggregator = PacketAggregator::new();
assert!(std::mem::size_of_val(&aggregator) == 0); }
#[test]
fn test_aggregate_empty_packets() {
let aggregator = PacketAggregator::new();
let result = aggregator.aggregate_telemetry("squad-1", "node-1", vec![]);
assert!(matches!(result, Err(AggregationError::EmptyInput)));
}
#[test]
fn test_aggregate_wrong_packet_type() {
let aggregator = PacketAggregator::new();
let command_packet = DataPacket::command("hq", "node-1", vec![1, 2, 3]);
let result = aggregator.aggregate_telemetry("squad-1", "node-1", vec![command_packet]);
assert!(matches!(
result,
Err(AggregationError::InvalidPacketType { .. })
));
}
#[test]
fn test_extract_summary_wrong_type() {
let aggregator = PacketAggregator::new();
let telemetry_packet = DataPacket::telemetry("node-1", vec![1, 2, 3]);
let result = aggregator.extract_squad_summary(&telemetry_packet);
assert!(matches!(
result,
Err(AggregationError::InvalidPacketType { .. })
));
}
}