use ringkernel_derive::RingMessage;
use rkyv::{Archive, Deserialize, Serialize};
use rustkernel_core::messages::MessageId;
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 200)]
pub struct KMeansInitRing {
pub id: MessageId,
pub k: u32,
pub n_features: u32,
pub centroids_packed: [i64; 32], }
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 201)]
pub struct KMeansInitResponse {
pub request_id: u64,
pub success: bool,
pub k: u32,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 202)]
pub struct KMeansAssignRing {
pub id: MessageId,
pub iteration: u32,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 203)]
pub struct KMeansAssignResponse {
pub request_id: u64,
pub iteration: u32,
pub inertia_fp: i64,
pub points_assigned: u32,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 204)]
pub struct KMeansUpdateRing {
pub id: MessageId,
pub iteration: u32,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 205)]
pub struct KMeansUpdateResponse {
pub request_id: u64,
pub iteration: u32,
pub max_shift_fp: i64,
pub converged: bool,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 206)]
pub struct KMeansQueryRing {
pub id: MessageId,
pub point: [i64; 8], pub n_dims: u8,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 207)]
pub struct KMeansQueryResponse {
pub request_id: u64,
pub cluster: u32,
pub distance_fp: i64,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 260)]
pub struct K2KPartialCentroid {
pub id: MessageId,
pub worker_id: u64,
pub iteration: u64,
pub cluster_id: u32,
pub point_count: u32,
pub coord_sum_fp: [i64; 8],
pub n_dims: u8,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 261)]
pub struct K2KCentroidAggregation {
pub request_id: u64,
pub cluster_id: u32,
pub iteration: u64,
pub new_centroid_fp: [i64; 8],
pub total_points: u32,
pub shift_fp: i64,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 262)]
pub struct K2KKMeansSync {
pub id: MessageId,
pub worker_id: u64,
pub iteration: u64,
pub local_inertia_fp: i64,
pub points_processed: u32,
pub max_shift_fp: i64,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 263)]
pub struct K2KKMeansSyncResponse {
pub request_id: u64,
pub iteration: u64,
pub all_synced: bool,
pub global_inertia_fp: i64,
pub global_max_shift_fp: i64,
pub converged: bool,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 264)]
pub struct K2KCentroidBroadcast {
pub id: MessageId,
pub iteration: u64,
pub k: u32,
pub n_dims: u8,
pub centroids_packed: [i64; 32],
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
#[archive(check_bytes)]
#[message(type_id = 265)]
pub struct K2KCentroidBroadcastAck {
pub request_id: u64,
pub worker_id: u64,
pub iteration: u64,
pub applied: bool,
}
#[inline]
pub fn to_fixed_point(value: f64) -> i64 {
(value * 100_000_000.0) as i64
}
#[inline]
pub fn from_fixed_point(fp: i64) -> f64 {
fp as f64 / 100_000_000.0
}
pub fn pack_coordinates(coords: &[f64], output: &mut [i64; 8]) {
for (i, &c) in coords.iter().take(8).enumerate() {
output[i] = to_fixed_point(c);
}
}
pub fn unpack_coordinates(input: &[i64; 8], n_dims: usize) -> Vec<f64> {
input
.iter()
.take(n_dims)
.map(|&fp| from_fixed_point(fp))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fixed_point_conversion() {
let value = 2.5;
let fp = to_fixed_point(value);
let back = from_fixed_point(fp);
assert!((value - back).abs() < 1e-8);
}
#[test]
fn test_pack_unpack_coordinates() {
let coords = vec![1.5, 2.5, 3.5];
let mut packed = [0i64; 8];
pack_coordinates(&coords, &mut packed);
let unpacked = unpack_coordinates(&packed, 3);
assert_eq!(unpacked.len(), 3);
for (a, b) in coords.iter().zip(unpacked.iter()) {
assert!((a - b).abs() < 1e-7);
}
}
#[test]
fn test_kmeans_init_ring() {
let msg = KMeansInitRing {
id: MessageId(1),
k: 3,
n_features: 2,
centroids_packed: [0; 32],
};
assert_eq!(msg.k, 3);
}
#[test]
fn test_k2k_partial_centroid() {
let mut coord_sum = [0i64; 8];
pack_coordinates(&[10.0, 20.0], &mut coord_sum);
let msg = K2KPartialCentroid {
id: MessageId(2),
worker_id: 1,
iteration: 5,
cluster_id: 0,
point_count: 100,
coord_sum_fp: coord_sum,
n_dims: 2,
};
assert_eq!(msg.point_count, 100);
assert_eq!(msg.iteration, 5);
}
#[test]
fn test_k2k_kmeans_sync() {
let msg = K2KKMeansSync {
id: MessageId(3),
worker_id: 2,
iteration: 10,
local_inertia_fp: to_fixed_point(1234.5),
points_processed: 5000,
max_shift_fp: to_fixed_point(0.001),
};
assert_eq!(msg.iteration, 10);
assert_eq!(msg.points_processed, 5000);
}
}