use std::sync::mpsc::{sync_channel, Receiver};
use super::expert_forward::MAX_K;
use super::expert_io::{ExpertFiles, ExpertIoError};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum SlotSource {
Synced,
Prefetched,
}
#[derive(Debug)]
pub struct PrefetchState {
last_token_indices: Vec<Option<[i32; MAX_K]>>,
in_flight: Option<InFlight>,
}
#[derive(Debug)]
struct InFlight {
target_layer: usize,
loaded_indices: [i32; MAX_K],
k: usize,
rx: Receiver<Result<(), ExpertIoError>>,
}
#[derive(Copy, Clone, Debug)]
pub struct PrefetchStatus {
pub loaded_indices: [i32; MAX_K],
pub k: usize,
}
#[derive(Copy, Clone, Debug)]
struct DataPrefetchPtr {
ptr_addr: usize,
len: usize,
}
impl DataPrefetchPtr {
fn from_slice(s: &mut [u8]) -> Self {
debug_assert!(!s.is_empty(), "data_prefetch slot must be non-empty");
Self {
ptr_addr: s.as_mut_ptr() as usize,
len: s.len(),
}
}
unsafe fn as_mut_slice<'a>(self) -> &'a mut [u8] {
unsafe {
std::slice::from_raw_parts_mut(
self.ptr_addr as *mut u8,
self.len,
)
}
}
}
#[derive(Copy, Clone, Debug)]
struct ExpertFilesPtr {
addr: usize,
}
impl ExpertFilesPtr {
fn from_ref(r: &ExpertFiles) -> Self {
Self {
addr: (r as *const ExpertFiles) as usize,
}
}
unsafe fn as_ref<'a>(self) -> &'a ExpertFiles {
unsafe { &*(self.addr as *const ExpertFiles) }
}
}
impl PrefetchState {
pub fn new(num_layers: usize) -> Self {
Self {
last_token_indices: vec![None; num_layers],
in_flight: None,
}
}
pub fn drain(&mut self) {
if let Some(in_flight) = self.in_flight.take() {
for _ in 0..in_flight.k {
let _ = in_flight.rx.recv();
}
}
}
pub fn invalidate_all(&mut self) {
self.drain();
for slot in self.last_token_indices.iter_mut() {
*slot = None;
}
}
pub fn wait_for(&mut self, layer_idx: usize) -> Option<PrefetchStatus> {
let in_flight = self.in_flight.take()?;
let mut all_ok = true;
for _ in 0..in_flight.k {
match in_flight.rx.recv() {
Ok(Ok(())) => {}
_ => all_ok = false,
}
}
if !all_ok || in_flight.target_layer != layer_idx {
None
} else {
Some(PrefetchStatus {
loaded_indices: in_flight.loaded_indices,
k: in_flight.k,
})
}
}
pub fn predict_for(&self, layer_idx: usize) -> Option<[i32; MAX_K]> {
self.last_token_indices.get(layer_idx).copied().flatten()
}
pub fn record_actual(
&mut self,
layer_idx: usize,
actual: [i32; MAX_K],
) {
if let Some(slot) = self.last_token_indices.get_mut(layer_idx) {
*slot = Some(actual);
}
}
pub fn dispatch(
&mut self,
target_layer: usize,
predicted: [i32; MAX_K],
k: usize,
data_prefetch: [&mut [u8]; MAX_K],
pool: &rayon::ThreadPool,
expert_files: &ExpertFiles,
) {
self.drain();
let (tx, rx) = sync_channel::<Result<(), ExpertIoError>>(MAX_K);
let efp = ExpertFilesPtr::from_ref(expert_files);
let mut slot_ptrs: [Option<DataPrefetchPtr>; MAX_K] =
std::array::from_fn(|_| None);
for (i, dst) in data_prefetch.into_iter().enumerate() {
slot_ptrs[i] = Some(DataPrefetchPtr::from_slice(dst));
}
for slot in 0..k {
let expert_idx = predicted[slot] as usize;
let dst_ptr = slot_ptrs[slot].expect("slot 0..k populated");
let tx = tx.clone();
pool.spawn(move || {
let dst = unsafe { dst_ptr.as_mut_slice() };
let efs = unsafe { efp.as_ref() };
let r = efs.read_expert(target_layer, expert_idx, dst);
let _ = tx.send(r);
});
}
self.in_flight = Some(InFlight {
target_layer,
loaded_indices: predicted,
k,
rx,
});
}
}
impl Drop for PrefetchState {
fn drop(&mut self) {
self.drain();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn predict_for_returns_none_until_recorded() {
let mut st = PrefetchState::new(8);
assert_eq!(st.predict_for(3), None);
let actual = [0i32; MAX_K];
st.record_actual(3, actual);
assert_eq!(st.predict_for(3), Some(actual));
st.invalidate_all();
assert_eq!(st.predict_for(3), None);
}
#[test]
fn predict_for_out_of_range_layer_is_none() {
let st = PrefetchState::new(2);
assert_eq!(st.predict_for(99), None);
}
}