use ringkernel_core::runtime::KernelId;
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
pub fn kernel_id_to_u64(id: &KernelId) -> u64 {
let mut hasher = DefaultHasher::new();
id.as_str().hash(&mut hasher);
hasher.finish()
}
#[derive(Debug, Clone)]
pub struct IterativeState {
pub iteration: u64,
pub last_delta: f64,
pub convergence_threshold: f64,
pub max_iterations: u64,
pub converged: bool,
}
impl IterativeState {
pub fn new(convergence_threshold: f64, max_iterations: u64) -> Self {
Self {
iteration: 0,
last_delta: f64::MAX,
convergence_threshold,
max_iterations,
converged: false,
}
}
pub fn update(&mut self, delta: f64) -> bool {
self.iteration += 1;
self.last_delta = delta;
self.converged =
delta < self.convergence_threshold || self.iteration >= self.max_iterations;
self.converged
}
pub fn should_continue(&self) -> bool {
!self.converged && self.iteration < self.max_iterations
}
pub fn reset(&mut self) {
self.iteration = 0;
self.last_delta = f64::MAX;
self.converged = false;
}
pub fn summary(&self) -> IterativeConvergenceSummary {
IterativeConvergenceSummary {
iterations: self.iteration,
final_delta: self.last_delta,
converged: self.converged,
reached_max: self.iteration >= self.max_iterations,
}
}
}
#[derive(Debug, Clone)]
pub struct IterativeConvergenceSummary {
pub iterations: u64,
pub final_delta: f64,
pub converged: bool,
pub reached_max: bool,
}
#[derive(Debug, Clone)]
pub struct PipelineTracker {
stages: Vec<String>,
current_stage: usize,
stage_timings_us: HashMap<String, u64>,
total_items_processed: u64,
}
impl PipelineTracker {
pub fn new(stages: Vec<String>) -> Self {
Self {
stages,
current_stage: 0,
stage_timings_us: HashMap::new(),
total_items_processed: 0,
}
}
pub fn current_stage(&self) -> Option<&str> {
self.stages.get(self.current_stage).map(|s| s.as_str())
}
pub fn next_stage(&self) -> Option<&str> {
self.stages.get(self.current_stage + 1).map(|s| s.as_str())
}
pub fn advance(&mut self, elapsed_us: u64) -> bool {
if let Some(stage) = self.stages.get(self.current_stage) {
self.stage_timings_us.insert(stage.clone(), elapsed_us);
}
if self.current_stage + 1 < self.stages.len() {
self.current_stage += 1;
true
} else {
false
}
}
pub fn record_items(&mut self, count: u64) {
self.total_items_processed += count;
}
pub fn is_complete(&self) -> bool {
self.current_stage >= self.stages.len().saturating_sub(1)
&& self.stage_timings_us.len() >= self.stages.len()
}
pub fn total_time_us(&self) -> u64 {
self.stage_timings_us.values().sum()
}
pub fn stage_timing(&self, stage: &str) -> Option<u64> {
self.stage_timings_us.get(stage).copied()
}
pub fn reset(&mut self) {
self.current_stage = 0;
self.stage_timings_us.clear();
self.total_items_processed = 0;
}
}
#[derive(Debug)]
pub struct ScatterGatherState<T> {
pub worker_count: usize,
pub results: Vec<T>,
pub responded_workers: Vec<KernelId>,
pub start_time_us: u64,
}
impl<T> ScatterGatherState<T> {
pub fn new(worker_count: usize, start_time_us: u64) -> Self {
Self {
worker_count,
results: Vec::with_capacity(worker_count),
responded_workers: Vec::with_capacity(worker_count),
start_time_us,
}
}
pub fn receive_result(&mut self, worker: KernelId, result: T) {
if !self.responded_workers.contains(&worker) {
self.responded_workers.push(worker);
self.results.push(result);
}
}
pub fn is_complete(&self) -> bool {
self.responded_workers.len() >= self.worker_count
}
pub fn pending_count(&self) -> usize {
self.worker_count
.saturating_sub(self.responded_workers.len())
}
pub fn take_results(self) -> Vec<T> {
self.results
}
}
#[derive(Debug, Clone)]
pub struct FanOutTracker {
destinations: Vec<KernelId>,
delivery_status: HashMap<String, bool>,
broadcast_count: u64,
}
impl FanOutTracker {
pub fn new() -> Self {
Self {
destinations: Vec::new(),
delivery_status: HashMap::new(),
broadcast_count: 0,
}
}
pub fn add_destination(&mut self, dest: KernelId) {
if !self
.destinations
.iter()
.any(|d| d.as_str() == dest.as_str())
{
self.destinations.push(dest);
}
}
pub fn remove_destination(&mut self, dest: &KernelId) {
self.destinations.retain(|d| d.as_str() != dest.as_str());
self.delivery_status.remove(dest.as_str());
}
pub fn destinations(&self) -> &[KernelId] {
&self.destinations
}
pub fn record_broadcast(&mut self) {
self.broadcast_count += 1;
for dest in &self.destinations {
self.delivery_status
.insert(dest.as_str().to_string(), false);
}
}
pub fn mark_delivered(&mut self, dest: &KernelId) {
self.delivery_status.insert(dest.as_str().to_string(), true);
}
pub fn delivery_count(&self) -> usize {
self.delivery_status.values().filter(|&&v| v).count()
}
pub fn broadcast_count(&self) -> u64 {
self.broadcast_count
}
pub fn destination_count(&self) -> usize {
self.destinations.len()
}
}
impl Default for FanOutTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum K2KControlMessage {
Start {
correlation_id: u64,
},
Stop {
reason: String,
},
GetStatus {
correlation_id: u64,
},
IterationComplete {
iteration: u64,
delta: f64,
worker_id: u64,
},
Converged {
iterations: u64,
final_delta: f64,
},
Error {
message: String,
code: u32,
},
Heartbeat {
sequence: u64,
timestamp_us: u64,
},
Barrier {
barrier_id: u64,
worker_id: u64,
},
}
#[derive(Debug, Clone)]
pub struct K2KWorkerResult<T> {
pub worker_id: KernelId,
pub correlation_id: u64,
pub result: T,
pub processing_time_us: u64,
}
impl<T> K2KWorkerResult<T> {
pub fn new(
worker_id: KernelId,
correlation_id: u64,
result: T,
processing_time_us: u64,
) -> Self {
Self {
worker_id,
correlation_id,
result,
processing_time_us,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
#[derive(Default)]
pub enum K2KPriority {
Low = 0,
#[default]
Normal = 64,
High = 128,
Critical = 192,
RealTime = 255,
}
impl From<K2KPriority> for u8 {
fn from(p: K2KPriority) -> u8 {
p as u8
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_iterative_state_convergence() {
let mut state = IterativeState::new(1e-6, 100);
assert!(state.should_continue());
assert!(!state.converged);
state.update(0.1);
assert!(!state.converged);
assert_eq!(state.iteration, 1);
state.update(0.01);
assert!(!state.converged);
state.update(1e-7); assert!(state.converged);
let summary = state.summary();
assert_eq!(summary.iterations, 3);
assert!(summary.converged);
}
#[test]
fn test_iterative_state_max_iterations() {
let mut state = IterativeState::new(1e-6, 3);
state.update(0.1);
state.update(0.05);
state.update(0.01);
assert!(state.converged);
let summary = state.summary();
assert!(summary.reached_max);
}
#[test]
fn test_pipeline_tracker() {
let stages = vec![
"ingest".to_string(),
"transform".to_string(),
"output".to_string(),
];
let mut tracker = PipelineTracker::new(stages);
assert_eq!(tracker.current_stage(), Some("ingest"));
assert_eq!(tracker.next_stage(), Some("transform"));
tracker.advance(1000);
assert_eq!(tracker.current_stage(), Some("transform"));
tracker.advance(2000);
assert_eq!(tracker.current_stage(), Some("output"));
tracker.advance(500);
assert!(tracker.is_complete());
assert_eq!(tracker.total_time_us(), 3500);
}
#[test]
fn test_scatter_gather_state() {
let mut state: ScatterGatherState<i32> = ScatterGatherState::new(3, 0);
assert!(!state.is_complete());
assert_eq!(state.pending_count(), 3);
state.receive_result(KernelId::new("worker1"), 10);
state.receive_result(KernelId::new("worker2"), 20);
assert_eq!(state.pending_count(), 1);
state.receive_result(KernelId::new("worker3"), 30);
assert!(state.is_complete());
let results = state.take_results();
assert_eq!(results, vec![10, 20, 30]);
}
#[test]
fn test_fan_out_tracker() {
let mut tracker = FanOutTracker::new();
tracker.add_destination(KernelId::new("dest1"));
tracker.add_destination(KernelId::new("dest2"));
tracker.add_destination(KernelId::new("dest1"));
assert_eq!(tracker.destination_count(), 2);
tracker.record_broadcast();
assert_eq!(tracker.broadcast_count(), 1);
assert_eq!(tracker.delivery_count(), 0);
tracker.mark_delivered(&KernelId::new("dest1"));
assert_eq!(tracker.delivery_count(), 1);
}
#[test]
fn test_kernel_id_to_u64() {
let id1 = KernelId::new("kernel-a");
let id2 = KernelId::new("kernel-b");
let id1_copy = KernelId::new("kernel-a");
let hash1 = kernel_id_to_u64(&id1);
let hash2 = kernel_id_to_u64(&id2);
let hash1_copy = kernel_id_to_u64(&id1_copy);
assert_ne!(hash1, hash2);
assert_eq!(hash1, hash1_copy);
}
}