use crate::error::{Error, Result};
use crate::metal::{
Binding, BoundBuffer, CommandBatch, Dispatch, GpuBuffer, MetalContext, MjHeader, MjParams,
Pipeline, Pod,
};
use crate::pool::{Alloc, ScratchPool};
pub const WORD_BYTES: usize = 64;
pub const CHUNK_WORDS: usize = 1024;
pub const CHUNK_BYTES: usize = WORD_BYTES * CHUNK_WORDS;
pub const TOKEN_CHUNK_TOKENS: usize = 1024;
pub const SKELETON_CHUNK_ELEMS: usize = 1024;
#[must_use]
pub fn sort_passes(max_depth: u32) -> usize {
let key_max = max_depth.max(1) - 1;
let bits = (32 - key_max.leading_zeros()).max(1) as usize;
bits.div_ceil(5)
}
pub const ESCAPE_LOOKBACK_CAP: usize = 4096;
pub const MAX_INPUT_BYTES: u64 = u32::MAX as u64 - 64;
#[derive(Debug)]
pub struct Stage {
name: &'static str,
pipeline: core::cell::OnceCell<Pipeline>,
}
impl Stage {
#[must_use]
pub const fn new(name: &'static str) -> Self {
Self {
name,
pipeline: core::cell::OnceCell::new(),
}
}
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
pub fn pipeline(&self, ctx: &MetalContext) -> Result<&Pipeline> {
if self.pipeline.get().is_none() {
let _ = self.pipeline.set(Pipeline::new(ctx, self.name)?);
}
Ok(self.pipeline.get().expect("pipeline just initialized"))
}
pub fn encode(
&self,
batch: &mut CommandBatch<'_>,
buffers: &[BoundBuffer],
params: Option<&MjParams>,
work: Dispatch,
) -> Result<()> {
let pipeline = self.pipeline(batch.ctx())?;
batch.dispatch(pipeline, buffers, params, work);
Ok(())
}
}
#[derive(Debug)]
pub struct TestHarness {
ctx: MetalContext,
}
impl TestHarness {
pub fn new() -> Result<Self> {
Ok(Self {
ctx: MetalContext::new()?,
})
}
#[must_use]
pub fn ctx(&self) -> &MetalContext {
&self.ctx
}
pub fn upload<T: Pod>(&self, data: &[T]) -> Result<GpuBuffer> {
let mut buffer = GpuBuffer::alloc(&self.ctx, size_of_val(data))?;
buffer.write_from(data);
Ok(buffer)
}
pub fn alloc_zeroed<T: Pod>(&self, count: usize) -> Result<GpuBuffer> {
let mut buffer = GpuBuffer::alloc(&self.ctx, count * size_of::<T>())?;
buffer.contents_mut().fill(0);
Ok(buffer)
}
pub fn run(
&self,
stage: &Stage,
bindings: &mut [Binding<'_>],
params: Option<&MjParams>,
work: Dispatch,
) -> Result<()> {
let mut batch = self.ctx.batch()?;
let mut handles = Vec::with_capacity(bindings.len());
for binding in bindings.iter_mut() {
handles.push(match binding {
Binding::Read(buffer) => batch.bind_read(buffer),
Binding::ReadWrite(buffer) => batch.bind_write(buffer),
});
}
stage.encode(&mut batch, &handles, params, work)?;
batch.commit_and_wait()
}
#[must_use]
pub fn read_back<T: Pod>(&self, buffer: &GpuBuffer) -> Vec<T> {
buffer.as_slice::<T>().to_vec()
}
}
#[derive(Debug)]
pub struct Stage1Buffers {
pub input: GpuBuffer,
input_external: bool,
input_len: usize,
words: usize,
chunks: usize,
pub bm_quote: GpuBuffer,
pub bm_tok: GpuBuffer,
pub escape_info: GpuBuffer,
pub chunk_quote_counts: GpuBuffer,
pub chunk_token_counts: GpuBuffer,
pub header: GpuBuffer,
pub tok_pos: Option<GpuBuffer>,
pub tok_kind: Option<GpuBuffer>,
}
impl Stage1Buffers {
pub fn new(ctx: &MetalContext, input: &[u8]) -> Result<Self> {
Self::new_in(ctx, Alloc::Direct, input)
}
pub(crate) fn new_in(ctx: &MetalContext, alloc: Alloc<'_>, input: &[u8]) -> Result<Self> {
check_input_len(input.len() as u64)?;
let words = input.len().div_ceil(WORD_BYTES);
let mut input_buf = alloc.buffer(ctx, words * WORD_BYTES)?;
let bytes = input_buf.contents_mut();
bytes[..input.len()].copy_from_slice(input);
bytes[input.len()..].fill(b' ');
Self::assemble(ctx, alloc, input_buf, input.len(), false)
}
pub(crate) fn with_external_input(
ctx: &MetalContext,
alloc: Alloc<'_>,
input_buf: GpuBuffer,
input_len: usize,
) -> Result<Self> {
check_input_len(input_len as u64)?;
Self::assemble(ctx, alloc, input_buf, input_len, true)
}
pub(crate) fn with_pooled_input(
ctx: &MetalContext,
alloc: Alloc<'_>,
input_buf: GpuBuffer,
input_len: usize,
) -> Result<Self> {
check_input_len(input_len as u64)?;
debug_assert!(
input_buf.len() >= input_len.next_multiple_of(WORD_BYTES),
"pooled input buffer must cover the padded word length"
);
Self::assemble(ctx, alloc, input_buf, input_len, false)
}
fn assemble(
ctx: &MetalContext,
alloc: Alloc<'_>,
input_buf: GpuBuffer,
input_len: usize,
input_external: bool,
) -> Result<Self> {
let words = input_len.div_ceil(WORD_BYTES);
let chunks = words.div_ceil(CHUNK_WORDS);
let bm_quote = alloc.buffer(ctx, words * size_of::<u64>())?;
let bm_tok = alloc.buffer(ctx, words * size_of::<u64>())?;
let escape_info = alloc.buffer(ctx, words)?;
let mut chunk_quote_counts = alloc.buffer(ctx, chunks * size_of::<u32>())?;
let mut chunk_token_counts = alloc.buffer(ctx, chunks * size_of::<u32>())?;
let mut header = alloc.buffer(ctx, size_of::<MjHeader>())?;
chunk_quote_counts.contents_mut().fill(0);
chunk_token_counts.contents_mut().fill(0);
header.as_mut_slice::<MjHeader>()[0] = MjHeader::new();
Ok(Self {
input: input_buf,
input_external,
input_len,
words,
chunks,
bm_quote,
bm_tok,
escape_info,
chunk_quote_counts,
chunk_token_counts,
header,
tok_pos: None,
tok_kind: None,
})
}
#[must_use]
pub fn input_len(&self) -> usize {
self.input_len
}
#[must_use]
pub fn words(&self) -> usize {
self.words
}
#[must_use]
pub fn chunks(&self) -> usize {
self.chunks
}
#[must_use]
pub fn read_header(&self) -> MjHeader {
self.header.as_slice::<MjHeader>()[0]
}
pub fn reset_for_reuse(&mut self) {
self.chunk_quote_counts.contents_mut().fill(0);
self.chunk_token_counts.contents_mut().fill(0);
self.header.as_mut_slice::<MjHeader>()[0] = MjHeader::new();
self.tok_pos = None;
self.tok_kind = None;
}
pub fn alloc_tokens(&mut self, ctx: &MetalContext, token_count: usize) -> Result<()> {
self.alloc_tokens_in(ctx, Alloc::Direct, token_count)
}
pub(crate) fn alloc_tokens_in(
&mut self,
ctx: &MetalContext,
alloc: Alloc<'_>,
token_count: usize,
) -> Result<()> {
self.tok_pos = Some(alloc.buffer(ctx, token_count * size_of::<u32>())?);
self.tok_kind = Some(alloc.buffer(ctx, token_count)?);
Ok(())
}
pub(crate) fn recycle(self, pool: &ScratchPool) {
let Self {
input,
input_external,
bm_quote,
bm_tok,
escape_info,
chunk_quote_counts,
chunk_token_counts,
header,
tok_pos,
tok_kind,
..
} = self;
if !input_external {
pool.put_back(input);
}
for buf in [
bm_quote,
bm_tok,
escape_info,
chunk_quote_counts,
chunk_token_counts,
header,
] {
pool.put_back(buf);
}
if let Some(buf) = tok_pos {
pool.put_back(buf);
}
if let Some(buf) = tok_kind {
pool.put_back(buf);
}
}
}
fn check_input_len(len: u64) -> Result<()> {
if len > MAX_INPUT_BYTES {
return Err(Error::InputTooLarge {
len,
max: MAX_INPUT_BYTES,
});
}
Ok(())
}
#[derive(Debug)]
pub struct Stage2Buffers {
token_total: usize,
chunks: usize,
pub chunk_counts: GpuBuffer,
pub chunk_string_bytes: GpuBuffer,
pub chunk_error: GpuBuffer,
pub tape_ofs: GpuBuffer,
pub skel_token_index: Option<GpuBuffer>,
pub skel_pos: Option<GpuBuffer>,
pub skel_byte: Option<GpuBuffer>,
pub string_tokens: Option<GpuBuffer>,
pub scalar_tokens: Option<GpuBuffer>,
}
impl Stage2Buffers {
pub fn new(ctx: &MetalContext, token_total: usize) -> Result<Self> {
Self::new_in(ctx, Alloc::Direct, token_total)
}
pub(crate) fn new_in(
ctx: &MetalContext,
alloc: Alloc<'_>,
token_total: usize,
) -> Result<Self> {
let chunks = token_total.div_ceil(TOKEN_CHUNK_TOKENS);
Ok(Self {
token_total,
chunks,
chunk_counts: alloc.buffer(ctx, chunks * 4 * size_of::<u32>())?,
chunk_string_bytes: alloc.buffer(ctx, chunks * size_of::<u64>())?,
chunk_error: alloc.buffer(ctx, chunks * size_of::<u64>())?,
tape_ofs: alloc.buffer(ctx, token_total * size_of::<u32>())?,
skel_token_index: None,
skel_pos: None,
skel_byte: None,
string_tokens: None,
scalar_tokens: None,
})
}
#[must_use]
pub fn token_total(&self) -> usize {
self.token_total
}
#[must_use]
pub fn chunks(&self) -> usize {
self.chunks
}
pub fn alloc_lists(
&mut self,
ctx: &MetalContext,
skeleton_total: usize,
string_total: usize,
scalar_total: usize,
) -> Result<()> {
self.alloc_lists_in(ctx, Alloc::Direct, skeleton_total, string_total, scalar_total)
}
pub(crate) fn alloc_lists_in(
&mut self,
ctx: &MetalContext,
alloc: Alloc<'_>,
skeleton_total: usize,
string_total: usize,
scalar_total: usize,
) -> Result<()> {
self.skel_token_index = Some(alloc.buffer(ctx, skeleton_total * size_of::<u32>())?);
self.skel_pos = Some(alloc.buffer(ctx, skeleton_total * size_of::<u32>())?);
self.skel_byte = Some(alloc.buffer(ctx, skeleton_total)?);
self.string_tokens = Some(alloc.buffer(ctx, string_total * size_of::<u32>())?);
self.scalar_tokens = Some(alloc.buffer(ctx, scalar_total * size_of::<u32>())?);
Ok(())
}
pub(crate) fn recycle(self, pool: &ScratchPool) {
let Self {
chunk_counts,
chunk_string_bytes,
chunk_error,
tape_ofs,
skel_token_index,
skel_pos,
skel_byte,
string_tokens,
scalar_tokens,
..
} = self;
for buf in [chunk_counts, chunk_string_bytes, chunk_error, tape_ofs] {
pool.put_back(buf);
}
for buf in [skel_token_index, skel_pos, skel_byte, string_tokens, scalar_tokens]
.into_iter()
.flatten()
{
pool.put_back(buf);
}
}
}
pub const CTX_STATE_BYTES: usize = 32;
#[derive(Debug)]
pub struct Stage3Buffers {
skeleton_total: usize,
chunks: usize,
passes: usize,
pub chunk_depth: GpuBuffer,
pub depths: GpuBuffer,
pub chunk_error: GpuBuffer,
pub sort_hist: GpuBuffer,
pub max_key: GpuBuffer,
pub sorted: GpuBuffer,
pub sorted_scratch: Option<GpuBuffer>,
pub chunk_ctx: GpuBuffer,
pub match_index: GpuBuffer,
pub context_opener: GpuBuffer,
pub child_counts: GpuBuffer,
}
impl Stage3Buffers {
pub fn new(ctx: &MetalContext, skeleton_total: usize, passes: usize) -> Result<Self> {
Self::new_in(ctx, Alloc::Direct, skeleton_total, passes)
}
pub(crate) fn new_in(
ctx: &MetalContext,
alloc: Alloc<'_>,
skeleton_total: usize,
passes: usize,
) -> Result<Self> {
assert!(skeleton_total > 0, "empty skeletons never dispatch CB3");
assert!(passes > 0, "the sort always runs at least one pass");
let chunks = skeleton_total.div_ceil(SKELETON_CHUNK_ELEMS);
Ok(Self {
skeleton_total,
chunks,
passes,
chunk_depth: alloc.buffer(ctx, chunks * size_of::<i64>())?,
depths: alloc.buffer(ctx, skeleton_total * size_of::<u32>())?,
chunk_error: alloc.buffer(ctx, chunks * size_of::<u64>())?,
sort_hist: alloc.buffer(ctx, 32 * chunks * size_of::<u32>())?,
max_key: alloc.buffer(ctx, size_of::<u32>())?,
sorted: alloc.buffer(ctx, skeleton_total * size_of::<u32>())?,
sorted_scratch: if passes > 1 {
Some(alloc.buffer(ctx, skeleton_total * size_of::<u32>())?)
} else {
None
},
chunk_ctx: alloc.buffer(ctx, chunks * CTX_STATE_BYTES)?,
match_index: alloc.buffer(ctx, skeleton_total * size_of::<u32>())?,
context_opener: alloc.buffer(ctx, skeleton_total)?,
child_counts: alloc.buffer(ctx, skeleton_total * size_of::<u32>())?,
})
}
pub(crate) fn recycle(self, pool: &ScratchPool) {
let Self {
chunk_depth,
depths,
chunk_error,
sort_hist,
max_key,
sorted,
sorted_scratch,
chunk_ctx,
match_index,
context_opener,
child_counts,
..
} = self;
for buf in [
chunk_depth,
depths,
chunk_error,
sort_hist,
max_key,
sorted,
chunk_ctx,
match_index,
context_opener,
child_counts,
] {
pool.put_back(buf);
}
if let Some(buf) = sorted_scratch {
pool.put_back(buf);
}
}
#[must_use]
pub fn skeleton_total(&self) -> usize {
self.skeleton_total
}
#[must_use]
pub fn chunks(&self) -> usize {
self.chunks
}
#[must_use]
pub fn passes(&self) -> usize {
self.passes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pipeline_geometry_constants_are_consistent() {
assert_eq!(CHUNK_BYTES, WORD_BYTES * CHUNK_WORDS);
assert_eq!(CHUNK_BYTES, 65536);
assert_eq!(ESCAPE_LOOKBACK_CAP % WORD_BYTES, 0);
assert_eq!(TOKEN_CHUNK_TOKENS, crate::metal::THREADGROUP_SIZE * 4);
assert_eq!(SKELETON_CHUNK_ELEMS, crate::metal::THREADGROUP_SIZE * 4);
}
#[test]
fn sort_pass_counts_cover_the_clean_key_range() {
assert_eq!(sort_passes(1), 1);
assert_eq!(sort_passes(31), 1);
assert_eq!(sort_passes(32), 1, "key_max 31 still fits one digit");
assert_eq!(sort_passes(33), 2);
assert_eq!(sort_passes(1024), 2, "the simdjson-parity default");
assert_eq!(sort_passes(1025), 3);
assert_eq!(sort_passes(u32::MAX), 7);
assert_eq!(sort_passes(0), 1);
}
#[cfg(feature = "cpu-reference")]
#[test]
fn max_input_matches_the_reference_pipeline() {
assert_eq!(MAX_INPUT_BYTES, crate::reference::MAX_INPUT_BYTES);
}
#[test]
fn stage_is_const_constructible_and_named() {
const fn make() -> Stage {
Stage::new("classify_escape_utf8")
}
let k1 = make();
assert_eq!(k1.name(), "classify_escape_utf8");
}
}