use super::gpu_source_bytes::{packed_source_byte_len_expr, safe_load_packed_byte_expr};
use crate::parsing::c::lex::tokens::{TOK_PP_IFDEF, TOK_PP_IFNDEF};
use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-libs::parsing::c::preprocess::gpu_ifdef_value";
pub const BINDING_TOK_STARTS: u32 = 0;
pub const BINDING_TOK_LENS: u32 = 1;
pub const BINDING_DIRECTIVE_KINDS: u32 = 2;
pub const BINDING_SOURCE: u32 = 3;
pub const BINDING_MACRO_NAMES_PACKED: u32 = 4;
pub const BINDING_MACRO_OFFSETS: u32 = 5;
pub const BINDING_DIRECTIVE_VALUES: u32 = 6;
const MAX_WS_PREFIX: u32 = 4;
#[must_use]
pub fn gpu_ifdef_value(num_tokens: u32, source_len: u32) -> Program {
let _ = source_len;
let t = Expr::var("t");
let source_byte_len = packed_source_byte_len_expr();
let safe_load = |buf: &'static str, addr: Expr, bound: Expr| -> Expr {
safe_load_packed_byte_expr(buf, addr, bound)
};
let is_ws = |b: Expr| -> Expr {
Expr::select(
Expr::or(
Expr::or(
Expr::eq(b.clone(), Expr::u32(b' ' as u32)),
Expr::eq(b.clone(), Expr::u32(b'\t' as u32)),
),
Expr::or(
Expr::eq(b.clone(), Expr::u32(0x0B)),
Expr::eq(b, Expr::u32(0x0C)),
),
),
Expr::u32(1),
Expr::u32(0),
)
};
let is_continue = |b: Expr| -> Expr {
let is_lower = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'a' as u32)),
Expr::le(b.clone(), Expr::u32(b'z' as u32)),
);
let is_upper = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'A' as u32)),
Expr::le(b.clone(), Expr::u32(b'Z' as u32)),
);
let is_digit = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'0' as u32)),
Expr::le(b.clone(), Expr::u32(b'9' as u32)),
);
let is_under = Expr::eq(b, Expr::u32(b'_' as u32));
Expr::select(
Expr::or(Expr::or(is_lower, is_upper), Expr::or(is_digit, is_under)),
Expr::u32(1),
Expr::u32(0),
)
};
let is_start = |b: Expr| -> Expr {
let is_lower = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'a' as u32)),
Expr::le(b.clone(), Expr::u32(b'z' as u32)),
);
let is_upper = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'A' as u32)),
Expr::le(b.clone(), Expr::u32(b'Z' as u32)),
);
let is_under = Expr::eq(b, Expr::u32(b'_' as u32));
Expr::select(
Expr::or(Expr::or(is_lower, is_upper), is_under),
Expr::u32(1),
Expr::u32(0),
)
};
let hash_off_expr = {
let mut acc = Expr::u32(0xFFFF_FFFF);
for p in (0..=MAX_WS_PREFIX).rev() {
let mut prefix_ws = Expr::u32(1);
for q in 0..p {
prefix_ws = Expr::bitand(prefix_ws, Expr::var(format!("hs_ws_{q}")));
}
let s_eq_hash = Expr::select(
Expr::eq(Expr::var(format!("hs_{p}")), Expr::u32(b'#' as u32)),
Expr::u32(1),
Expr::u32(0),
);
let cond_u32 = Expr::bitand(s_eq_hash, prefix_ws);
acc = Expr::select(Expr::eq(cond_u32, Expr::u32(1)), Expr::u32(p), acc);
}
acc
};
let ws_skip_expr = |prefix: &str, n: u32| -> Expr {
let mut acc = Expr::u32(n);
for q in (0..n).rev() {
let mut prefix_ws = Expr::u32(1);
for r in 0..q {
prefix_ws = Expr::bitand(prefix_ws, Expr::var(format!("{prefix}_ws_{r}")));
}
let xs_q_not_ws = Expr::select(
Expr::eq(Expr::var(format!("{prefix}_ws_{q}")), Expr::u32(0)),
Expr::u32(1),
Expr::u32(0),
);
let cond_u32 = Expr::bitand(xs_q_not_ws, prefix_ws);
acc = Expr::select(Expr::eq(cond_u32, Expr::u32(1)), Expr::u32(q), acc);
}
acc
};
let mut evaluate: Vec<Node> = Vec::new();
evaluate.push(Node::let_bind(
"tok_start",
Expr::load("tok_starts", t.clone()),
));
evaluate.push(Node::let_bind("tok_len", Expr::load("tok_lens", t.clone())));
evaluate.push(Node::let_bind(
"tok_end",
Expr::add(Expr::var("tok_start"), Expr::var("tok_len")),
));
for p in 0..=MAX_WS_PREFIX {
evaluate.push(Node::let_bind(
format!("hs_{p}"),
safe_load(
"source",
Expr::add(Expr::var("tok_start"), Expr::u32(p)),
source_byte_len.clone(),
),
));
}
for p in 0..=MAX_WS_PREFIX {
evaluate.push(Node::let_bind(
format!("hs_ws_{p}"),
is_ws(Expr::var(format!("hs_{p}"))),
));
}
evaluate.push(Node::let_bind("hash_off", hash_off_expr));
evaluate.push(Node::let_bind(
"hash_idx",
Expr::add(Expr::var("tok_start"), Expr::var("hash_off")),
));
evaluate.push(Node::let_bind(
"found_hash",
Expr::select(
Expr::lt(Expr::var("hash_off"), Expr::u32(MAX_WS_PREFIX + 1)),
Expr::u32(1),
Expr::u32(0),
),
));
for q in 0..MAX_WS_PREFIX {
evaluate.push(Node::let_bind(
format!("kp_{q}"),
safe_load(
"source",
Expr::add(Expr::var("hash_idx"), Expr::u32(q + 1)),
source_byte_len.clone(),
),
));
}
for q in 0..MAX_WS_PREFIX {
evaluate.push(Node::let_bind(
format!("kp_ws_{q}"),
is_ws(Expr::var(format!("kp_{q}"))),
));
}
evaluate.push(Node::let_bind("kw_skip", ws_skip_expr("kp", MAX_WS_PREFIX)));
evaluate.push(Node::let_bind(
"kw_start",
Expr::add(
Expr::add(Expr::var("hash_idx"), Expr::u32(1)),
Expr::var("kw_skip"),
),
));
evaluate.push(Node::let_bind(
"kw_len_skip",
Expr::select(
Expr::eq(Expr::var("kind"), Expr::u32(TOK_PP_IFNDEF)),
Expr::u32(6),
Expr::u32(5),
),
));
evaluate.push(Node::let_bind(
"post_kw",
Expr::add(Expr::var("kw_start"), Expr::var("kw_len_skip")),
));
for q in 0..MAX_WS_PREFIX {
evaluate.push(Node::let_bind(
format!("ip_{q}"),
safe_load(
"source",
Expr::add(Expr::var("post_kw"), Expr::u32(q)),
source_byte_len.clone(),
),
));
}
for q in 0..MAX_WS_PREFIX {
evaluate.push(Node::let_bind(
format!("ip_ws_{q}"),
is_ws(Expr::var(format!("ip_{q}"))),
));
}
evaluate.push(Node::let_bind(
"ident_skip",
ws_skip_expr("ip", MAX_WS_PREFIX),
));
evaluate.push(Node::let_bind(
"ident_start_val",
Expr::add(Expr::var("post_kw"), Expr::var("ident_skip")),
));
evaluate.push(Node::let_bind(
"ident_scan_limit",
Expr::select(
Expr::lt(Expr::var("ident_start_val"), Expr::var("tok_end")),
Expr::sub(Expr::var("tok_end"), Expr::var("ident_start_val")),
Expr::u32(0),
),
));
evaluate.push(Node::let_bind("ident_len_val", Expr::u32(0)));
evaluate.push(Node::let_bind("ident_done", Expr::u32(0)));
evaluate.push(Node::loop_for(
"ident_i",
Expr::u32(0),
Expr::var("ident_scan_limit"),
vec![Node::if_then(
Expr::eq(Expr::var("ident_done"), Expr::u32(0)),
vec![
Node::let_bind(
"ident_byte",
safe_load(
"source",
Expr::add(Expr::var("ident_start_val"), Expr::var("ident_i")),
source_byte_len.clone(),
),
),
Node::let_bind(
"ident_byte_ok",
Expr::select(
Expr::eq(Expr::var("ident_i"), Expr::u32(0)),
is_start(Expr::var("ident_byte")),
is_continue(Expr::var("ident_byte")),
),
),
Node::if_then_else(
Expr::eq(Expr::var("ident_byte_ok"), Expr::u32(1)),
vec![Node::assign(
"ident_len_val",
Expr::add(Expr::var("ident_i"), Expr::u32(1)),
)],
vec![Node::assign("ident_done", Expr::u32(1))],
),
],
)],
));
let macro_count_runtime = Expr::sub(Expr::buf_len("macro_offsets"), Expr::u32(1));
let macro_names_byte_cap_runtime = Expr::mul(Expr::buf_len("macro_names_packed"), Expr::u32(4));
evaluate.push(Node::let_bind("def_found", Expr::u32(0)));
let compare_macro_body: Vec<Node> = vec![
Node::let_bind(
"m_start",
Expr::cast(DataType::U32, Expr::load("macro_offsets", Expr::var("m"))),
),
Node::let_bind(
"m_end",
Expr::cast(
DataType::U32,
Expr::load("macro_offsets", Expr::add(Expr::var("m"), Expr::u32(1))),
),
),
Node::let_bind("m_len", Expr::sub(Expr::var("m_end"), Expr::var("m_start"))),
Node::let_bind(
"all_match",
Expr::select(
Expr::and(
Expr::ne(Expr::var("ident_len_val"), Expr::u32(0)),
Expr::eq(Expr::var("m_len"), Expr::var("ident_len_val")),
),
Expr::u32(1),
Expr::u32(0),
),
),
Node::loop_for(
"name_k",
Expr::u32(0),
Expr::var("m_len"),
vec![Node::if_then(
Expr::eq(Expr::var("all_match"), Expr::u32(1)),
vec![
Node::let_bind(
"ident_cmp_byte",
safe_load(
"source",
Expr::add(Expr::var("ident_start_val"), Expr::var("name_k")),
source_byte_len.clone(),
),
),
Node::let_bind(
"macro_cmp_byte",
safe_load(
"macro_names_packed",
Expr::add(Expr::var("m_start"), Expr::var("name_k")),
macro_names_byte_cap_runtime.clone(),
),
),
Node::if_then(
Expr::ne(Expr::var("ident_cmp_byte"), Expr::var("macro_cmp_byte")),
vec![Node::assign("all_match", Expr::u32(0))],
),
],
)],
),
Node::if_then(
Expr::eq(Expr::var("all_match"), Expr::u32(1)),
vec![Node::assign("def_found", Expr::u32(1))],
),
];
evaluate.push(Node::loop_for(
"m",
Expr::u32(0),
macro_count_runtime,
vec![Node::if_then(
Expr::eq(Expr::var("def_found"), Expr::u32(0)),
compare_macro_body,
)],
));
evaluate.push(Node::let_bind(
"value_out_val",
Expr::select(
Expr::eq(Expr::var("kind"), Expr::u32(TOK_PP_IFNDEF)),
Expr::select(
Expr::eq(Expr::var("def_found"), Expr::u32(1)),
Expr::u32(0),
Expr::u32(1),
),
Expr::var("def_found"),
),
));
evaluate.push(Node::if_then(
Expr::eq(Expr::var("found_hash"), Expr::u32(1)),
vec![Node::store(
"directive_values",
t.clone(),
Expr::var("value_out_val"),
)],
));
let body: Vec<Node> = vec![
Node::let_bind("t", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(t.clone(), Expr::u32(num_tokens)),
vec![
Node::let_bind("kind", Expr::load("directive_kinds", t.clone())),
Node::if_then(
Expr::or(
Expr::eq(Expr::var("kind"), Expr::u32(TOK_PP_IFDEF)),
Expr::eq(Expr::var("kind"), Expr::u32(TOK_PP_IFNDEF)),
),
evaluate,
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(
"tok_starts",
BINDING_TOK_STARTS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(num_tokens.max(1)),
BufferDecl::storage(
"tok_lens",
BINDING_TOK_LENS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(num_tokens.max(1)),
BufferDecl::storage(
"directive_kinds",
BINDING_DIRECTIVE_KINDS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(num_tokens.max(1)),
BufferDecl::storage(
"source",
BINDING_SOURCE,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(0),
BufferDecl::storage(
"macro_names_packed",
BINDING_MACRO_NAMES_PACKED,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(0),
BufferDecl::storage(
"macro_offsets",
BINDING_MACRO_OFFSETS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(0),
BufferDecl::storage(
"directive_values",
BINDING_DIRECTIVE_VALUES,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(num_tokens.max(1)),
],
[256, 1, 1],
body,
)
.with_entry_op_id(OP_ID)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn op_id_is_canonical_and_stable() {
assert_eq!(OP_ID, "vyre-libs::parsing::c::preprocess::gpu_ifdef_value");
}
#[test]
fn binding_indices_are_canonical_and_stable() {
assert_eq!(BINDING_TOK_STARTS, 0);
assert_eq!(BINDING_TOK_LENS, 1);
assert_eq!(BINDING_DIRECTIVE_KINDS, 2);
assert_eq!(BINDING_SOURCE, 3);
assert_eq!(BINDING_MACRO_NAMES_PACKED, 4);
assert_eq!(BINDING_MACRO_OFFSETS, 5);
assert_eq!(BINDING_DIRECTIVE_VALUES, 6);
}
#[test]
fn build_program_returns_well_formed_program() {
let p = gpu_ifdef_value(8, 64);
assert_eq!(p.buffers().len(), 7);
assert_eq!(p.workgroup_size(), [256, 1, 1]);
}
#[test]
fn source_buffer_is_runtime_sized_not_source_length_specialized() {
let p = gpu_ifdef_value(8, 64);
let source = p
.buffers()
.iter()
.find(|buffer| buffer.name() == "source")
.expect("Fix: source buffer must exist");
assert_eq!(
source.count, 0,
"source must be runtime-sized so one ifdef evaluator program serves all source lengths"
);
}
}