vyre-runtime 0.6.2

Persistent megakernel + io_uring zero-copy streaming runtime for vyre - GPU as VIR0 bytecode interpreter
Documentation
use super::{
    default_buffers, optimize_megakernel_program, persistent_body_with_io, wrap_megakernel_program,
    wrap_persistent_megakernel_program,
};
use rustc_hash::FxHashMap;
use std::cell::RefCell;
use std::sync::Arc;
use vyre_foundation::ir::Program;

const EMPTY_TEMPLATE_CACHE_CAP: usize = 64;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct EmptyTemplateKey {
    workgroup_size_x: u32,
    slot_count: u32,
    include_io_polling: bool,
    finite_once: bool,
    control_report_only: bool,
}

struct EmptyTemplateCache {
    entries: FxHashMap<EmptyTemplateKey, EmptyTemplateEntry>,
    clock: u64,
}

struct EmptyTemplateEntry {
    program: Arc<Program>,
    last_seen: u64,
}

impl EmptyTemplateCache {
    fn get(&mut self, key: &EmptyTemplateKey) -> Option<Arc<Program>> {
        if self.clock == u64::MAX {
            self.clock = 0;
            for entry in self.entries.values_mut() {
                entry.last_seen = 0;
            }
        }
        let entry = self.entries.get_mut(key)?;
        self.clock += 1;
        entry.last_seen = self.clock;
        Some(Arc::clone(&entry.program))
    }

    fn insert(&mut self, key: EmptyTemplateKey, program: Arc<Program>) {
        let tick = self.next_tick();
        self.entries.insert(
            key,
            EmptyTemplateEntry {
                program,
                last_seen: tick,
            },
        );
        while self.entries.len() > EMPTY_TEMPLATE_CACHE_CAP {
            let Some(evicted) = self
                .entries
                .iter()
                .filter(|(candidate, _)| **candidate != key)
                .min_by_key(|(_, entry)| entry.last_seen)
                .map(|(candidate, _)| *candidate)
            else {
                break;
            };
            self.entries.remove(&evicted);
        }
    }

    #[cfg(test)]
    fn clear(&mut self) {
        self.entries.clear();
        self.clock = 0;
    }

    fn next_tick(&mut self) -> u64 {
        if self.clock == u64::MAX {
            self.clock = 0;
            for entry in self.entries.values_mut() {
                entry.last_seen = 0;
            }
        }
        self.clock += 1;
        self.clock
    }
}

impl Default for EmptyTemplateCache {
    fn default() -> Self {
        Self {
            entries: FxHashMap::with_capacity_and_hasher(
                EMPTY_TEMPLATE_CACHE_CAP,
                Default::default(),
            ),
            clock: 0,
        }
    }
}

thread_local! {
    static EMPTY_TEMPLATE_CACHE: RefCell<EmptyTemplateCache> =
        RefCell::new(EmptyTemplateCache::default());
}

pub(super) fn cached_empty_sharded_program(
    workgroup_size_x: u32,
    slot_count: u32,
    include_io_polling: bool,
) -> Program {
    cached_empty_sharded_program_shared(workgroup_size_x, slot_count, include_io_polling)
        .as_ref()
        .clone()
}

pub(super) fn cached_empty_sharded_program_shared(
    workgroup_size_x: u32,
    slot_count: u32,
    include_io_polling: bool,
) -> Arc<Program> {
    let key = EmptyTemplateKey {
        workgroup_size_x,
        slot_count,
        include_io_polling,
        finite_once: false,
        control_report_only: false,
    };
    if let Some(program) = EMPTY_TEMPLATE_CACHE.with(|cache| cache.borrow_mut().get(&key)) {
        return program;
    }

    let program = wrap_persistent_megakernel_program(
        workgroup_size_x,
        slot_count,
        persistent_body_with_io(workgroup_size_x, &[], include_io_polling),
    );
    let program = Arc::new(program);
    EMPTY_TEMPLATE_CACHE.with(|cache| {
        cache.borrow_mut().insert(key, Arc::clone(&program));
    });
    program
}

pub(super) fn cached_empty_sharded_once_program(workgroup_size_x: u32, slot_count: u32) -> Program {
    cached_empty_sharded_once_program_shared(workgroup_size_x, slot_count)
        .as_ref()
        .clone()
}

pub(super) fn cached_empty_sharded_once_program_shared(
    workgroup_size_x: u32,
    slot_count: u32,
) -> Arc<Program> {
    let key = EmptyTemplateKey {
        workgroup_size_x,
        slot_count,
        include_io_polling: false,
        finite_once: true,
        control_report_only: false,
    };
    if let Some(program) = EMPTY_TEMPLATE_CACHE.with(|cache| cache.borrow_mut().get(&key)) {
        return program;
    }

    let program = wrap_megakernel_program(
        workgroup_size_x,
        slot_count,
        persistent_body_with_io(workgroup_size_x, &[], false),
    );
    let program = Arc::new(program);
    EMPTY_TEMPLATE_CACHE.with(|cache| {
        cache.borrow_mut().insert(key, Arc::clone(&program));
    });
    program
}

pub(super) fn cached_empty_sharded_once_control_report_program_shared(
    workgroup_size_x: u32,
    slot_count: u32,
) -> Arc<Program> {
    let key = EmptyTemplateKey {
        workgroup_size_x,
        slot_count,
        include_io_polling: false,
        finite_once: true,
        control_report_only: true,
    };
    if let Some(program) = EMPTY_TEMPLATE_CACHE.with(|cache| cache.borrow_mut().get(&key)) {
        return program;
    }

    let mut buffers = default_buffers(slot_count);
    for buffer in buffers.iter_mut().skip(1) {
        buffer.output_byte_range = Some(0..0);
    }
    let program = Arc::new(optimize_megakernel_program(Program::wrapped(
        buffers,
        [workgroup_size_x, 1, 1],
        persistent_body_with_io(workgroup_size_x, &[], false),
    )));
    EMPTY_TEMPLATE_CACHE.with(|cache| {
        cache.borrow_mut().insert(key, Arc::clone(&program));
    });
    program
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn empty_template_cache_refreshes_hot_template_on_hit() {
        EMPTY_TEMPLATE_CACHE.with(|cache| cache.borrow_mut().clear());
        let hot = cached_empty_sharded_program_shared(1, 1, false);
        for slot_count in 2..=EMPTY_TEMPLATE_CACHE_CAP as u32 {
            let _ = cached_empty_sharded_program_shared(1, slot_count, false);
        }
        let hot_after_hit = cached_empty_sharded_program_shared(1, 1, false);
        assert!(Arc::ptr_eq(&hot, &hot_after_hit));
        let _ =
            cached_empty_sharded_program_shared(1, (EMPTY_TEMPLATE_CACHE_CAP + 1) as u32, false);
        let hot_after_eviction = cached_empty_sharded_program_shared(1, 1, false);
        assert!(Arc::ptr_eq(&hot, &hot_after_eviction));
    }

    #[test]
    fn empty_control_report_template_is_cached_by_arc() {
        EMPTY_TEMPLATE_CACHE.with(|cache| cache.borrow_mut().clear());
        let first = cached_empty_sharded_once_control_report_program_shared(64, 128);
        let second = cached_empty_sharded_once_control_report_program_shared(64, 128);

        assert!(Arc::ptr_eq(&first, &second));
    }
}