use crate::domain::Domain;
use crate::hlc::{HlcClock, HlcTimestamp};
use crate::message::MessageEnvelope;
use crate::types::{BlockId, Dim3, FenceScope, GlobalThreadId, MemoryOrder, ThreadId, WarpId};
use tokio::sync::mpsc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MetricType {
Latency,
Throughput,
Counter,
Gauge,
}
#[derive(Debug, Clone)]
pub struct MetricsEntry {
pub operation: String,
pub metric_type: MetricType,
pub value: u64,
pub timestamp: HlcTimestamp,
pub kernel_id: u64,
pub domain: Option<Domain>,
}
#[derive(Debug)]
pub struct ContextMetricsBuffer {
entries: Vec<MetricsEntry>,
capacity: usize,
}
impl ContextMetricsBuffer {
pub fn new(capacity: usize) -> Self {
Self {
entries: Vec::with_capacity(capacity.min(1024)), capacity,
}
}
pub fn record(&mut self, entry: MetricsEntry) {
if self.entries.len() < self.capacity {
self.entries.push(entry);
}
}
pub fn drain(&mut self) -> Vec<MetricsEntry> {
std::mem::take(&mut self.entries)
}
pub fn is_full(&self) -> bool {
self.entries.len() >= self.capacity
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Default for ContextMetricsBuffer {
fn default() -> Self {
Self::new(256)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum AlertSeverity {
Info = 0,
Warning = 1,
Error = 2,
Critical = 3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KernelAlertType {
HighLatency,
QueuePressure,
MemoryPressure,
ProcessingError,
DomainAlert(u32),
Custom(u32),
}
#[derive(Debug, Clone, Copy, Default)]
pub enum AlertRouting {
#[default]
Host,
Kernel(u64),
Domain,
External,
}
#[derive(Debug, Clone)]
pub struct KernelAlert {
pub severity: AlertSeverity,
pub alert_type: KernelAlertType,
pub message: String,
pub source_kernel: u64,
pub source_domain: Option<Domain>,
pub timestamp: HlcTimestamp,
pub routing: AlertRouting,
}
impl KernelAlert {
pub fn new(
severity: AlertSeverity,
alert_type: KernelAlertType,
message: impl Into<String>,
) -> Self {
Self {
severity,
alert_type,
message: message.into(),
source_kernel: 0,
source_domain: None,
timestamp: HlcTimestamp::zero(),
routing: AlertRouting::default(),
}
}
pub fn high_latency(message: impl Into<String>, latency_us: u64) -> Self {
Self::new(
AlertSeverity::Warning,
KernelAlertType::HighLatency,
format!("{} ({}µs)", message.into(), latency_us),
)
}
pub fn error(message: impl Into<String>) -> Self {
Self::new(
AlertSeverity::Error,
KernelAlertType::ProcessingError,
message,
)
}
pub fn queue_pressure(message: impl Into<String>, utilization_pct: u32) -> Self {
Self::new(
AlertSeverity::Warning,
KernelAlertType::QueuePressure,
format!("{} ({}% full)", message.into(), utilization_pct),
)
}
pub fn with_routing(mut self, routing: AlertRouting) -> Self {
self.routing = routing;
self
}
}
pub struct RingContext<'a> {
pub thread_id: ThreadId,
pub block_id: BlockId,
pub block_dim: Dim3,
pub grid_dim: Dim3,
clock: &'a HlcClock,
kernel_id: u64,
backend: ContextBackend,
domain: Option<Domain>,
metrics_buffer: ContextMetricsBuffer,
alert_sender: Option<mpsc::UnboundedSender<KernelAlert>>,
}
#[derive(Debug, Clone)]
pub enum ContextBackend {
Cpu,
Cuda,
Metal,
Wgpu,
}
impl<'a> RingContext<'a> {
pub fn new(
thread_id: ThreadId,
block_id: BlockId,
block_dim: Dim3,
grid_dim: Dim3,
clock: &'a HlcClock,
kernel_id: u64,
backend: ContextBackend,
) -> Self {
Self {
thread_id,
block_id,
block_dim,
grid_dim,
clock,
kernel_id,
backend,
domain: None,
metrics_buffer: ContextMetricsBuffer::default(),
alert_sender: None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn new_with_options(
thread_id: ThreadId,
block_id: BlockId,
block_dim: Dim3,
grid_dim: Dim3,
clock: &'a HlcClock,
kernel_id: u64,
backend: ContextBackend,
domain: Option<Domain>,
metrics_capacity: usize,
alert_sender: Option<mpsc::UnboundedSender<KernelAlert>>,
) -> Self {
Self {
thread_id,
block_id,
block_dim,
grid_dim,
clock,
kernel_id,
backend,
domain,
metrics_buffer: ContextMetricsBuffer::new(metrics_capacity),
alert_sender,
}
}
#[inline]
pub fn thread_id(&self) -> ThreadId {
self.thread_id
}
#[inline]
pub fn block_id(&self) -> BlockId {
self.block_id
}
#[inline]
pub fn global_thread_id(&self) -> GlobalThreadId {
GlobalThreadId::from_block_thread(self.block_id, self.thread_id, self.block_dim)
}
#[inline]
pub fn warp_id(&self) -> WarpId {
let linear = self
.thread_id
.linear_for_dim(self.block_dim.x, self.block_dim.y);
WarpId::from_thread_linear(linear)
}
#[inline]
pub fn lane_id(&self) -> u32 {
let linear = self
.thread_id
.linear_for_dim(self.block_dim.x, self.block_dim.y);
WarpId::lane_id(linear)
}
#[inline]
pub fn block_dim(&self) -> Dim3 {
self.block_dim
}
#[inline]
pub fn grid_dim(&self) -> Dim3 {
self.grid_dim
}
#[inline]
pub fn kernel_id(&self) -> u64 {
self.kernel_id
}
#[inline]
pub fn sync_threads(&self) {
match self.backend {
ContextBackend::Cpu => {
}
_ => {
}
}
}
#[inline]
pub fn sync_grid(&self) {
match self.backend {
ContextBackend::Cpu => {
}
_ => {
}
}
}
#[inline]
pub fn sync_warp(&self) {
match self.backend {
ContextBackend::Cpu => {
}
_ => {
}
}
}
#[inline]
pub fn thread_fence(&self, scope: FenceScope) {
match (self.backend.clone(), scope) {
(ContextBackend::Cpu, _) => {
std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
}
_ => {
}
}
}
#[inline]
pub fn fence_thread(&self) {
self.thread_fence(FenceScope::Thread);
}
#[inline]
pub fn fence_block(&self) {
self.thread_fence(FenceScope::Block);
}
#[inline]
pub fn fence_device(&self) {
self.thread_fence(FenceScope::Device);
}
#[inline]
pub fn fence_system(&self) {
self.thread_fence(FenceScope::System);
}
#[inline]
pub fn now(&self) -> HlcTimestamp {
self.clock.now()
}
#[inline]
pub fn tick(&self) -> HlcTimestamp {
self.clock.tick()
}
#[inline]
pub fn update_clock(&self, received: &HlcTimestamp) -> crate::error::Result<HlcTimestamp> {
self.clock.update(received)
}
#[inline]
pub fn atomic_add(
&self,
ptr: &std::sync::atomic::AtomicU64,
val: u64,
order: MemoryOrder,
) -> u64 {
let ordering = match order {
MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
MemoryOrder::Release => std::sync::atomic::Ordering::Release,
MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
};
ptr.fetch_add(val, ordering)
}
#[inline]
pub fn atomic_cas(
&self,
ptr: &std::sync::atomic::AtomicU64,
expected: u64,
desired: u64,
success: MemoryOrder,
failure: MemoryOrder,
) -> Result<u64, u64> {
let success_ord = match success {
MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
MemoryOrder::Release => std::sync::atomic::Ordering::Release,
MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
};
let failure_ord = match failure {
MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
MemoryOrder::Release => std::sync::atomic::Ordering::Release,
MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
};
ptr.compare_exchange(expected, desired, success_ord, failure_ord)
}
#[inline]
pub fn atomic_exchange(
&self,
ptr: &std::sync::atomic::AtomicU64,
val: u64,
order: MemoryOrder,
) -> u64 {
let ordering = match order {
MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
MemoryOrder::Release => std::sync::atomic::Ordering::Release,
MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
};
ptr.swap(val, ordering)
}
#[inline]
pub fn warp_shuffle<T: Copy>(&self, value: T, src_lane: u32) -> T {
match self.backend {
ContextBackend::Cpu => {
let _ = src_lane;
value
}
_ => {
let _ = src_lane;
value
}
}
}
#[inline]
pub fn warp_shuffle_down<T: Copy>(&self, value: T, delta: u32) -> T {
self.warp_shuffle(value, self.lane_id().saturating_add(delta))
}
#[inline]
pub fn warp_shuffle_up<T: Copy>(&self, value: T, delta: u32) -> T {
self.warp_shuffle(value, self.lane_id().saturating_sub(delta))
}
#[inline]
pub fn warp_shuffle_xor<T: Copy>(&self, value: T, mask: u32) -> T {
self.warp_shuffle(value, self.lane_id() ^ mask)
}
#[inline]
pub fn warp_ballot(&self, predicate: bool) -> u32 {
match self.backend {
ContextBackend::Cpu => {
if predicate {
1
} else {
0
}
}
_ => {
if predicate {
1 << self.lane_id()
} else {
0
}
}
}
}
#[inline]
pub fn warp_all(&self, predicate: bool) -> bool {
match self.backend {
ContextBackend::Cpu => predicate,
_ => {
predicate
}
}
}
#[inline]
pub fn warp_any(&self, predicate: bool) -> bool {
match self.backend {
ContextBackend::Cpu => predicate,
_ => {
predicate
}
}
}
#[inline]
pub fn k2k_send(
&self,
_target_kernel: u64,
_envelope: &MessageEnvelope,
) -> crate::error::Result<()> {
Err(crate::error::RingKernelError::NotSupported(
"K2K messaging requires runtime".to_string(),
))
}
#[inline]
pub fn k2k_try_recv(&self) -> crate::error::Result<MessageEnvelope> {
Err(crate::error::RingKernelError::NotSupported(
"K2K messaging requires runtime".to_string(),
))
}
#[inline]
pub fn domain(&self) -> Option<&Domain> {
self.domain.as_ref()
}
#[inline]
pub fn set_domain(&mut self, domain: Domain) {
self.domain = Some(domain);
}
#[inline]
pub fn clear_domain(&mut self) {
self.domain = None;
}
pub fn record_latency(&mut self, operation: &str, latency_us: u64) {
let entry = MetricsEntry {
operation: operation.to_string(),
metric_type: MetricType::Latency,
value: latency_us,
timestamp: self.clock.now(),
kernel_id: self.kernel_id,
domain: self.domain,
};
self.metrics_buffer.record(entry);
}
pub fn record_throughput(&mut self, operation: &str, count: u64) {
let entry = MetricsEntry {
operation: operation.to_string(),
metric_type: MetricType::Throughput,
value: count,
timestamp: self.clock.now(),
kernel_id: self.kernel_id,
domain: self.domain,
};
self.metrics_buffer.record(entry);
}
pub fn record_counter(&mut self, operation: &str, increment: u64) {
let entry = MetricsEntry {
operation: operation.to_string(),
metric_type: MetricType::Counter,
value: increment,
timestamp: self.clock.now(),
kernel_id: self.kernel_id,
domain: self.domain,
};
self.metrics_buffer.record(entry);
}
pub fn record_gauge(&mut self, operation: &str, value: u64) {
let entry = MetricsEntry {
operation: operation.to_string(),
metric_type: MetricType::Gauge,
value,
timestamp: self.clock.now(),
kernel_id: self.kernel_id,
domain: self.domain,
};
self.metrics_buffer.record(entry);
}
pub fn flush_metrics(&mut self) -> Vec<MetricsEntry> {
self.metrics_buffer.drain()
}
pub fn metrics_count(&self) -> usize {
self.metrics_buffer.len()
}
pub fn metrics_buffer_full(&self) -> bool {
self.metrics_buffer.is_full()
}
pub fn emit_alert(&self, alert: impl Into<KernelAlert>) {
if let Some(ref sender) = self.alert_sender {
let mut alert = alert.into();
alert.source_kernel = self.kernel_id;
alert.source_domain = self.domain;
alert.timestamp = self.clock.now();
let _ = sender.send(alert);
}
}
#[inline]
pub fn has_alert_channel(&self) -> bool {
self.alert_sender.is_some()
}
pub fn alert_if_slow(&self, operation: &str, latency_us: u64, threshold_us: u64) {
if latency_us > threshold_us {
self.emit_alert(KernelAlert::high_latency(
format!("{} exceeded threshold", operation),
latency_us,
));
}
}
}
impl<'a> std::fmt::Debug for RingContext<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RingContext")
.field("thread_id", &self.thread_id)
.field("block_id", &self.block_id)
.field("block_dim", &self.block_dim)
.field("grid_dim", &self.grid_dim)
.field("kernel_id", &self.kernel_id)
.field("backend", &self.backend)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_context(clock: &HlcClock) -> RingContext<'_> {
RingContext::new(
ThreadId::new_1d(0),
BlockId::new_1d(0),
Dim3::new_1d(256),
Dim3::new_1d(1),
clock,
1,
ContextBackend::Cpu,
)
}
#[test]
fn test_thread_identity() {
let clock = HlcClock::new(1);
let ctx = make_test_context(&clock);
assert_eq!(ctx.thread_id().x, 0);
assert_eq!(ctx.block_id().x, 0);
assert_eq!(ctx.global_thread_id().x, 0);
}
#[test]
fn test_warp_id() {
let clock = HlcClock::new(1);
let ctx = RingContext::new(
ThreadId::new_1d(35), BlockId::new_1d(0),
Dim3::new_1d(256),
Dim3::new_1d(1),
&clock,
1,
ContextBackend::Cpu,
);
assert_eq!(ctx.warp_id().0, 1);
assert_eq!(ctx.lane_id(), 3);
}
#[test]
fn test_hlc_operations() {
let clock = HlcClock::new(1);
let ctx = make_test_context(&clock);
let ts1 = ctx.now();
let ts2 = ctx.tick();
assert!(ts2 >= ts1);
}
#[test]
fn test_warp_ballot_cpu() {
let clock = HlcClock::new(1);
let ctx = make_test_context(&clock);
assert_eq!(ctx.warp_ballot(true), 1);
assert_eq!(ctx.warp_ballot(false), 0);
}
#[test]
fn test_domain_operations() {
let clock = HlcClock::new(1);
let mut ctx = make_test_context(&clock);
assert!(ctx.domain().is_none());
ctx.set_domain(Domain::OrderMatching);
assert_eq!(ctx.domain(), Some(&Domain::OrderMatching));
ctx.clear_domain();
assert!(ctx.domain().is_none());
}
#[test]
fn test_context_with_domain() {
let clock = HlcClock::new(1);
let ctx = RingContext::new_with_options(
ThreadId::new_1d(0),
BlockId::new_1d(0),
Dim3::new_1d(256),
Dim3::new_1d(1),
&clock,
42,
ContextBackend::Cpu,
Some(Domain::RiskManagement),
128,
None,
);
assert_eq!(ctx.domain(), Some(&Domain::RiskManagement));
assert_eq!(ctx.kernel_id(), 42);
}
#[test]
fn test_metrics_buffer() {
let mut buffer = ContextMetricsBuffer::new(3);
assert!(buffer.is_empty());
assert!(!buffer.is_full());
assert_eq!(buffer.len(), 0);
let entry = MetricsEntry {
operation: "test".to_string(),
metric_type: MetricType::Latency,
value: 100,
timestamp: HlcTimestamp::zero(),
kernel_id: 1,
domain: None,
};
buffer.record(entry.clone());
assert_eq!(buffer.len(), 1);
buffer.record(entry.clone());
buffer.record(entry.clone());
assert!(buffer.is_full());
let entries = buffer.drain();
assert_eq!(entries.len(), 3);
assert!(buffer.is_empty());
}
#[test]
fn test_record_metrics() {
let clock = HlcClock::new(1);
let mut ctx = RingContext::new_with_options(
ThreadId::new_1d(0),
BlockId::new_1d(0),
Dim3::new_1d(256),
Dim3::new_1d(1),
&clock,
100,
ContextBackend::Cpu,
Some(Domain::Compliance),
256,
None,
);
ctx.record_latency("process_order", 500);
ctx.record_throughput("orders_per_sec", 1000);
ctx.record_counter("total_orders", 1);
ctx.record_gauge("queue_depth", 42);
assert_eq!(ctx.metrics_count(), 4);
let metrics = ctx.flush_metrics();
assert_eq!(metrics.len(), 4);
assert_eq!(metrics[0].operation, "process_order");
assert_eq!(metrics[0].metric_type, MetricType::Latency);
assert_eq!(metrics[0].value, 500);
assert_eq!(metrics[0].kernel_id, 100);
assert_eq!(metrics[0].domain, Some(Domain::Compliance));
assert_eq!(metrics[1].metric_type, MetricType::Throughput);
assert_eq!(metrics[2].metric_type, MetricType::Counter);
assert_eq!(metrics[3].metric_type, MetricType::Gauge);
assert_eq!(metrics[3].value, 42);
assert_eq!(ctx.metrics_count(), 0);
}
#[test]
fn test_kernel_alert_constructors() {
let alert = KernelAlert::high_latency("Slow", 500);
assert_eq!(alert.severity, AlertSeverity::Warning);
assert_eq!(alert.alert_type, KernelAlertType::HighLatency);
assert!(alert.message.contains("500µs"));
let alert = KernelAlert::error("Failed");
assert_eq!(alert.severity, AlertSeverity::Error);
assert_eq!(alert.alert_type, KernelAlertType::ProcessingError);
let alert = KernelAlert::queue_pressure("Input queue", 85);
assert_eq!(alert.alert_type, KernelAlertType::QueuePressure);
assert!(alert.message.contains("85%"));
let alert = KernelAlert::new(
AlertSeverity::Critical,
KernelAlertType::Custom(42),
"Custom alert",
)
.with_routing(AlertRouting::External);
assert_eq!(alert.severity, AlertSeverity::Critical);
assert!(matches!(alert.routing, AlertRouting::External));
}
#[test]
fn test_emit_alert_with_channel() {
let (tx, mut rx) = mpsc::unbounded_channel();
let clock = HlcClock::new(1);
let ctx = RingContext::new_with_options(
ThreadId::new_1d(0),
BlockId::new_1d(0),
Dim3::new_1d(256),
Dim3::new_1d(1),
&clock,
42,
ContextBackend::Cpu,
Some(Domain::OrderMatching),
256,
Some(tx),
);
assert!(ctx.has_alert_channel());
ctx.emit_alert(KernelAlert::error("Test error"));
let alert = rx.try_recv().expect("Should receive alert");
assert_eq!(alert.source_kernel, 42);
assert_eq!(alert.source_domain, Some(Domain::OrderMatching));
assert_eq!(alert.alert_type, KernelAlertType::ProcessingError);
}
#[test]
fn test_emit_alert_without_channel() {
let clock = HlcClock::new(1);
let ctx = make_test_context(&clock);
assert!(!ctx.has_alert_channel());
ctx.emit_alert(KernelAlert::error("No-op"));
}
#[test]
fn test_alert_if_slow() {
let (tx, mut rx) = mpsc::unbounded_channel();
let clock = HlcClock::new(1);
let ctx = RingContext::new_with_options(
ThreadId::new_1d(0),
BlockId::new_1d(0),
Dim3::new_1d(256),
Dim3::new_1d(1),
&clock,
1,
ContextBackend::Cpu,
None,
256,
Some(tx),
);
ctx.alert_if_slow("fast_op", 50, 100);
assert!(rx.try_recv().is_err());
ctx.alert_if_slow("slow_op", 150, 100);
let alert = rx.try_recv().expect("Should receive alert");
assert!(alert.message.contains("slow_op"));
assert!(alert.message.contains("150µs"));
}
#[test]
fn test_alert_severity_ordering() {
assert!(AlertSeverity::Info < AlertSeverity::Warning);
assert!(AlertSeverity::Warning < AlertSeverity::Error);
assert!(AlertSeverity::Error < AlertSeverity::Critical);
}
}