use std::sync::Arc;
use vyre_runtime::megakernel::{
Megakernel, MegakernelConfig, MegakernelResidentBuffers, MegakernelWorkItem,
};
use vyre_runtime::PipelineError;
use vyre_libs::scan::LiteralMatch;
#[derive(Debug, Clone)]
pub struct MegakernelSessionConfig {
pub slot_count: u32,
pub workgroup_size_x: u32,
pub tenant_count: u32,
pub observable_slots: u32,
pub config: MegakernelConfig,
}
impl Default for MegakernelSessionConfig {
fn default() -> Self {
Self {
slot_count: 256,
workgroup_size_x: 256,
tenant_count: 1,
observable_slots: 0,
config: MegakernelConfig::default(),
}
}
}
pub struct MegakernelSession {
kernel: Megakernel,
buffers: MegakernelResidentBuffers,
config: MegakernelSessionConfig,
next_slot: u32,
}
const _: () = {
const fn assert_send_sync<T: Send + Sync>() {}
let _ = assert_send_sync::<MegakernelSession>;
};
impl MegakernelSession {
pub fn new(
backend: Arc<dyn vyre::VyreBackend>,
session_config: MegakernelSessionConfig,
) -> Result<Option<Self>, PipelineError> {
let kernel = match Megakernel::bootstrap_sharded(
backend,
session_config.slot_count,
session_config.workgroup_size_x,
Vec::new(),
) {
Ok(k) => k,
Err(error) => {
tracing::debug!(
target: "keyhog::gpu",
%error,
"megakernel bootstrap failed - degrading to per-batch dispatch",
);
return Ok(None);
}
};
let buffers = MegakernelResidentBuffers::new(
session_config.slot_count,
session_config.tenant_count,
session_config.observable_slots,
)?;
Ok(Some(Self {
kernel,
buffers,
config: session_config,
next_slot: 0,
}))
}
pub fn submit_scan(
&mut self,
work_items: &[MegakernelWorkItem],
) -> Result<Vec<LiteralMatch>, PipelineError> {
if work_items.is_empty() {
return Ok(Vec::new());
}
let published = self.buffers.publish_work_items(
self.next_slot,
0, work_items,
)?;
self.next_slot = (self.next_slot + published) % self.config.slot_count;
let readback = self.buffers.dispatch(&self.kernel)?;
let _readback_io = readback.io_queue_bytes;
let matches = Vec::new();
tracing::trace!(
target: "keyhog::gpu",
published,
readback_control_bytes = readback.control_bytes.len(),
readback_ring_bytes = readback.ring_bytes.len(),
"megakernel scan dispatch completed",
);
Ok(matches)
}
pub fn flush(&mut self) -> Result<(), PipelineError> {
self.buffers.dispatch_update(&self.kernel)?;
tracing::trace!(
target: "keyhog::gpu",
"megakernel flush completed",
);
Ok(())
}
pub fn shutdown(&mut self) {
if let Err(error) = self.buffers.reset(
self.config.tenant_count,
self.config.observable_slots,
) {
tracing::warn!(
target: "keyhog::gpu",
%error,
"megakernel resident buffer reset failed during shutdown",
);
}
self.next_slot = 0;
tracing::debug!(
target: "keyhog::gpu",
"megakernel session shutdown",
);
}
#[must_use]
pub fn slot_count(&self) -> u32 {
self.config.slot_count
}
#[must_use]
pub fn next_slot(&self) -> u32 {
self.next_slot
}
}