use crate::reader::serializer::TapeSerializer;
use arrow_schema::ArrowError;
use serde::Serialize;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum TapeElement {
StartObject(u32),
EndObject(u32),
StartList(u32),
EndList(u32),
String(u32),
Number(u32),
True,
False,
Null,
}
#[derive(Debug)]
pub struct Tape<'a> {
elements: &'a [TapeElement],
strings: &'a str,
string_offsets: &'a [usize],
num_rows: usize,
}
impl<'a> Tape<'a> {
#[inline]
pub fn get_string(&self, idx: u32) -> &'a str {
let end_offset = self.string_offsets[idx as usize + 1];
let start_offset = self.string_offsets[idx as usize];
unsafe { self.strings.get_unchecked(start_offset..end_offset) }
}
pub fn get(&self, idx: u32) -> TapeElement {
self.elements[idx as usize]
}
pub fn next(&self, cur_idx: u32, expected: &str) -> Result<u32, ArrowError> {
match self.get(cur_idx) {
TapeElement::String(_)
| TapeElement::Number(_)
| TapeElement::True
| TapeElement::False
| TapeElement::Null => Ok(cur_idx + 1),
TapeElement::StartList(end_idx) => Ok(end_idx + 1),
TapeElement::StartObject(end_idx) => Ok(end_idx + 1),
_ => Err(self.error(cur_idx, expected)),
}
}
pub fn num_rows(&self) -> usize {
self.num_rows
}
fn serialize(&self, out: &mut String, idx: u32) -> u32 {
match self.get(idx) {
TapeElement::StartObject(end) => {
out.push('{');
let mut cur_idx = idx + 1;
while cur_idx < end {
cur_idx = self.serialize(out, cur_idx);
out.push_str(": ");
cur_idx = self.serialize(out, cur_idx);
}
out.push('}');
return end + 1;
}
TapeElement::EndObject(_) => out.push('}'),
TapeElement::StartList(end) => {
out.push('[');
let mut cur_idx = idx + 1;
while cur_idx < end {
cur_idx = self.serialize(out, cur_idx);
if cur_idx < end {
out.push_str(", ");
}
}
out.push(']');
return end + 1;
}
TapeElement::EndList(_) => out.push(']'),
TapeElement::String(s) => {
out.push('"');
out.push_str(self.get_string(s));
out.push('"')
}
TapeElement::Number(n) => out.push_str(self.get_string(n)),
TapeElement::True => out.push_str("true"),
TapeElement::False => out.push_str("false"),
TapeElement::Null => out.push_str("null"),
}
idx + 1
}
pub fn error(&self, idx: u32, expected: &str) -> ArrowError {
let mut out = String::with_capacity(64);
self.serialize(&mut out, idx);
ArrowError::JsonError(format!("expected {expected} got {out}"))
}
}
#[derive(Debug, Copy, Clone)]
enum DecoderState {
Object(u32),
List(u32),
String,
Value,
Number,
Colon,
Escape,
Unicode(u16, u16, u8),
Literal(Literal, u8),
}
impl DecoderState {
fn as_str(&self) -> &'static str {
match self {
DecoderState::Object(_) => "object",
DecoderState::List(_) => "list",
DecoderState::String => "string",
DecoderState::Value => "value",
DecoderState::Number => "number",
DecoderState::Colon => "colon",
DecoderState::Escape => "escape",
DecoderState::Unicode(_, _, _) => "unicode literal",
DecoderState::Literal(d, _) => d.as_str(),
}
}
}
#[derive(Debug, Copy, Clone)]
enum Literal {
Null,
True,
False,
}
impl Literal {
fn element(&self) -> TapeElement {
match self {
Literal::Null => TapeElement::Null,
Literal::True => TapeElement::True,
Literal::False => TapeElement::False,
}
}
fn as_str(&self) -> &'static str {
match self {
Literal::Null => "null",
Literal::True => "true",
Literal::False => "false",
}
}
fn bytes(&self) -> &'static [u8] {
self.as_str().as_bytes()
}
}
macro_rules! next {
($next:ident) => {
match $next.next() {
Some(b) => b,
None => break,
}
};
}
pub struct TapeDecoder {
elements: Vec<TapeElement>,
num_rows: usize,
batch_size: usize,
bytes: Vec<u8>,
offsets: Vec<usize>,
stack: Vec<DecoderState>,
}
impl TapeDecoder {
pub fn new(batch_size: usize, num_fields: usize) -> Self {
let tokens_per_row = 2 + num_fields * 2;
let mut offsets = Vec::with_capacity(batch_size * (num_fields * 2) + 1);
offsets.push(0);
let mut elements = Vec::with_capacity(batch_size * tokens_per_row);
elements.push(TapeElement::Null);
Self {
offsets,
elements,
batch_size,
num_rows: 0,
bytes: Vec::with_capacity(num_fields * 2 * 8),
stack: Vec::with_capacity(10),
}
}
pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
if self.num_rows >= self.batch_size {
return Ok(0);
}
let mut iter = BufIter::new(buf);
while !iter.is_empty() {
match self.stack.last_mut() {
None => {
iter.skip_whitespace();
match next!(iter) {
b'{' => {
let idx = self.elements.len() as u32;
self.stack.push(DecoderState::Object(idx));
self.elements.push(TapeElement::StartObject(u32::MAX));
}
b => return Err(err(b, "trimming leading whitespace")),
}
}
Some(DecoderState::Object(start_idx)) => {
iter.advance_until(|b| !json_whitespace(b) && b != b',');
match next!(iter) {
b'"' => {
self.stack.push(DecoderState::Value);
self.stack.push(DecoderState::Colon);
self.stack.push(DecoderState::String);
}
b'}' => {
let start_idx = *start_idx;
let end_idx = self.elements.len() as u32;
self.elements[start_idx as usize] =
TapeElement::StartObject(end_idx);
self.elements.push(TapeElement::EndObject(start_idx));
self.stack.pop();
self.num_rows += self.stack.is_empty() as usize;
if self.num_rows >= self.batch_size {
break;
}
}
b => return Err(err(b, "parsing object")),
}
}
Some(DecoderState::List(start_idx)) => {
iter.advance_until(|b| !json_whitespace(b) && b != b',');
match iter.peek() {
Some(b']') => {
iter.next();
let start_idx = *start_idx;
let end_idx = self.elements.len() as u32;
self.elements[start_idx as usize] =
TapeElement::StartList(end_idx);
self.elements.push(TapeElement::EndList(start_idx));
self.stack.pop();
}
Some(_) => self.stack.push(DecoderState::Value),
None => break,
}
}
Some(DecoderState::String) => {
let s = iter.advance_until(|b| matches!(b, b'\\' | b'"'));
self.bytes.extend_from_slice(s);
match next!(iter) {
b'\\' => self.stack.push(DecoderState::Escape),
b'"' => {
let idx = self.offsets.len() - 1;
self.elements.push(TapeElement::String(idx as _));
self.offsets.push(self.bytes.len());
self.stack.pop();
}
b => unreachable!("{}", b),
}
}
Some(state @ DecoderState::Value) => {
iter.skip_whitespace();
*state = match next!(iter) {
b'"' => DecoderState::String,
b @ b'-' | b @ b'0'..=b'9' => {
self.bytes.push(b);
DecoderState::Number
}
b'n' => DecoderState::Literal(Literal::Null, 1),
b'f' => DecoderState::Literal(Literal::False, 1),
b't' => DecoderState::Literal(Literal::True, 1),
b'[' => {
let idx = self.elements.len() as u32;
self.elements.push(TapeElement::StartList(u32::MAX));
DecoderState::List(idx)
}
b'{' => {
let idx = self.elements.len() as u32;
self.elements.push(TapeElement::StartObject(u32::MAX));
DecoderState::Object(idx)
}
b => return Err(err(b, "parsing value")),
};
}
Some(DecoderState::Number) => {
let s = iter.advance_until(|b| {
!matches!(b, b'0'..=b'9' | b'-' | b'+' | b'.' | b'e' | b'E')
});
self.bytes.extend_from_slice(s);
if !iter.is_empty() {
self.stack.pop();
let idx = self.offsets.len() - 1;
self.elements.push(TapeElement::Number(idx as _));
self.offsets.push(self.bytes.len());
}
}
Some(DecoderState::Colon) => {
iter.skip_whitespace();
match next!(iter) {
b':' => self.stack.pop(),
b => return Err(err(b, "parsing colon")),
};
}
Some(DecoderState::Literal(literal, idx)) => {
let bytes = literal.bytes();
let expected = bytes.iter().skip(*idx as usize).copied();
for (expected, b) in expected.zip(&mut iter) {
match b == expected {
true => *idx += 1,
false => return Err(err(b, "parsing literal")),
}
}
if *idx == bytes.len() as u8 {
let element = literal.element();
self.stack.pop();
self.elements.push(element);
}
}
Some(DecoderState::Escape) => {
let v = match next!(iter) {
b'u' => {
self.stack.pop();
self.stack.push(DecoderState::Unicode(0, 0, 0));
continue;
}
b'"' => b'"',
b'\\' => b'\\',
b'/' => b'/',
b'b' => 8, b'f' => 12, b'n' => b'\n',
b'r' => b'\r',
b't' => b'\t',
b => return Err(err(b, "parsing escape sequence")),
};
self.stack.pop();
self.bytes.push(v);
}
Some(DecoderState::Unicode(high, low, idx)) => loop {
match *idx {
0..=3 => *high = *high << 4 | parse_hex(next!(iter))? as u16,
4 => {
if let Some(c) = char::from_u32(*high as u32) {
write_char(c, &mut self.bytes);
self.stack.pop();
break;
}
match next!(iter) {
b'\\' => {}
b => return Err(err(b, "parsing surrogate pair escape")),
}
}
5 => match next!(iter) {
b'u' => {}
b => return Err(err(b, "parsing surrogate pair unicode")),
},
6..=9 => *low = *low << 4 | parse_hex(next!(iter))? as u16,
_ => {
let c = char_from_surrogate_pair(*low, *high)?;
write_char(c, &mut self.bytes);
self.stack.pop();
break;
}
}
*idx += 1;
},
}
}
Ok(buf.len() - iter.len())
}
pub fn serialize<S: Serialize>(&mut self, rows: &[S]) -> Result<(), ArrowError> {
if let Some(b) = self.stack.last() {
return Err(ArrowError::JsonError(format!(
"Cannot serialize to tape containing partial decode state {}",
b.as_str()
)));
}
let mut serializer =
TapeSerializer::new(&mut self.elements, &mut self.bytes, &mut self.offsets);
rows.iter()
.try_for_each(|row| row.serialize(&mut serializer))
.map_err(|e| ArrowError::JsonError(e.to_string()))?;
self.num_rows += rows.len();
Ok(())
}
pub fn finish(&self) -> Result<Tape<'_>, ArrowError> {
if let Some(b) = self.stack.last() {
return Err(ArrowError::JsonError(format!(
"Truncated record whilst reading {}",
b.as_str()
)));
}
if self.offsets.len() >= u32::MAX as usize {
return Err(ArrowError::JsonError(format!("Encountered more than {} bytes of string data, consider using a smaller batch size", u32::MAX)));
}
if self.offsets.len() >= u32::MAX as usize {
return Err(ArrowError::JsonError(format!("Encountered more than {} JSON elements, consider using a smaller batch size", u32::MAX)));
}
assert_eq!(
self.offsets.last().copied().unwrap_or_default(),
self.bytes.len()
);
let strings = std::str::from_utf8(&self.bytes).map_err(|_| {
ArrowError::JsonError("Encountered non-UTF-8 data".to_string())
})?;
for offset in self.offsets.iter().copied() {
if !strings.is_char_boundary(offset) {
return Err(ArrowError::JsonError(
"Encountered truncated UTF-8 sequence".to_string(),
));
}
}
Ok(Tape {
strings,
elements: &self.elements,
string_offsets: &self.offsets,
num_rows: self.num_rows,
})
}
pub fn clear(&mut self) {
assert!(self.stack.is_empty());
self.num_rows = 0;
self.bytes.clear();
self.elements.clear();
self.elements.push(TapeElement::Null);
self.offsets.clear();
self.offsets.push(0);
}
}
struct BufIter<'a>(std::slice::Iter<'a, u8>);
impl<'a> BufIter<'a> {
fn new(buf: &'a [u8]) -> Self {
Self(buf.iter())
}
fn as_slice(&self) -> &'a [u8] {
self.0.as_slice()
}
fn is_empty(&self) -> bool {
self.0.len() == 0
}
fn peek(&self) -> Option<u8> {
self.0.as_slice().first().copied()
}
fn advance(&mut self, skip: usize) {
for _ in 0..skip {
self.0.next();
}
}
fn advance_until<F: FnMut(u8) -> bool>(&mut self, f: F) -> &[u8] {
let s = self.as_slice();
match s.iter().copied().position(f) {
Some(x) => {
self.advance(x);
&s[..x]
}
None => {
self.advance(s.len());
s
}
}
}
fn skip_whitespace(&mut self) {
self.advance_until(|b| !json_whitespace(b));
}
}
impl<'a> Iterator for BufIter<'a> {
type Item = u8;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().copied()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<'a> ExactSizeIterator for BufIter<'a> {}
fn err(b: u8, ctx: &str) -> ArrowError {
ArrowError::JsonError(format!(
"Encountered unexpected '{}' whilst {ctx}",
b as char
))
}
fn char_from_surrogate_pair(low: u16, high: u16) -> Result<char, ArrowError> {
let n = (((high - 0xD800) as u32) << 10 | (low - 0xDC00) as u32) + 0x1_0000;
char::from_u32(n).ok_or_else(|| {
ArrowError::JsonError(format!("Invalid UTF-16 surrogate pair {n}"))
})
}
fn write_char(c: char, out: &mut Vec<u8>) {
let mut t = [0; 4];
out.extend_from_slice(c.encode_utf8(&mut t).as_bytes());
}
#[inline]
fn json_whitespace(b: u8) -> bool {
matches!(b, b' ' | b'\n' | b'\r' | b'\t')
}
fn parse_hex(b: u8) -> Result<u8, ArrowError> {
let digit = char::from(b)
.to_digit(16)
.ok_or_else(|| err(b, "unicode escape"))?;
Ok(digit as u8)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sizes() {
assert_eq!(std::mem::size_of::<DecoderState>(), 8);
assert_eq!(std::mem::size_of::<TapeElement>(), 8);
}
#[test]
fn test_basic() {
let a = r#"
{"hello": "world", "foo": 2, "bar": 45}
{"foo": "bar"}
{"fiz": null}
{"a": true, "b": false, "c": null}
{"a": "", "": "a"}
{"a": "b", "object": {"nested": "hello", "foo": 23}, "b": {}, "c": {"foo": null }}
{"a": ["", "foo", ["bar", "c"]], "b": {"1": []}, "c": {"2": [1, 2, 3]} }
"#;
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(a.as_bytes()).unwrap();
let finished = decoder.finish().unwrap();
assert_eq!(
finished.elements,
&[
TapeElement::Null,
TapeElement::StartObject(8), TapeElement::String(0), TapeElement::String(1), TapeElement::String(2), TapeElement::Number(3), TapeElement::String(4), TapeElement::Number(5), TapeElement::EndObject(1),
TapeElement::StartObject(12), TapeElement::String(6), TapeElement::String(7), TapeElement::EndObject(9),
TapeElement::StartObject(16), TapeElement::String(8), TapeElement::Null, TapeElement::EndObject(13),
TapeElement::StartObject(24), TapeElement::String(9), TapeElement::True, TapeElement::String(10), TapeElement::False, TapeElement::String(11), TapeElement::Null, TapeElement::EndObject(17),
TapeElement::StartObject(30), TapeElement::String(12), TapeElement::String(13), TapeElement::String(14), TapeElement::String(15), TapeElement::EndObject(25),
TapeElement::StartObject(49), TapeElement::String(16), TapeElement::String(17), TapeElement::String(18), TapeElement::StartObject(40), TapeElement::String(19), TapeElement::String(20), TapeElement::String(21), TapeElement::Number(22), TapeElement::EndObject(35),
TapeElement::String(23), TapeElement::StartObject(43), TapeElement::EndObject(42),
TapeElement::String(24), TapeElement::StartObject(48), TapeElement::String(25), TapeElement::Null, TapeElement::EndObject(45),
TapeElement::EndObject(31),
TapeElement::StartObject(75), TapeElement::String(26), TapeElement::StartList(59), TapeElement::String(27), TapeElement::String(28), TapeElement::StartList(58), TapeElement::String(29), TapeElement::String(30), TapeElement::EndList(55),
TapeElement::EndList(52),
TapeElement::String(31), TapeElement::StartObject(65), TapeElement::String(32), TapeElement::StartList(64), TapeElement::EndList(63),
TapeElement::EndObject(61),
TapeElement::String(33), TapeElement::StartObject(74), TapeElement::String(34), TapeElement::StartList(73), TapeElement::Number(35), TapeElement::Number(36), TapeElement::Number(37), TapeElement::EndList(69),
TapeElement::EndObject(67),
TapeElement::EndObject(50)
]
);
assert_eq!(
finished.strings,
"helloworldfoo2bar45foobarfizabcaaabobjectnestedhellofoo23bcfooafoobarcb1c2123"
);
assert_eq!(
&finished.string_offsets,
&[
0, 5, 10, 13, 14, 17, 19, 22, 25, 28, 29, 30, 31, 32, 32, 32, 33, 34, 35,
41, 47, 52, 55, 57, 58, 59, 62, 63, 63, 66, 69, 70, 71, 72, 73, 74, 75,
76, 77
]
)
}
#[test]
fn test_invalid() {
let mut decoder = TapeDecoder::new(16, 2);
let err = decoder.decode(b"hello").unwrap_err().to_string();
assert_eq!(
err,
"Json error: Encountered unexpected 'h' whilst trimming leading whitespace"
);
let mut decoder = TapeDecoder::new(16, 2);
let err = decoder.decode(b"{\"hello\": }").unwrap_err().to_string();
assert_eq!(
err,
"Json error: Encountered unexpected '}' whilst parsing value"
);
let mut decoder = TapeDecoder::new(16, 2);
let err = decoder
.decode(b"{\"hello\": [ false, tru ]}")
.unwrap_err()
.to_string();
assert_eq!(
err,
"Json error: Encountered unexpected ' ' whilst parsing literal"
);
let mut decoder = TapeDecoder::new(16, 2);
let err = decoder
.decode(b"{\"hello\": \"\\ud8\"}")
.unwrap_err()
.to_string();
assert_eq!(
err,
"Json error: Encountered unexpected '\"' whilst unicode escape"
);
let mut decoder = TapeDecoder::new(16, 2);
let err = decoder
.decode(b"{\"hello\": \"\\ud83d\"}")
.unwrap_err()
.to_string();
assert_eq!(
err,
"Json error: Encountered unexpected '\"' whilst parsing surrogate pair escape"
);
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(b"{\"he").unwrap();
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Json error: Truncated record whilst reading string");
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(b"{\"hello\" : ").unwrap();
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Json error: Truncated record whilst reading value");
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(b"{\"hello\" : [").unwrap();
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Json error: Truncated record whilst reading list");
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(b"{\"hello\" : tru").unwrap();
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Json error: Truncated record whilst reading true");
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(b"{\"hello\" : nu").unwrap();
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Json error: Truncated record whilst reading null");
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(b"{\"hello\" : \"world\xFF\"}").unwrap();
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Json error: Encountered non-UTF-8 data");
let mut decoder = TapeDecoder::new(16, 2);
decoder.decode(b"{\"\xe2\" : \"\x96\xa1\"}").unwrap();
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Json error: Encountered truncated UTF-8 sequence");
}
}