use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::decode::rle_segment_lengths";
pub const BINDING_SEGMENTS_IN: u32 = 0;
pub const BINDING_SEGMENT_LENGTHS_OUT: u32 = 1;
pub const BINDING_SEGMENT_VALUES_OUT: u32 = 2;
pub const MAX_SEGMENT_LENGTH: u32 = (1 << 24) - 1;
pub const MAX_SEGMENT_VALUE: u32 = 0xFF;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum PackError {
LengthTooLarge {
segment: usize,
length: u32,
},
ValueTooLarge {
segment: usize,
value: u32,
},
}
#[must_use]
pub fn rle_segment_lengths(segment_count: u32) -> Program {
let body = vec![
Node::let_bind("seg_idx", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(Expr::var("seg_idx"), Expr::u32(segment_count)),
vec![
Node::let_bind("packed", Expr::load("segments_in", Expr::var("seg_idx"))),
Node::let_bind("length", Expr::shr(Expr::var("packed"), Expr::u32(8))),
Node::let_bind("value", Expr::bitand(Expr::var("packed"), Expr::u32(0xFF))),
Node::store(
"segment_lengths_out",
Expr::var("seg_idx"),
Expr::var("length"),
),
Node::store(
"segment_values_out",
Expr::var("seg_idx"),
Expr::var("value"),
),
],
),
];
let buffers = vec![
BufferDecl::storage(
"segments_in",
BINDING_SEGMENTS_IN,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(segment_count),
BufferDecl::storage(
"segment_lengths_out",
BINDING_SEGMENT_LENGTHS_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(segment_count),
BufferDecl::storage(
"segment_values_out",
BINDING_SEGMENT_VALUES_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(segment_count),
];
let entry = vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}];
Program::wrapped(buffers, [256, 1, 1], entry)
}
pub fn pack_rle_segments(segments: &[(u32, u8)]) -> Result<Vec<u32>, PackError> {
let mut packed = Vec::with_capacity(segments.len());
pack_rle_segments_into(segments, &mut packed)?;
Ok(packed)
}
pub fn pack_rle_segments_into(segments: &[(u32, u8)], out: &mut Vec<u32>) -> Result<(), PackError> {
out.clear();
out.reserve(segments.len());
for (idx, (length, value)) in segments.iter().enumerate() {
if *length > MAX_SEGMENT_LENGTH {
return Err(PackError::LengthTooLarge {
segment: idx,
length: *length,
});
}
let value_u32 = u32::from(*value);
if value_u32 > MAX_SEGMENT_VALUE {
return Err(PackError::ValueTooLarge {
segment: idx,
value: value_u32,
});
}
out.push((length << 8) | value_u32);
}
Ok(())
}
#[must_use]
pub fn rle_segment_lengths_cpu(segments_in: &[u32]) -> (Vec<u32>, Vec<u32>) {
let mut lengths = Vec::new();
let mut values = Vec::new();
rle_segment_lengths_cpu_into(segments_in, &mut lengths, &mut values);
(lengths, values)
}
pub fn rle_segment_lengths_cpu_into(
segments_in: &[u32],
lengths: &mut Vec<u32>,
values: &mut Vec<u32>,
) {
lengths.clear();
values.clear();
lengths.reserve(segments_in.len());
values.reserve(segments_in.len());
for packed in segments_in {
lengths.push(packed >> 8);
values.push(packed & 0xFF);
}
}
#[must_use]
pub fn rle_segment_start_offsets_cpu(segment_lengths: &[u32]) -> (Vec<u32>, u32) {
let mut offsets = Vec::new();
let total = rle_segment_start_offsets_cpu_into(segment_lengths, &mut offsets);
(offsets, total)
}
pub fn rle_segment_start_offsets_cpu_into(segment_lengths: &[u32], offsets: &mut Vec<u32>) -> u32 {
offsets.clear();
offsets.reserve(segment_lengths.len());
let mut acc: u32 = 0;
for length in segment_lengths {
offsets.push(acc);
acc = acc.saturating_add(*length);
}
acc
}
#[must_use]
pub fn rle_decode_cpu(segments_in: &[u32]) -> Vec<u8> {
let mut output = Vec::new();
rle_decode_cpu_into(segments_in, &mut output);
output
}
pub fn rle_decode_cpu_into(segments_in: &[u32], output: &mut Vec<u8>) {
output.clear();
let total = segments_in
.iter()
.map(|packed| packed >> 8)
.fold(0_u32, u32::saturating_add);
output.reserve(total as usize);
for packed in segments_in {
let length = (packed >> 8) as usize;
let value = (packed & 0xFF) as u8;
let new_len = output.len().saturating_add(length);
output.resize(new_len, value);
}
}
#[cfg(feature = "inventory-registry")]
fn fixture_u32(words: &[u32]) -> Vec<u8> {
words.iter().flat_map(|word| word.to_le_bytes()).collect()
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| rle_segment_lengths(3),
Some(|| {
let packed = pack_rle_segments(&[(2, b'A'), (0, b'X'), (3, b'B')])
.expect("fixture RLE segments fit the 24-bit length field");
vec![vec![
fixture_u32(&packed),
fixture_u32(&[0, 0, 0]),
fixture_u32(&[0, 0, 0]),
]]
}),
Some(|| vec![vec![
fixture_u32(&[2, 0, 3]),
fixture_u32(&[u32::from(b'A'), u32::from(b'X'), u32::from(b'B')]),
]]),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_then_unpack_round_trips_simple_segments() {
let segments = [(1u32, 0xABu8), (5u32, 0xCDu8)];
let packed = pack_rle_segments(&segments).expect("pack must succeed");
let (lengths, values) = rle_segment_lengths_cpu(&packed);
assert_eq!(lengths, vec![1, 5]);
assert_eq!(values, vec![0xAB, 0xCD]);
}
#[test]
fn pack_rejects_length_at_field_boundary() {
let segments = [(1u32 << 24, 0u8)]; match pack_rle_segments(&segments) {
Err(PackError::LengthTooLarge { segment: 0, length }) => {
assert_eq!(length, 1u32 << 24);
}
other => panic!("expected LengthTooLarge at the 24-bit boundary; got {other:?}"),
}
}
#[test]
fn pack_handles_max_representable_length() {
let segments = [(MAX_SEGMENT_LENGTH, 0xFFu8)];
let packed = pack_rle_segments(&segments).expect("max-length must pack");
let (lengths, values) = rle_segment_lengths_cpu(&packed);
assert_eq!(lengths, vec![MAX_SEGMENT_LENGTH]);
assert_eq!(values, vec![0xFF]);
}
#[test]
fn pack_handles_zero_length_segment_as_no_op() {
let segments = [(0u32, 0xABu8)];
let packed = pack_rle_segments(&segments).expect("zero-length must pack");
let (lengths, _values) = rle_segment_lengths_cpu(&packed);
assert_eq!(lengths, vec![0]);
}
#[test]
fn pack_preserves_per_segment_index_in_error() {
let mut segments: Vec<(u32, u8)> = (0..10).map(|i| (i as u32, 0u8)).collect();
segments[3].0 = 1u32 << 25; match pack_rle_segments(&segments) {
Err(PackError::LengthTooLarge { segment: 3, .. }) => {}
other => panic!("expected error at segment 3; got {other:?}"),
}
}
#[test]
fn start_offsets_are_exclusive_prefix_sum() {
let lengths = [3u32, 5, 2, 7];
let (offsets, total) = rle_segment_start_offsets_cpu(&lengths);
assert_eq!(offsets, vec![0, 3, 8, 10]);
assert_eq!(total, 17, "sum of lengths");
}
#[test]
fn start_offsets_handle_zero_length_runs_correctly() {
let lengths = [3u32, 0, 5, 0, 2];
let (offsets, total) = rle_segment_start_offsets_cpu(&lengths);
assert_eq!(offsets, vec![0, 3, 3, 8, 8]);
assert_eq!(total, 10);
}
#[test]
fn start_offsets_handle_empty_input() {
let (offsets, total) = rle_segment_start_offsets_cpu(&[]);
assert!(offsets.is_empty());
assert_eq!(total, 0);
}
#[test]
fn end_to_end_decode_expands_runs_in_order() {
let segments = [(3u32, b'A'), (2u32, b'B'), (1u32, b'C')];
let packed = pack_rle_segments(&segments).expect("pack must succeed");
let decoded = rle_decode_cpu(&packed);
assert_eq!(decoded, b"AAABBC".to_vec());
}
#[test]
fn end_to_end_decode_handles_long_run() {
let segments = [(1000u32, 0x42u8)];
let packed = pack_rle_segments(&segments).expect("pack must succeed");
let decoded = rle_decode_cpu(&packed);
assert_eq!(decoded.len(), 1000);
assert!(decoded.iter().all(|&b| b == 0x42));
}
#[test]
fn end_to_end_decode_handles_alternating_short_runs() {
let mut segments = Vec::with_capacity(256);
for i in 0..256 {
segments.push((1u32, if i % 2 == 0 { 0xAAu8 } else { 0xBBu8 }));
}
let packed = pack_rle_segments(&segments).expect("pack must succeed");
let decoded = rle_decode_cpu(&packed);
assert_eq!(decoded.len(), 256);
for (i, byte) in decoded.iter().enumerate() {
let expected = if i % 2 == 0 { 0xAA } else { 0xBB };
assert_eq!(*byte, expected);
}
}
#[test]
fn end_to_end_decode_handles_empty_input() {
let decoded = rle_decode_cpu(&[]);
assert!(decoded.is_empty());
}
#[test]
fn end_to_end_decode_handles_zero_length_segments_as_skips() {
let segments = [(2u32, b'A'), (0u32, b'X'), (3u32, b'B')];
let packed = pack_rle_segments(&segments).expect("pack must succeed");
let decoded = rle_decode_cpu(&packed);
assert_eq!(decoded, b"AABBB".to_vec());
}
#[test]
fn pack_into_reuses_existing_capacity() {
let segments = [(2u32, b'A'), (4u32, b'B')];
let mut out = Vec::with_capacity(64);
let before = out.capacity();
pack_rle_segments_into(&segments, &mut out).expect("pack_into must succeed");
assert_eq!(out.len(), 2);
assert_eq!(
out.capacity(),
before,
"pack_into must reuse caller-owned capacity"
);
}
#[test]
fn cpu_unpack_into_reuses_existing_capacity() {
let segments = [(2u32, b'A'), (4u32, b'B')];
let packed = pack_rle_segments(&segments).expect("pack must succeed");
let mut lengths = Vec::with_capacity(64);
let mut values = Vec::with_capacity(64);
let lengths_capacity = lengths.capacity();
let values_capacity = values.capacity();
rle_segment_lengths_cpu_into(&packed, &mut lengths, &mut values);
assert_eq!(lengths, vec![2, 4]);
assert_eq!(values, vec![u32::from(b'A'), u32::from(b'B')]);
assert_eq!(lengths.capacity(), lengths_capacity);
assert_eq!(values.capacity(), values_capacity);
}
#[test]
fn start_offsets_into_reuses_existing_capacity() {
let mut offsets = Vec::with_capacity(64);
let capacity = offsets.capacity();
let total = rle_segment_start_offsets_cpu_into(&[2, 0, 4], &mut offsets);
assert_eq!(offsets, vec![0, 2, 2]);
assert_eq!(total, 6);
assert_eq!(offsets.capacity(), capacity);
}
#[test]
fn decode_into_reuses_existing_capacity_without_intermediate_vectors() {
let segments = [(2u32, b'A'), (0u32, b'X'), (3u32, b'B')];
let packed = pack_rle_segments(&segments).expect("pack must succeed");
let mut decoded = Vec::with_capacity(64);
let capacity = decoded.capacity();
rle_decode_cpu_into(&packed, &mut decoded);
assert_eq!(decoded, b"AABBB".to_vec());
assert_eq!(decoded.capacity(), capacity);
}
#[test]
fn build_program_returns_well_formed_program() {
let program = rle_segment_lengths(8);
assert_eq!(
program.buffers().len(),
3,
"segments_in + lengths_out + values_out"
);
assert_eq!(program.workgroup_size(), [256, 1, 1]);
}
#[test]
fn build_program_is_deterministic_across_calls() {
let p1 = rle_segment_lengths(32);
let p2 = rle_segment_lengths(32);
assert_eq!(p1.buffers().len(), p2.buffers().len());
assert_eq!(p1.workgroup_size(), p2.workgroup_size());
}
#[test]
fn op_id_is_canonical_and_stable() {
assert_eq!(OP_ID, "vyre-primitives::decode::rle_segment_lengths");
}
#[test]
fn binding_indices_are_canonical_and_stable() {
assert_eq!(BINDING_SEGMENTS_IN, 0);
assert_eq!(BINDING_SEGMENT_LENGTHS_OUT, 1);
assert_eq!(BINDING_SEGMENT_VALUES_OUT, 2);
}
#[test]
fn max_segment_length_is_canonical_24_bit_field_max() {
assert_eq!(MAX_SEGMENT_LENGTH, (1u32 << 24) - 1);
assert_eq!(MAX_SEGMENT_VALUE, 0xFF);
}
}