use core::arch::asm;
use crate::json::simple::{SemiIndex as SimpleSemiIndex, State as SimpleState};
use crate::json::standard::{SemiIndex, State};
use crate::json::BitWriter;
#[derive(Debug, Clone, Copy, Default)]
struct CharClass {
quotes: u16,
backslashes: u16,
opens: u16,
closes: u16,
delims: u16,
value_chars: u16,
string_special: u16,
len: usize,
}
#[inline]
#[target_feature(enable = "sve2")]
unsafe fn get_vl() -> usize {
let vl: usize;
unsafe {
asm!(
"cntb {vl}",
vl = out(reg) vl,
options(pure, nomem, nostack)
);
}
vl
}
#[inline]
#[target_feature(enable = "sve2")]
unsafe fn classify_chars_sve2_simd(data: *const u8, len: usize) -> CharClass {
let vl = unsafe { get_vl() };
let chunk_len = len.min(vl).min(16);
if chunk_len == 0 {
return CharClass::default();
}
let mut quotes_buf = [0u8; 16];
let mut backslashes_buf = [0u8; 16];
let mut opens_buf = [0u8; 16];
let mut closes_buf = [0u8; 16];
let mut delims_buf = [0u8; 16];
let mut value_chars_buf = [0u8; 16];
unsafe {
asm!(
"whilelt p0.b, xzr, {len}",
"ld1b z0.b, p0/z, [{data}]",
"mov z1.b, #0x22",
"cmpeq p1.b, p0/z, z0.b, z1.b",
"mov z1.b, #0",
"mov z1.b, p1/m, #0xFF",
"st1b z1.b, p0, [{quotes_buf}]",
"mov z1.b, #0x5C",
"cmpeq p1.b, p0/z, z0.b, z1.b",
"mov z1.b, #0",
"mov z1.b, p1/m, #0xFF",
"st1b z1.b, p0, [{backslashes_buf}]",
"mov z1.b, #0x7B",
"cmpeq p1.b, p0/z, z0.b, z1.b",
"mov z1.b, #0x5B",
"cmpeq p2.b, p0/z, z0.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0",
"mov z1.b, p1/m, #0xFF",
"st1b z1.b, p0, [{opens_buf}]",
"mov z1.b, #0x7D",
"cmpeq p1.b, p0/z, z0.b, z1.b",
"mov z1.b, #0x5D",
"cmpeq p2.b, p0/z, z0.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0",
"mov z1.b, p1/m, #0xFF",
"st1b z1.b, p0, [{closes_buf}]",
"mov z1.b, #0x2C",
"cmpeq p1.b, p0/z, z0.b, z1.b",
"mov z1.b, #0x3A",
"cmpeq p2.b, p0/z, z0.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0",
"mov z1.b, p1/m, #0xFF",
"st1b z1.b, p0, [{delims_buf}]",
"mov z1.b, #0x30",
"sub z2.b, z0.b, z1.b",
"mov z1.b, #10",
"cmplo p1.b, p0/z, z2.b, z1.b",
"mov z1.b, #0x41",
"sub z2.b, z0.b, z1.b",
"mov z1.b, #26",
"cmplo p2.b, p0/z, z2.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0x61",
"sub z2.b, z0.b, z1.b",
"mov z1.b, #26",
"cmplo p2.b, p0/z, z2.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0x2E",
"cmpeq p2.b, p0/z, z0.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0x2D",
"cmpeq p2.b, p0/z, z0.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0x2B",
"cmpeq p2.b, p0/z, z0.b, z1.b",
"orr p1.b, p0/z, p1.b, p2.b",
"mov z1.b, #0",
"mov z1.b, p1/m, #0xFF",
"st1b z1.b, p0, [{value_chars_buf}]",
data = in(reg) data,
len = in(reg) chunk_len,
quotes_buf = in(reg) quotes_buf.as_mut_ptr(),
backslashes_buf = in(reg) backslashes_buf.as_mut_ptr(),
opens_buf = in(reg) opens_buf.as_mut_ptr(),
closes_buf = in(reg) closes_buf.as_mut_ptr(),
delims_buf = in(reg) delims_buf.as_mut_ptr(),
value_chars_buf = in(reg) value_chars_buf.as_mut_ptr(),
out("z0") _,
out("z1") _,
out("z2") _,
out("p0") _,
out("p1") _,
out("p2") _,
options(nostack)
);
}
fn buf_to_mask(buf: &[u8; 16], len: usize) -> u16 {
let mut mask = 0u16;
for (i, &byte) in buf.iter().enumerate().take(len) {
if byte != 0 {
mask |= 1u16 << i;
}
}
mask
}
let quotes = buf_to_mask("es_buf, chunk_len);
let backslashes = buf_to_mask(&backslashes_buf, chunk_len);
CharClass {
quotes,
backslashes,
opens: buf_to_mask(&opens_buf, chunk_len),
closes: buf_to_mask(&closes_buf, chunk_len),
delims: buf_to_mask(&delims_buf, chunk_len),
value_chars: buf_to_mask(&value_chars_buf, chunk_len),
string_special: quotes | backslashes,
len: chunk_len,
}
}
#[inline]
fn process_chunk_standard(
class: CharClass,
mut state: State,
ib: &mut BitWriter,
bp: &mut BitWriter,
) -> State {
let len = class.len;
let mut i = 0;
while i < len {
let remaining_mask = !((1u16 << i) - 1);
match state {
State::InJson => {
let bit = 1u16 << i;
let is_open = (class.opens & bit) != 0;
let is_close = (class.closes & bit) != 0;
let is_quote = (class.quotes & bit) != 0;
let is_value_char = (class.value_chars & bit) != 0;
if is_open {
bp.write_1();
ib.write_1();
} else if is_close {
bp.write_0();
ib.write_0();
} else if is_quote {
bp.write_1();
bp.write_0();
ib.write_1();
state = State::InString;
} else if is_value_char {
bp.write_1();
bp.write_0();
ib.write_1();
state = State::InValue;
} else {
ib.write_0();
}
i += 1;
}
State::InString => {
let special_remaining = class.string_special & remaining_mask;
if special_remaining == 0 {
let zeros_to_write = len - i;
ib.write_zeros(zeros_to_write);
return State::InString;
}
let next_special = special_remaining.trailing_zeros() as usize;
if next_special > i {
let zeros = next_special - i;
ib.write_zeros(zeros);
i = next_special;
}
let bit = 1u16 << i;
ib.write_0();
if (class.quotes & bit) != 0 {
state = State::InJson;
} else {
state = State::InEscape;
}
i += 1;
}
State::InEscape => {
ib.write_0();
state = State::InString;
i += 1;
}
State::InValue => {
let bit = 1u16 << i;
let is_open = (class.opens & bit) != 0;
let is_close = (class.closes & bit) != 0;
let is_quote = (class.quotes & bit) != 0;
let is_value_char = (class.value_chars & bit) != 0;
if is_open {
bp.write_1();
ib.write_1();
state = State::InJson;
} else if is_close {
bp.write_0();
ib.write_0();
state = State::InJson;
} else if is_quote {
bp.write_1();
bp.write_0();
ib.write_1();
state = State::InString;
} else if is_value_char {
ib.write_0();
} else {
ib.write_0();
state = State::InJson;
}
i += 1;
}
}
}
state
}
#[inline]
fn process_chunk_simple(
class: CharClass,
mut state: SimpleState,
ib: &mut BitWriter,
bp: &mut BitWriter,
) -> SimpleState {
for i in 0..class.len {
let bit = 1u16 << i;
let is_quote = (class.quotes & bit) != 0;
let is_backslash = (class.backslashes & bit) != 0;
let is_open = (class.opens & bit) != 0;
let is_close = (class.closes & bit) != 0;
let is_delim = (class.delims & bit) != 0;
match state {
SimpleState::InJson => {
if is_open {
bp.write_1();
bp.write_1();
ib.write_1();
} else if is_close {
bp.write_0();
bp.write_0();
ib.write_1();
} else if is_delim {
bp.write_0();
bp.write_1();
ib.write_1();
} else if is_quote {
ib.write_0();
state = SimpleState::InString;
} else {
ib.write_0();
}
}
SimpleState::InString => {
ib.write_0();
if is_quote {
state = SimpleState::InJson;
} else if is_backslash {
state = SimpleState::InEscape;
}
}
SimpleState::InEscape => {
ib.write_0();
state = SimpleState::InString;
}
}
}
state
}
#[target_feature(enable = "sve2")]
pub unsafe fn build_semi_index_standard(json: &[u8]) -> SemiIndex {
let word_capacity = json.len().div_ceil(64);
let mut ib = BitWriter::with_capacity(word_capacity);
let mut bp = BitWriter::with_capacity(word_capacity * 2);
let mut state = State::InJson;
let mut offset = 0;
while offset < json.len() {
let remaining = json.len() - offset;
let class = unsafe { classify_chars_sve2_simd(json.as_ptr().add(offset), remaining) };
state = process_chunk_standard(class, state, &mut ib, &mut bp);
offset += class.len;
}
SemiIndex {
state,
ib: ib.finish(),
bp: bp.finish(),
}
}
#[target_feature(enable = "sve2")]
pub unsafe fn build_semi_index_simple(json: &[u8]) -> SimpleSemiIndex {
let word_capacity = json.len().div_ceil(64);
let mut ib = BitWriter::with_capacity(word_capacity);
let mut bp = BitWriter::with_capacity(word_capacity * 2);
let mut state = SimpleState::InJson;
let mut offset = 0;
while offset < json.len() {
let remaining = json.len() - offset;
let class = unsafe { classify_chars_sve2_simd(json.as_ptr().add(offset), remaining) };
state = process_chunk_simple(class, state, &mut ib, &mut bp);
offset += class.len;
}
SimpleSemiIndex {
state,
ib: ib.finish(),
bp: bp.finish(),
}
}
#[cfg(feature = "std")]
#[inline]
pub fn has_sve2() -> bool {
std::arch::is_aarch64_feature_detected!("sve2")
}
#[cfg(not(feature = "std"))]
#[inline]
pub const fn has_sve2() -> bool {
false
}
#[cfg(test)]
mod tests {
use super::*;
fn has_sve2_runtime() -> bool {
#[cfg(feature = "std")]
{
std::arch::is_aarch64_feature_detected!("sve2")
}
#[cfg(not(feature = "std"))]
{
false
}
}
fn get_bit(words: &[u64], i: usize) -> bool {
let word_idx = i / 64;
let bit_idx = i % 64;
if word_idx < words.len() {
(words[word_idx] >> bit_idx) & 1 == 1
} else {
false
}
}
fn bits_to_string(words: &[u64], n: usize) -> String {
(0..n)
.map(|i| if get_bit(words, i) { '1' } else { '0' })
.collect()
}
#[test]
fn test_sve2_matches_scalar_empty_object() {
if !has_sve2_runtime() {
eprintln!("Skipping SVE2 test: CPU doesn't support SVE2");
return;
}
let json = b"{}";
let sve2_result = unsafe { build_semi_index_standard(json) };
let scalar_result = crate::json::standard::build_semi_index(json);
assert_eq!(
bits_to_string(&sve2_result.ib, json.len()),
bits_to_string(&scalar_result.ib, json.len()),
"IB mismatch"
);
assert_eq!(sve2_result.state, scalar_result.state);
}
#[test]
fn test_sve2_matches_scalar_simple_object() {
if !has_sve2_runtime() {
return;
}
let json = br#"{"a":"b"}"#;
let sve2_result = unsafe { build_semi_index_standard(json) };
let scalar_result = crate::json::standard::build_semi_index(json);
assert_eq!(
bits_to_string(&sve2_result.ib, json.len()),
bits_to_string(&scalar_result.ib, json.len()),
"IB mismatch"
);
assert_eq!(sve2_result.state, scalar_result.state);
}
#[test]
fn test_sve2_matches_scalar_array() {
if !has_sve2_runtime() {
return;
}
let json = b"[1,2,3]";
let sve2_result = unsafe { build_semi_index_standard(json) };
let scalar_result = crate::json::standard::build_semi_index(json);
assert_eq!(
bits_to_string(&sve2_result.ib, json.len()),
bits_to_string(&scalar_result.ib, json.len()),
"IB mismatch"
);
}
#[test]
fn test_sve2_matches_scalar_nested() {
if !has_sve2_runtime() {
return;
}
let json = br#"{"a":{"b":1}}"#;
let sve2_result = unsafe { build_semi_index_standard(json) };
let scalar_result = crate::json::standard::build_semi_index(json);
assert_eq!(
bits_to_string(&sve2_result.ib, json.len()),
bits_to_string(&scalar_result.ib, json.len()),
"IB mismatch"
);
}
#[test]
fn test_sve2_matches_scalar_escaped() {
if !has_sve2_runtime() {
return;
}
let json = br#"{"a":"b\"c"}"#;
let sve2_result = unsafe { build_semi_index_standard(json) };
let scalar_result = crate::json::standard::build_semi_index(json);
assert_eq!(
bits_to_string(&sve2_result.ib, json.len()),
bits_to_string(&scalar_result.ib, json.len()),
"IB mismatch"
);
}
#[test]
fn test_sve2_matches_scalar_long_input() {
if !has_sve2_runtime() {
return;
}
let json = br#"{"name":"value","number":12345,"array":[1,2,3]}"#;
let sve2_result = unsafe { build_semi_index_standard(json) };
let scalar_result = crate::json::standard::build_semi_index(json);
assert_eq!(
bits_to_string(&sve2_result.ib, json.len()),
bits_to_string(&scalar_result.ib, json.len()),
"IB mismatch for long input"
);
assert_eq!(sve2_result.state, scalar_result.state);
}
#[test]
fn test_sve2_simple_matches_scalar() {
if !has_sve2_runtime() {
return;
}
let json = br#"{"a":"b"}"#;
let sve2_result = unsafe { build_semi_index_simple(json) };
let scalar_result = crate::json::simple::build_semi_index(json);
assert_eq!(
bits_to_string(&sve2_result.ib, json.len()),
bits_to_string(&scalar_result.ib, json.len()),
"IB mismatch"
);
assert_eq!(sve2_result.state, scalar_result.state);
}
}