use crate::parsing::c::lex::tokens::{
TOK_PP_DEFINE, TOK_PP_ELIF, TOK_PP_ELSE, TOK_PP_ENDIF, TOK_PP_ERROR, TOK_PP_IDENT, TOK_PP_IF,
TOK_PP_IFDEF, TOK_PP_IFNDEF, TOK_PP_INCLUDE, TOK_PP_INCLUDE_NEXT, TOK_PP_LINE, TOK_PP_NULL,
TOK_PP_PRAGMA, TOK_PP_SCCS, TOK_PP_UNDEF, TOK_PP_WARNING, TOK_PREPROC,
};
use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-libs::parsing::c::preprocess::gpu_directive_metadata";
pub const BINDING_TOK_TYPES: u32 = 0;
pub const BINDING_TOK_STARTS: u32 = 1;
pub const BINDING_TOK_LENS: u32 = 2;
pub const BINDING_SOURCE: u32 = 3;
pub const BINDING_DIRECTIVE_KINDS: u32 = 4;
pub const BINDING_DIRECTIVE_VALUES: u32 = 5;
pub const MAX_KEYWORD_LEN: u32 = 12;
const MAX_WS_PREFIX: u32 = 4;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum SourceLayout {
PackedU32,
RawU8,
}
#[must_use]
pub fn gpu_directive_metadata(num_tokens: u32, source_len: u32) -> Program {
gpu_directive_metadata_with_source_layout(num_tokens, source_len, SourceLayout::PackedU32)
}
#[must_use]
pub fn gpu_directive_metadata_u8(num_tokens: u32, source_len: u32) -> Program {
gpu_directive_metadata_with_source_layout(num_tokens, source_len, SourceLayout::RawU8)
}
fn gpu_directive_metadata_with_source_layout(
num_tokens: u32,
source_len: u32,
source_layout: SourceLayout,
) -> Program {
let _ = source_len;
let t = Expr::var("t");
let source_byte_len = match source_layout {
SourceLayout::PackedU32 => super::gpu_source_bytes::packed_source_byte_len_expr(),
SourceLayout::RawU8 => Expr::buf_len("source"),
};
let safe_load = |addr: Expr| -> Expr {
match source_layout {
SourceLayout::PackedU32 => {
super::gpu_source_bytes::safe_load_source_byte_expr(addr, source_byte_len.clone())
}
SourceLayout::RawU8 => Expr::select(
Expr::lt(addr.clone(), source_byte_len.clone()),
Expr::bitand(
Expr::cast(DataType::U32, Expr::load("source", addr)),
Expr::u32(0xFF),
),
Expr::u32(0),
),
}
};
let is_ws = |b: Expr| -> Expr {
Expr::select(
Expr::or(
Expr::or(
Expr::eq(b.clone(), Expr::u32(b' ' as u32)),
Expr::eq(b.clone(), Expr::u32(b'\t' as u32)),
),
Expr::or(
Expr::eq(b.clone(), Expr::u32(0x0B)),
Expr::eq(b, Expr::u32(0x0C)),
),
),
Expr::u32(1),
Expr::u32(0),
)
};
let is_continue = |b: Expr| -> Expr {
let is_lower = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'a' as u32)),
Expr::le(b.clone(), Expr::u32(b'z' as u32)),
);
let is_upper = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'A' as u32)),
Expr::le(b.clone(), Expr::u32(b'Z' as u32)),
);
let is_digit = Expr::and(
Expr::ge(b.clone(), Expr::u32(b'0' as u32)),
Expr::le(b.clone(), Expr::u32(b'9' as u32)),
);
let is_under = Expr::eq(b, Expr::u32(b'_' as u32));
Expr::select(
Expr::or(Expr::or(is_lower, is_upper), Expr::or(is_digit, is_under)),
Expr::u32(1),
Expr::u32(0),
)
};
let hash_off_expr = {
let mut acc = Expr::u32(0xFFFF_FFFF);
for p in (0..=MAX_WS_PREFIX).rev() {
let mut prefix_ws = Expr::u32(1);
for q in 0..p {
prefix_ws = Expr::bitand(prefix_ws, Expr::var(format!("s_ws_{q}")));
}
let s_eq_hash = Expr::select(
Expr::eq(Expr::var(format!("s_{p}")), Expr::u32(b'#' as u32)),
Expr::u32(1),
Expr::u32(0),
);
let cond_u32 = Expr::bitand(s_eq_hash, prefix_ws);
acc = Expr::select(Expr::eq(cond_u32, Expr::u32(1)), Expr::u32(p), acc);
}
acc
};
let kw_skip_expr = {
let mut acc = Expr::u32(MAX_WS_PREFIX);
for q in (0..MAX_WS_PREFIX).rev() {
let mut prefix_ws = Expr::u32(1);
for r in 0..q {
prefix_ws = Expr::bitand(prefix_ws, Expr::var(format!("p_ws_{r}")));
}
let p_not_ws = Expr::select(
Expr::eq(Expr::var(format!("p_ws_{q}")), Expr::u32(0)),
Expr::u32(1),
Expr::u32(0),
);
let cond_u32 = Expr::bitand(p_not_ws, prefix_ws);
acc = Expr::select(Expr::eq(cond_u32, Expr::u32(1)), Expr::u32(q), acc);
}
acc
};
let keyword_match_expr = |expected: &[u32]| -> Expr {
let mut all_eq = Expr::u32(1);
for (i, byte) in expected.iter().copied().enumerate() {
let eq_byte = Expr::select(
Expr::eq(Expr::var(format!("k_{i}")), Expr::u32(byte)),
Expr::u32(1),
Expr::u32(0),
);
all_eq = Expr::bitand(all_eq, eq_byte);
}
let next_not_ident = Expr::select(
Expr::eq(
Expr::var(format!("k_is_continue_{}", expected.len())),
Expr::u32(0),
),
Expr::u32(1),
Expr::u32(0),
);
Expr::bitand(all_eq, next_not_ident)
};
let mut classify: Vec<Node> = Vec::new();
classify.push(Node::let_bind(
"tok_start",
Expr::load("tok_starts", t.clone()),
));
for p in 0..=MAX_WS_PREFIX {
classify.push(Node::let_bind(
format!("s_{p}"),
safe_load(Expr::add(Expr::var("tok_start"), Expr::u32(p))),
));
}
for p in 0..=MAX_WS_PREFIX {
classify.push(Node::let_bind(
format!("s_ws_{p}"),
is_ws(Expr::var(format!("s_{p}"))),
));
}
classify.push(Node::let_bind("hash_off", hash_off_expr));
classify.push(Node::let_bind(
"hash_idx",
Expr::add(Expr::var("tok_start"), Expr::var("hash_off")),
));
for q in 0..MAX_WS_PREFIX {
classify.push(Node::let_bind(
format!("p_{q}"),
safe_load(Expr::add(Expr::var("hash_idx"), Expr::u32(q + 1))),
));
}
for q in 0..MAX_WS_PREFIX {
classify.push(Node::let_bind(
format!("p_ws_{q}"),
is_ws(Expr::var(format!("p_{q}"))),
));
}
classify.push(Node::let_bind("kw_skip", kw_skip_expr));
classify.push(Node::let_bind(
"kw_start",
Expr::add(
Expr::add(Expr::var("hash_idx"), Expr::u32(1)),
Expr::var("kw_skip"),
),
));
for i in 0..=MAX_KEYWORD_LEN {
classify.push(Node::let_bind(
format!("k_{i}"),
safe_load(Expr::add(Expr::var("kw_start"), Expr::u32(i))),
));
}
for i in 0..=MAX_KEYWORD_LEN {
classify.push(Node::let_bind(
format!("k_is_continue_{i}"),
is_continue(Expr::var(format!("k_{i}"))),
));
}
classify.push(Node::let_bind(
"found_hash",
Expr::select(
Expr::lt(Expr::var("hash_off"), Expr::u32(MAX_WS_PREFIX + 1)),
Expr::u32(1),
Expr::u32(0),
),
));
let store_kind = |kind: u32| -> Vec<Node> {
vec![Node::store("directive_kinds", t.clone(), Expr::u32(kind))]
};
let fire = |cond_u32: Expr, kind: u32| -> Node {
Node::if_then(
Expr::eq(
Expr::bitand(Expr::var("found_hash"), cond_u32),
Expr::u32(1),
),
store_kind(kind),
)
};
classify.push(fire(
Expr::select(
Expr::eq(Expr::var("k_is_continue_0"), Expr::u32(0)),
Expr::u32(1),
Expr::u32(0),
),
TOK_PP_NULL,
));
classify.push(fire(
keyword_match_expr(&[100, 101, 102, 105, 110, 101]),
TOK_PP_DEFINE,
));
classify.push(fire(
keyword_match_expr(&[117, 110, 100, 101, 102]),
TOK_PP_UNDEF,
));
classify.push(fire(
keyword_match_expr(&[105, 110, 99, 108, 117, 100, 101, 95, 110, 101, 120, 116]),
TOK_PP_INCLUDE_NEXT,
));
classify.push(fire(
keyword_match_expr(&[105, 110, 99, 108, 117, 100, 101]),
TOK_PP_INCLUDE,
));
classify.push(fire(
keyword_match_expr(&[105, 102, 110, 100, 101, 102]),
TOK_PP_IFNDEF,
));
classify.push(fire(
keyword_match_expr(&[105, 102, 100, 101, 102]),
TOK_PP_IFDEF,
));
classify.push(fire(keyword_match_expr(&[105, 102]), TOK_PP_IF));
classify.push(fire(keyword_match_expr(&[101, 108, 105, 102]), TOK_PP_ELIF));
classify.push(fire(keyword_match_expr(&[101, 108, 115, 101]), TOK_PP_ELSE));
classify.push(fire(
keyword_match_expr(&[101, 110, 100, 105, 102]),
TOK_PP_ENDIF,
));
classify.push(fire(
keyword_match_expr(&[112, 114, 97, 103, 109, 97]),
TOK_PP_PRAGMA,
));
classify.push(fire(keyword_match_expr(&[108, 105, 110, 101]), TOK_PP_LINE));
classify.push(fire(
keyword_match_expr(&[101, 114, 114, 111, 114]),
TOK_PP_ERROR,
));
classify.push(fire(
keyword_match_expr(&[119, 97, 114, 110, 105, 110, 103]),
TOK_PP_WARNING,
));
classify.push(fire(
keyword_match_expr(&[105, 100, 101, 110, 116]),
TOK_PP_IDENT,
));
classify.push(fire(keyword_match_expr(&[115, 99, 99, 115]), TOK_PP_SCCS));
let body: Vec<Node> = vec![
Node::let_bind("t", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(t.clone(), Expr::buf_len("tok_starts")),
vec![
Node::let_bind("tok_type", Expr::load("tok_types", t.clone())),
Node::store("directive_kinds", t.clone(), Expr::u32(0)),
Node::store("directive_values", t.clone(), Expr::u32(0)),
Node::if_then(
Expr::eq(Expr::var("tok_type"), Expr::u32(TOK_PREPROC)),
classify,
),
],
),
];
let source_element = match source_layout {
SourceLayout::PackedU32 => DataType::U32,
SourceLayout::RawU8 => DataType::U8,
};
Program::wrapped(
vec![
BufferDecl::storage(
"tok_types",
BINDING_TOK_TYPES,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(num_tokens.max(1)),
BufferDecl::storage(
"tok_starts",
BINDING_TOK_STARTS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(num_tokens.max(1)),
BufferDecl::storage(
"tok_lens",
BINDING_TOK_LENS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(num_tokens.max(1)),
BufferDecl::storage(
"source",
BINDING_SOURCE,
BufferAccess::ReadOnly,
source_element,
)
.with_count(0),
BufferDecl::storage(
"directive_kinds",
BINDING_DIRECTIVE_KINDS,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(num_tokens.max(1)),
BufferDecl::storage(
"directive_values",
BINDING_DIRECTIVE_VALUES,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(num_tokens.max(1)),
],
[256, 1, 1],
body,
)
.with_entry_op_id(OP_ID)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn op_id_is_canonical_and_stable() {
assert_eq!(
OP_ID,
"vyre-libs::parsing::c::preprocess::gpu_directive_metadata"
);
}
#[test]
fn binding_indices_are_canonical_and_stable() {
assert_eq!(BINDING_TOK_TYPES, 0);
assert_eq!(BINDING_TOK_STARTS, 1);
assert_eq!(BINDING_TOK_LENS, 2);
assert_eq!(BINDING_SOURCE, 3);
assert_eq!(BINDING_DIRECTIVE_KINDS, 4);
assert_eq!(BINDING_DIRECTIVE_VALUES, 5);
}
#[test]
fn build_program_returns_well_formed_program() {
let p = gpu_directive_metadata(8, 64);
assert_eq!(p.buffers().len(), 6);
assert_eq!(p.workgroup_size(), [256, 1, 1]);
}
#[test]
fn source_buffer_is_runtime_sized_not_source_length_specialized() {
let p = gpu_directive_metadata(8, 64);
let source = p
.buffers()
.iter()
.find(|buffer| buffer.name() == "source")
.expect("Fix: source buffer must exist after directive metadata allocation");
assert_eq!(
source.count, 0,
"source must be runtime-sized so one directive classifier program serves all source lengths"
);
}
#[test]
fn source_buffer_layouts_preserve_packed_abi_and_raw_u8_variant() {
let packed = gpu_directive_metadata(8, 64);
let raw_u8 = gpu_directive_metadata_u8(8, 64);
let packed_source = packed
.buffers()
.iter()
.find(|buffer| buffer.name() == "source")
.expect("Fix: packed directive metadata source buffer must exist");
let raw_u8_source = raw_u8
.buffers()
.iter()
.find(|buffer| buffer.name() == "source")
.expect("Fix: raw-U8 directive metadata source buffer must exist");
assert_eq!(packed_source.element(), DataType::U32);
assert_eq!(packed_source.count, 0);
assert_eq!(raw_u8_source.element(), DataType::U8);
assert_eq!(raw_u8_source.count, 0);
}
#[test]
fn max_keyword_len_covers_longest_directive() {
assert!(MAX_KEYWORD_LEN >= 12);
}
}