use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::parsing::whitespace_classify_word";
pub const BINDING_BYTES_IN: u32 = 0;
pub const BINDING_WHITESPACE_MASK_OUT: u32 = 1;
const WS_SP: u32 = 0x20;
const WS_TAB: u32 = 0x09;
const WS_LF: u32 = 0x0A;
const WS_CR: u32 = 0x0D;
#[must_use]
pub fn whitespace_classify_word(word_count: u32) -> Program {
let body = vec![
Node::let_bind("word_idx", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(Expr::var("word_idx"), Expr::u32(word_count)),
vec![
Node::let_bind("word", Expr::load("bytes_in", Expr::var("word_idx"))),
Node::let_bind("b0", Expr::bitand(Expr::var("word"), Expr::u32(0xFF))),
Node::let_bind(
"b1",
Expr::bitand(Expr::shr(Expr::var("word"), Expr::u32(8)), Expr::u32(0xFF)),
),
Node::let_bind(
"b2",
Expr::bitand(Expr::shr(Expr::var("word"), Expr::u32(16)), Expr::u32(0xFF)),
),
Node::let_bind(
"b3",
Expr::bitand(Expr::shr(Expr::var("word"), Expr::u32(24)), Expr::u32(0xFF)),
),
Node::let_bind(
"ws0",
Expr::or(
Expr::or(
Expr::eq(Expr::var("b0"), Expr::u32(WS_SP)),
Expr::eq(Expr::var("b0"), Expr::u32(WS_TAB)),
),
Expr::or(
Expr::eq(Expr::var("b0"), Expr::u32(WS_LF)),
Expr::eq(Expr::var("b0"), Expr::u32(WS_CR)),
),
),
),
Node::let_bind(
"ws1",
Expr::or(
Expr::or(
Expr::eq(Expr::var("b1"), Expr::u32(WS_SP)),
Expr::eq(Expr::var("b1"), Expr::u32(WS_TAB)),
),
Expr::or(
Expr::eq(Expr::var("b1"), Expr::u32(WS_LF)),
Expr::eq(Expr::var("b1"), Expr::u32(WS_CR)),
),
),
),
Node::let_bind(
"ws2",
Expr::or(
Expr::or(
Expr::eq(Expr::var("b2"), Expr::u32(WS_SP)),
Expr::eq(Expr::var("b2"), Expr::u32(WS_TAB)),
),
Expr::or(
Expr::eq(Expr::var("b2"), Expr::u32(WS_LF)),
Expr::eq(Expr::var("b2"), Expr::u32(WS_CR)),
),
),
),
Node::let_bind(
"ws3",
Expr::or(
Expr::or(
Expr::eq(Expr::var("b3"), Expr::u32(WS_SP)),
Expr::eq(Expr::var("b3"), Expr::u32(WS_TAB)),
),
Expr::or(
Expr::eq(Expr::var("b3"), Expr::u32(WS_LF)),
Expr::eq(Expr::var("b3"), Expr::u32(WS_CR)),
),
),
),
Node::let_bind(
"bit0",
Expr::select(Expr::var("ws0"), Expr::u32(1), Expr::u32(0)),
),
Node::let_bind(
"bit1",
Expr::select(Expr::var("ws1"), Expr::u32(2), Expr::u32(0)),
),
Node::let_bind(
"bit2",
Expr::select(Expr::var("ws2"), Expr::u32(4), Expr::u32(0)),
),
Node::let_bind(
"bit3",
Expr::select(Expr::var("ws3"), Expr::u32(8), Expr::u32(0)),
),
Node::let_bind(
"mask",
Expr::bitor(
Expr::bitor(Expr::var("bit0"), Expr::var("bit1")),
Expr::bitor(Expr::var("bit2"), Expr::var("bit3")),
),
),
Node::store(
"whitespace_mask_out",
Expr::var("word_idx"),
Expr::var("mask"),
),
],
),
];
let buffers = vec![
BufferDecl::storage(
"bytes_in",
BINDING_BYTES_IN,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(word_count),
BufferDecl::storage(
"whitespace_mask_out",
BINDING_WHITESPACE_MASK_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(word_count),
];
let entry = vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}];
Program::wrapped(buffers, [256, 1, 1], entry)
}
#[must_use]
#[inline]
pub const fn is_structural_whitespace(byte: u8) -> bool {
matches!(byte, 0x20 | 0x09 | 0x0A | 0x0D)
}
#[must_use]
pub fn whitespace_classify_word_cpu(words_in: &[u32]) -> Vec<u32> {
let mut out = Vec::with_capacity(words_in.len());
whitespace_classify_word_cpu_into(words_in, &mut out);
out
}
pub fn whitespace_classify_word_cpu_into(words_in: &[u32], out: &mut Vec<u32>) {
out.clear();
out.reserve(words_in.len());
for word in words_in {
let bytes = [
(*word & 0xFF) as u8,
((*word >> 8) & 0xFF) as u8,
((*word >> 16) & 0xFF) as u8,
((*word >> 24) & 0xFF) as u8,
];
let mut mask = 0u32;
for (lane, byte) in bytes.iter().enumerate() {
if is_structural_whitespace(*byte) {
mask |= 1u32 << lane;
}
}
out.push(mask);
}
}
#[must_use]
#[inline]
pub const fn pack_bytes_le(b0: u8, b1: u8, b2: u8, b3: u8) -> u32 {
(b0 as u32) | ((b1 as u32) << 8) | ((b2 as u32) << 16) | ((b3 as u32) << 24)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_all_non_whitespace_emits_zero_mask() {
let words = vec![pack_bytes_le(b'a', b'b', b'c', b'd')];
assert_eq!(whitespace_classify_word_cpu(&words), vec![0]);
}
#[test]
fn classify_all_whitespace_emits_full_4_bit_mask() {
let words = vec![pack_bytes_le(b' ', b'\t', b'\n', b'\r')];
assert_eq!(whitespace_classify_word_cpu(&words), vec![0b1111]);
}
#[test]
fn classify_mixed_word_marks_correct_lanes() {
let words = vec![pack_bytes_le(b'a', b' ', b'b', b'\t')];
assert_eq!(whitespace_classify_word_cpu(&words), vec![0b1010]);
}
#[test]
fn classify_every_per_lane_position_independently() {
for lane in 0..4u32 {
let mut bytes = [b'x', b'x', b'x', b'x'];
bytes[lane as usize] = b' ';
let word = pack_bytes_le(bytes[0], bytes[1], bytes[2], bytes[3]);
let result = whitespace_classify_word_cpu(&[word]);
assert_eq!(
result[0],
1u32 << lane,
"lane {lane} whitespace must set bit {lane}"
);
}
}
#[test]
fn classify_rejects_close_byte_values_that_are_not_whitespace() {
let words = vec![pack_bytes_le(0x21, 0x08, 0x0B, 0x0E)];
assert_eq!(
whitespace_classify_word_cpu(&words),
vec![0],
"values adjacent to but not exactly the whitespace set must NOT classify as ws"
);
}
#[test]
fn classify_does_not_match_unicode_whitespace() {
let words = vec![pack_bytes_le(0xA0, 0xC2, 0xE2, 0x80)];
assert_eq!(
whitespace_classify_word_cpu(&words),
vec![0],
"structural-parser whitespace is ASCII only by contract"
);
}
#[test]
fn classify_handles_long_input_byte_for_byte() {
let words = vec![pack_bytes_le(b' ', b'x', b' ', b'x'); 64];
let masks = whitespace_classify_word_cpu(&words);
assert_eq!(masks.len(), 64);
for mask in &masks {
assert_eq!(*mask, 0b0101);
}
}
#[test]
fn classify_empty_input_emits_empty_output() {
let masks = whitespace_classify_word_cpu(&[]);
assert!(masks.is_empty());
}
#[test]
fn classify_does_not_set_high_bits() {
let words = vec![pack_bytes_le(b' ', b' ', b' ', b' ')];
let masks = whitespace_classify_word_cpu(&words);
assert_eq!(
masks[0] >> 4,
0,
"high 28 bits must remain zero (reserved for lane widening)"
);
}
#[test]
fn pack_bytes_le_is_canonical() {
assert_eq!(pack_bytes_le(0x78, 0x56, 0x34, 0x12), 0x1234_5678);
assert_eq!(pack_bytes_le(0xFF, 0, 0, 0), 0xFF);
assert_eq!(pack_bytes_le(0, 0xFF, 0, 0), 0xFF00);
}
#[test]
fn classify_into_reuses_output_capacity() {
let words = [pack_bytes_le(b' ', b'x', b'\n', b'y')];
let mut out = Vec::with_capacity(32);
let before = out.capacity();
whitespace_classify_word_cpu_into(&words, &mut out);
assert_eq!(out, vec![0b0101]);
assert_eq!(out.capacity(), before);
}
#[test]
fn build_program_returns_well_formed_program() {
let program = whitespace_classify_word(64);
assert_eq!(program.buffers().len(), 2, "bytes_in + whitespace_mask_out");
assert_eq!(program.workgroup_size(), [256, 1, 1]);
}
#[test]
fn build_program_is_deterministic_across_calls() {
let p1 = whitespace_classify_word(128);
let p2 = whitespace_classify_word(128);
assert_eq!(p1.buffers().len(), p2.buffers().len());
assert_eq!(p1.workgroup_size(), p2.workgroup_size());
}
#[test]
fn op_id_is_canonical_and_stable() {
assert_eq!(OP_ID, "vyre-primitives::parsing::whitespace_classify_word");
}
#[test]
fn binding_indices_are_canonical_and_stable() {
assert_eq!(BINDING_BYTES_IN, 0);
assert_eq!(BINDING_WHITESPACE_MASK_OUT, 1);
}
#[test]
fn is_structural_whitespace_matches_only_the_canonical_four() {
for byte in 0u8..=255 {
let expected = matches!(byte, 0x20 | 0x09 | 0x0A | 0x0D);
assert_eq!(
is_structural_whitespace(byte),
expected,
"byte 0x{byte:02X} structural-ws classification"
);
}
}
}