use crate::region::wrap_anonymous;
use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use vyre_primitives::nn::quest_paging_passes::{
quest_score_pages_body, quest_select_top_k_body, quest_zero_fill_body,
};
const OP_ID: &str = "vyre-libs::nn::attention::quest_paging";
const SCORE_SENTINEL: f32 = f32::MIN;
#[must_use]
pub fn quest_paging(
query: &str,
page_metadata: &str,
scores: &str,
io_queue: &str,
num_pages: u32,
k: u32,
d_head: u32,
) -> Program {
if num_pages <= 8 && k <= 4 && d_head <= 16 {
let mut score_exprs = Vec::with_capacity(num_pages as usize);
for page in 0..num_pages {
let mut score = Expr::f32(0.0);
for dim in 0..d_head {
score = Expr::add(
score,
Expr::mul(
Expr::load(query, Expr::u32(dim)),
Expr::load(page_metadata, Expr::u32(page * d_head + dim)),
),
);
}
score_exprs.push(score);
}
let mut selected = Vec::<Expr>::with_capacity(k as usize);
for _rank in 0..k {
let mut best_score = Expr::f32(SCORE_SENTINEL);
let mut best_idx = Expr::u32(0);
for page in 0..num_pages {
let mut eligible = Expr::bool(true);
for prior in &selected {
eligible = Expr::select(
eligible,
Expr::ne(Expr::u32(page), prior.clone()),
Expr::bool(false),
);
}
let better = Expr::select(
eligible,
Expr::gt(score_exprs[page as usize].clone(), best_score.clone()),
Expr::bool(false),
);
best_score = Expr::select(
better.clone(),
score_exprs[page as usize].clone(),
best_score,
);
best_idx = Expr::select(better, Expr::u32(page), best_idx);
}
selected.push(best_idx);
}
let mut stores = Vec::with_capacity((num_pages * 2) as usize);
for page in 0..num_pages {
let mut picked = Expr::bool(false);
for prior in &selected {
picked = Expr::select(
picked,
Expr::bool(true),
Expr::eq(Expr::u32(page), prior.clone()),
);
}
stores.push(Node::store(
scores,
Expr::u32(page),
Expr::select(
picked,
Expr::f32(SCORE_SENTINEL),
score_exprs[page as usize].clone(),
),
));
}
for slot in 0..num_pages {
let value = if slot < k {
selected[slot as usize].clone()
} else {
Expr::u32(0)
};
stores.push(Node::store(io_queue, Expr::u32(slot), value));
}
return Program::wrapped(
vec![
BufferDecl::storage(query, 0, BufferAccess::ReadOnly, DataType::F32)
.with_count(d_head),
BufferDecl::storage(page_metadata, 1, BufferAccess::ReadOnly, DataType::F32)
.with_count(num_pages * d_head),
BufferDecl::storage(scores, 2, BufferAccess::ReadWrite, DataType::F32)
.with_count(num_pages),
BufferDecl::storage(io_queue, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(num_pages),
],
[1, 1, 1],
vec![wrap_anonymous(
OP_ID,
vec![Node::if_then(
Expr::eq(Expr::InvocationId { axis: 0 }, Expr::u32(0)),
stores,
)],
)],
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![
Node::Block(quest_zero_fill_body(io_queue, num_pages)),
Node::Block(quest_score_pages_body(
query,
page_metadata,
scores,
num_pages,
d_head,
)),
Node::barrier(),
Node::Block(vec![Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
quest_select_top_k_body(scores, io_queue, num_pages, k, SCORE_SENTINEL),
)]),
];
Program::wrapped(
vec![
BufferDecl::storage(query, 0, BufferAccess::ReadOnly, DataType::F32).with_count(d_head),
BufferDecl::storage(page_metadata, 1, BufferAccess::ReadOnly, DataType::F32)
.with_count(num_pages * d_head),
BufferDecl::storage(scores, 2, BufferAccess::ReadWrite, DataType::F32)
.with_count(num_pages),
BufferDecl::storage(io_queue, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(num_pages),
],
[256, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || quest_paging("q", "meta", "scores", "io", 4, 2, 2),
test_inputs: Some(|| {
let to_f32_bytes =
|w: &[f32]| vyre_primitives::wire::pack_f32_slice(w);
vec![vec![
to_f32_bytes(&[1.0, 0.0]),
to_f32_bytes(&[0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.5, 0.0]),
vec![0u8; 4 * 4],
vec![0u8; 4 * 4],
]]
}),
expected_output: Some(|| {
let to_f32_bytes =
|w: &[f32]| vyre_primitives::wire::pack_f32_slice(w);
let scores = [0.0, SCORE_SENTINEL, SCORE_SENTINEL, 0.5];
let io_queue = [2u32, 1, 0, 0];
vec![vec![to_f32_bytes(&scores), crate::test_support::byte_pack::u32_bytes(&io_queue)]]
}),
category: Some("nn"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::bytes_to_u32 as decode_u32;
use crate::test_support::byte_pack::decode_f32;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
#[test]
fn quest_paging_nan_in_query_produces_nan_scores() {
let query = [f32::NAN, 0.0];
let meta = [0.0f32, 0.0, 1.0, 0.0];
let program = quest_paging("q", "meta", "scores", "io", 2, 1, 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&query)),
Value::from(f32_bytes(&meta)),
Value::from(vec![0u8; 8]),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: quest_paging must not panic on NaN query");
let scores = decode_f32(&outputs[0].to_bytes());
assert!(
scores.iter().any(|v| v.is_nan()),
"quest_paging NaN query must produce at least one NaN score"
);
}
#[test]
fn quest_paging_zero_pages() {
let query = [1.0f32, 0.0];
let program = quest_paging("q", "meta", "scores", "io", 0, 0, 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&query)),
Value::from(vec![]),
Value::from(vec![]),
Value::from(vec![]),
],
)
.expect("Fix: quest_paging num_pages=0 must not panic");
assert!(outputs[0].to_bytes().is_empty());
assert!(outputs[1].to_bytes().is_empty());
}
#[test]
fn quest_paging_single_page() {
let query = [1.0f32, 0.0];
let meta = [2.0f32, 0.0];
let program = quest_paging("q", "meta", "scores", "io", 1, 1, 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&query)),
Value::from(f32_bytes(&meta)),
Value::from(vec![0u8; 4]),
Value::from(vec![0u8; 4]),
],
)
.expect("Fix: quest_paging single page must execute");
let io_queue = decode_u32(&outputs[1].to_bytes());
assert_eq!(io_queue[0], 0, "single page top-1 must be index 0");
}
#[test]
fn quest_paging_k_zero() {
let query = [1.0f32, 0.0];
let meta = [1.0f32, 0.0, 2.0, 0.0];
let program = quest_paging("q", "meta", "scores", "io", 2, 0, 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&query)),
Value::from(f32_bytes(&meta)),
Value::from(vec![0u8; 8]),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: quest_paging k=0 must not panic");
let io_queue = decode_u32(&outputs[1].to_bytes());
assert_eq!(io_queue, vec![0, 0], "k=0 must zero-fill io_queue");
}
}