use std::marker::PhantomData;
use std::sync::Arc;
use arrow_array::ArrayRef;
use arrow_array::builder::PrimitiveBuilder;
use arrow_array::types::ArrowTimestampType;
use arrow_cast::parse::string_to_datetime;
use arrow_schema::{ArrowError, DataType, TimeUnit};
use chrono::TimeZone;
use crate::reader::tape::{Tape, TapeElement};
use crate::reader::{ArrayDecoder, DecoderContext};
pub struct TimestampArrayDecoder<P: ArrowTimestampType, Tz: TimeZone> {
data_type: DataType,
timezone: Tz,
ignore_type_conflicts: bool,
phantom: PhantomData<fn(P) -> P>,
}
impl<P: ArrowTimestampType, Tz: TimeZone> TimestampArrayDecoder<P, Tz> {
pub fn new(ctx: &DecoderContext, data_type: &DataType, timezone: Tz) -> Self {
Self {
data_type: data_type.clone(),
timezone,
ignore_type_conflicts: ctx.ignore_type_conflicts(),
phantom: Default::default(),
}
}
}
impl<P, Tz> ArrayDecoder for TimestampArrayDecoder<P, Tz>
where
P: ArrowTimestampType,
Tz: TimeZone + Send,
{
fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayRef, ArrowError> {
let mut builder =
PrimitiveBuilder::<P>::with_capacity(pos.len()).with_data_type(self.data_type.clone());
for p in pos {
let value = match tape.get(*p) {
TapeElement::Null => {
builder.append_null();
continue;
}
TapeElement::String(idx) => {
let s = tape.get_string(idx);
let date = string_to_datetime(&self.timezone, s).map_err(|e| {
ArrowError::JsonError(format!(
"failed to parse \"{s}\" as {}: {}",
self.data_type, e
))
});
date.and_then(|date| match P::UNIT {
TimeUnit::Second => Ok(date.timestamp()),
TimeUnit::Millisecond => Ok(date.timestamp_millis()),
TimeUnit::Microsecond => Ok(date.timestamp_micros()),
TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
ArrowError::ParseError(format!(
"{} would overflow 64-bit signed nanoseconds",
date.to_rfc3339(),
))
}),
})
}
TapeElement::Number(idx) => {
let s = tape.get_string(idx);
let b = s.as_bytes();
lexical_core::parse::<i64>(b)
.or_else(|_| lexical_core::parse::<f64>(b).map(|x| x as i64))
.map_err(|_| {
ArrowError::JsonError(format!(
"failed to parse {s} as {}",
self.data_type
))
})
}
TapeElement::I32(v) => Ok(v as i64),
TapeElement::I64(high) => match tape.get(p + 1) {
TapeElement::I32(low) => Ok(((high as i64) << 32) | (low as u32) as i64),
_ => unreachable!(),
},
_ => Err(tape.error(*p, "primitive")),
};
match value {
Ok(value) => builder.append_value(value),
Err(_) if self.ignore_type_conflicts => builder.append_null(),
Err(e) => return Err(e),
}
}
Ok(Arc::new(builder.finish()))
}
}