use cubecl::prelude::*;
#[cube(launch_unchecked)]
pub fn kernel_compute_token_mask_dfa(
transitions: &Array<u32>,
char_classes: &Array<u32>,
vocab_offsets: &Array<u32>,
vocab_chars: &Array<u32>,
allow: &mut Array<u32>,
num_classes: u32,
start_state: u32,
reject_state: u32,
max_token_len: u32,
) {
let n_alw = allow.len();
if ABSOLUTE_POS < n_alw {
let num_classes_u = num_classes as usize;
let max_token_len_u = max_token_len as usize;
let tok = ABSOLUTE_POS;
let tok_next = tok + 1usize;
let start = vocab_offsets[tok] as usize;
let end = vocab_offsets[tok_next] as usize;
let mut state: u32 = start_state;
let mut rejected: u32 = 0u32;
if start == end {
rejected = 1u32;
}
for i in 0..max_token_len_u {
let pos = start + i;
if pos < end && rejected == 0u32 {
let c = vocab_chars[pos];
if c < 128u32 {
let class_u = char_classes[c as usize] as usize;
let idx = (state as usize) * num_classes_u + class_u;
state = transitions[idx];
if state == reject_state {
rejected = 1u32;
}
} else {
rejected = 1u32;
}
}
}
if rejected == 0u32 {
allow[tok] = 1u32;
} else {
allow[tok] = 0u32;
}
}
}
fn elementwise_launch_dims(n: u32) -> (CubeCount, CubeDim) {
let units_per_cube: u32 = 256;
let num_cubes = n.div_ceil(units_per_cube).max(1);
(
CubeCount::Static(num_cubes, 1, 1),
CubeDim::new_1d(units_per_cube),
)
}
pub struct DfaMaskInputs<'a> {
pub transitions: &'a [u32],
pub char_classes: &'a [u32],
pub vocab_offsets: &'a [u32],
pub vocab_chars: &'a [u32],
pub num_classes: u32,
pub start_state: u32,
pub reject_state: u32,
pub max_token_len: u32,
}
pub fn compute_token_mask_dfa_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
inputs: &DfaMaskInputs<'_>,
) -> (cubecl::server::Handle, usize) {
let vocab_size = inputs.vocab_offsets.len() - 1;
let bytes_u32 = std::mem::size_of::<u32>();
let trans_handle = client.create_from_slice(u32_slice_as_bytes(inputs.transitions));
let class_handle = client.create_from_slice(u32_slice_as_bytes(inputs.char_classes));
let offsets_handle = client.create_from_slice(u32_slice_as_bytes(inputs.vocab_offsets));
let chars_handle = client.create_from_slice(u32_slice_as_bytes(inputs.vocab_chars));
let allow_handle = client.empty(vocab_size * bytes_u32);
let (count, dim) = elementwise_launch_dims(vocab_size as u32);
unsafe {
kernel_compute_token_mask_dfa::launch_unchecked::<R>(
client,
count,
dim,
ArrayArg::from_raw_parts(trans_handle, inputs.transitions.len()),
ArrayArg::from_raw_parts(class_handle, inputs.char_classes.len()),
ArrayArg::from_raw_parts(offsets_handle, inputs.vocab_offsets.len()),
ArrayArg::from_raw_parts(chars_handle, inputs.vocab_chars.len()),
ArrayArg::from_raw_parts(allow_handle.clone(), vocab_size),
inputs.num_classes,
inputs.start_state,
inputs.reject_state,
inputs.max_token_len,
);
}
(allow_handle, vocab_size)
}
fn u32_slice_as_bytes(s: &[u32]) -> &[u8] {
unsafe { std::slice::from_raw_parts(s.as_ptr().cast::<u8>(), std::mem::size_of_val(s)) }
}
#[cfg(all(test, feature = "cuda"))]
mod cuda_tests {
use super::*;
use cubecl_cuda::{CudaDevice, CudaRuntime};
fn cuda_client() -> ComputeClient<CudaRuntime> {
let device = CudaDevice { index: 0 };
CudaRuntime::client(&device)
}
fn read_u32(
client: &ComputeClient<CudaRuntime>,
handle: cubecl::server::Handle,
n: usize,
) -> Vec<u32> {
let bytes = client.read_one(handle).expect("CUDA read_one failed");
assert_eq!(bytes.len(), n * std::mem::size_of::<u32>());
let mut out = Vec::with_capacity(n);
for chunk in bytes.chunks_exact(4) {
out.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
out
}
fn reference_walk(inputs: &DfaMaskInputs<'_>) -> Vec<u32> {
let n = inputs.vocab_offsets.len() - 1;
let mut out = vec![0u32; n];
for (tok, slot) in out.iter_mut().enumerate().take(n) {
let start = inputs.vocab_offsets[tok] as usize;
let end = inputs.vocab_offsets[tok + 1] as usize;
if start == end {
continue;
}
let mut state = inputs.start_state;
let mut rejected = false;
for &c in &inputs.vocab_chars[start..end] {
if c >= 128 {
rejected = true;
break;
}
let class = inputs.char_classes[c as usize];
state = inputs.transitions[(state * inputs.num_classes + class) as usize];
if state == inputs.reject_state {
rejected = true;
break;
}
}
if !rejected {
*slot = 1;
}
}
out
}
#[test]
fn dfa_kernel_matches_hand_built_walk() {
let mut classes = vec![1u32; 128];
classes[b'a' as usize] = 0;
let num_classes = 2u32;
let transitions: Vec<u32> = vec![
0, 1,
1, 1,
];
let tokens: &[&str] = &[
"a", "aa", "aaa", "ab", "ba", "", ];
let mut vocab_offsets: Vec<u32> = vec![0];
let mut vocab_chars: Vec<u32> = Vec::new();
for tok in tokens {
for c in tok.chars() {
vocab_chars.push(c as u32);
}
vocab_offsets.push(vocab_chars.len() as u32);
}
let inputs = DfaMaskInputs {
transitions: &transitions,
char_classes: &classes,
vocab_offsets: &vocab_offsets,
vocab_chars: &vocab_chars,
num_classes,
start_state: 0,
reject_state: 1,
max_token_len: 8,
};
let expected = reference_walk(&inputs);
assert_eq!(expected, vec![1, 1, 1, 0, 0, 0]);
let client = cuda_client();
let (handle, n) = compute_token_mask_dfa_to_gpu::<CudaRuntime>(&client, &inputs);
let got = read_u32(&client, handle, n);
assert_eq!(got, expected, "GPU mask must match CPU reference walk");
}
#[test]
fn dfa_kernel_accepts_digit_sequences() {
let mut classes = vec![1u32; 128];
for d in b'0'..=b'9' {
classes[d as usize] = 0;
}
let num_classes = 2u32;
let transitions: Vec<u32> = vec![0, 1, 1, 1];
let tokens: Vec<String> = (0..10)
.map(|d| d.to_string())
.chain((10..20).map(|n| n.to_string()))
.chain((100..105).map(|n| n.to_string()))
.chain(["1a", "2b", "3c", "4d", "x9"].iter().map(|s| s.to_string()))
.chain(["".to_string(), "1234567".to_string()])
.collect();
let mut vocab_offsets: Vec<u32> = vec![0];
let mut vocab_chars: Vec<u32> = Vec::new();
for tok in &tokens {
for c in tok.chars() {
vocab_chars.push(c as u32);
}
vocab_offsets.push(vocab_chars.len() as u32);
}
let inputs = DfaMaskInputs {
transitions: &transitions,
char_classes: &classes,
vocab_offsets: &vocab_offsets,
vocab_chars: &vocab_chars,
num_classes,
start_state: 0,
reject_state: 1,
max_token_len: 8,
};
let expected = reference_walk(&inputs);
let client = cuda_client();
let (handle, n) = compute_token_mask_dfa_to_gpu::<CudaRuntime>(&client, &inputs);
let got = read_u32(&client, handle, n);
assert_eq!(got, expected, "digit-DFA GPU mask must match CPU reference");
let accepted: u32 = got.iter().sum();
assert_eq!(accepted, 26);
}
#[test]
fn dfa_kernel_rejects_non_ascii() {
let classes = vec![0u32; 128];
let num_classes = 1u32;
let transitions: Vec<u32> = vec![0];
let tokens: Vec<String> = vec!["abc".to_string(), "héllo".to_string(), "x".to_string()];
let mut vocab_offsets: Vec<u32> = vec![0];
let mut vocab_chars: Vec<u32> = Vec::new();
for tok in &tokens {
for c in tok.chars() {
vocab_chars.push(c as u32);
}
vocab_offsets.push(vocab_chars.len() as u32);
}
let inputs = DfaMaskInputs {
transitions: &transitions,
char_classes: &classes,
vocab_offsets: &vocab_offsets,
vocab_chars: &vocab_chars,
num_classes,
start_state: 0,
reject_state: u32::MAX,
max_token_len: 8,
};
let expected = reference_walk(&inputs);
assert_eq!(expected, vec![1, 0, 1]);
let client = cuda_client();
let (handle, n) = compute_token_mask_dfa_to_gpu::<CudaRuntime>(&client, &inputs);
let got = read_u32(&client, handle, n);
assert_eq!(got, expected, "non-ASCII must reject on both CPU and GPU");
}
}