use std::fmt::Write;
use std::sync::Arc;
use arrow_array::ArrayRef;
use arrow_array::builder::GenericByteViewBuilder;
use arrow_array::types::StringViewType;
use arrow_schema::ArrowError;
use crate::reader::tape::{Tape, TapeElement};
use crate::reader::{ArrayDecoder, DecoderContext};
const TRUE: &str = "true";
const FALSE: &str = "false";
pub struct StringViewArrayDecoder {
coerce_primitive: bool,
ignore_type_conflicts: bool,
}
impl StringViewArrayDecoder {
pub fn new(ctx: &DecoderContext) -> Self {
Self {
coerce_primitive: ctx.coerce_primitive(),
ignore_type_conflicts: ctx.ignore_type_conflicts(),
}
}
}
impl ArrayDecoder for StringViewArrayDecoder {
fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayRef, ArrowError> {
let coerce = self.coerce_primitive;
let mut data_capacity = 0;
for &p in pos {
match tape.get(p) {
TapeElement::String(idx) => {
let s = tape.get_string(idx);
if s.len() > 12 {
data_capacity += s.len();
}
}
TapeElement::Null => {
}
TapeElement::True if coerce => {}
TapeElement::False if coerce => {}
TapeElement::Number(idx) if coerce => {
let s = tape.get_string(idx);
if s.len() > 12 {
data_capacity += s.len();
}
}
TapeElement::I64(_) if coerce => {
match tape.get(p + 1) {
TapeElement::I32(_) => {
let high = match tape.get(p) {
TapeElement::I64(h) => h,
_ => unreachable!(),
};
let low = match tape.get(p + 1) {
TapeElement::I32(l) => l,
_ => unreachable!(),
};
let val = ((high as i64) << 32) | (low as u32) as i64;
if val.abs() > 999_999_999_999 {
data_capacity += val.to_string().len();
}
}
_ => unreachable!(),
}
}
TapeElement::I32(_) if coerce => {}
TapeElement::F32(_) if coerce => {
data_capacity += 10;
}
TapeElement::F64(_) if coerce => {
data_capacity += 10;
}
_ if self.ignore_type_conflicts => {} _ => {
return Err(tape.error(p, "string"));
}
}
}
let mut builder = GenericByteViewBuilder::<StringViewType>::with_capacity(data_capacity);
let mut tmp_buf = String::new();
for &p in pos {
match tape.get(p) {
TapeElement::String(idx) => {
builder.append_value(tape.get_string(idx));
}
TapeElement::Null => {
builder.append_null();
}
TapeElement::True if coerce => {
builder.append_value(TRUE);
}
TapeElement::False if coerce => {
builder.append_value(FALSE);
}
TapeElement::Number(idx) if coerce => {
builder.append_value(tape.get_string(idx));
}
TapeElement::I64(high) if coerce => match tape.get(p + 1) {
TapeElement::I32(low) => {
let val = ((high as i64) << 32) | (low as u32) as i64;
tmp_buf.clear();
write!(&mut tmp_buf, "{val}").unwrap();
builder.append_value(&tmp_buf);
}
_ => unreachable!(),
},
TapeElement::I32(n) if coerce => {
tmp_buf.clear();
write!(&mut tmp_buf, "{n}").unwrap();
builder.append_value(&tmp_buf);
}
TapeElement::F32(n) if coerce => {
tmp_buf.clear();
write!(&mut tmp_buf, "{n}").unwrap();
builder.append_value(&tmp_buf);
}
TapeElement::F64(high) if coerce => match tape.get(p + 1) {
TapeElement::F32(low) => {
let val = f64::from_bits(((high as u64) << 32) | (low as u64));
tmp_buf.clear();
write!(&mut tmp_buf, "{val}").unwrap();
builder.append_value(&tmp_buf);
}
_ => unreachable!(),
},
_ if self.ignore_type_conflicts => {
builder.append_null();
}
_ => unreachable!(),
}
}
Ok(Arc::new(builder.finish()))
}
}