use std::sync::Arc;
use arrow::array::{
Array, ArrayRef, Date32Array, DurationMicrosecondArray, Float64Array, Float64Builder,
StringViewBuilder, TimestampMicrosecondArray,
};
use arrow::datatypes::{Field, Schema};
use arrow::record_batch::RecordBatch;
use encoding_rs::Encoding;
use rayon::prelude::*;
use crate::arrow_convert;
use crate::constants::{
MICROS_PER_SECOND, SECONDS_PER_DAY, SPSS_EPOCH_OFFSET_DAYS, SPSS_EPOCH_OFFSET_SECONDS,
TemporalKind, VarType, is_sysmis,
};
use crate::dictionary::ResolvedDictionary;
use crate::encoding;
use crate::error::Result;
use crate::io_utils;
use crate::variable::VariableRecord;
const WIDE_ROW_THRESHOLD: usize = 12_288;
const L3_TILE_BYTES: usize = 4 * 1024 * 1024;
struct VlsSegmentInfo {
useful_bytes: usize,
}
struct ColumnMapping {
slot_index: usize,
var_type: VarType,
n_segments: usize,
vls_layout: Vec<VlsSegmentInfo>,
}
enum ColBuilder {
Float64(Float64Builder),
Str(StringViewBuilder),
}
pub struct ColumnarBatchBuilder {
mappings: Vec<ColumnMapping>,
builders: Vec<ColBuilder>,
file_encoding: &'static Encoding,
schema: Arc<Schema>,
rows_appended: usize,
string_buf: Vec<u8>,
temporal_columns: Vec<(usize, TemporalKind)>,
}
impl ColumnarBatchBuilder {
pub fn new(dict: &ResolvedDictionary, projection: Option<&[usize]>, capacity: usize) -> Self {
let vars: Vec<&VariableRecord> = match projection {
Some(proj) => proj.iter().map(|&i| &dict.variables[i]).collect(),
None => dict.variables.iter().collect(),
};
let mut mappings = Vec::with_capacity(vars.len());
let mut builders = Vec::with_capacity(vars.len());
let mut fields = Vec::with_capacity(vars.len());
let mut temporal_columns = Vec::new();
for (col_idx, var) in vars.iter().enumerate() {
let vls_layout = if var.n_segments > 1 {
let width = match &var.var_type {
VarType::String(w) => *w,
_ => 0,
};
(0..var.n_segments)
.map(|seg| {
let bytes_before = seg * 255;
let remaining = width.saturating_sub(bytes_before);
let seg_useful = remaining.min(255);
VlsSegmentInfo {
useful_bytes: seg_useful,
}
})
.collect()
} else {
Vec::new()
};
mappings.push(ColumnMapping {
slot_index: var.slot_index,
var_type: var.var_type.clone(),
n_segments: var.n_segments,
vls_layout,
});
let output_type = arrow_convert::var_to_arrow_type(var);
fields.push(Field::new(&var.long_name, output_type, true));
match &var.var_type {
VarType::Numeric => {
if let Some(kind) = var
.print_format
.as_ref()
.and_then(|f| f.format_type.temporal_kind())
{
temporal_columns.push((col_idx, kind));
}
builders.push(ColBuilder::Float64(Float64Builder::with_capacity(capacity)));
}
VarType::String(_) => {
let has_value_labels = dict
.metadata
.variable_value_labels
.contains_key(&var.long_name);
let sb = if has_value_labels {
StringViewBuilder::new().with_deduplicate_strings()
} else {
StringViewBuilder::new()
};
builders.push(ColBuilder::Str(sb));
}
}
}
ColumnarBatchBuilder {
mappings,
builders,
file_encoding: dict.file_encoding,
schema: Arc::new(Schema::new(fields)),
rows_appended: 0,
string_buf: Vec::with_capacity(1024),
temporal_columns,
}
}
pub fn push_raw_chunk(&mut self, chunk: &[u8], num_rows: usize, slots_per_row: usize) {
let row_bytes = slots_per_row * 8;
if row_bytes > WIDE_ROW_THRESHOLD {
self.push_raw_chunk_tiled(chunk, num_rows, slots_per_row);
return;
}
let mappings = &self.mappings;
let file_encoding = self.file_encoding;
if num_rows >= 10_000 {
self.builders
.par_iter_mut()
.enumerate()
.for_each(|(i, builder)| {
let mapping = &mappings[i];
match (&mapping.var_type, builder) {
(VarType::Numeric, ColBuilder::Float64(b)) => {
process_numeric_rows(
b,
chunk,
0,
num_rows,
row_bytes,
mapping.slot_index,
);
}
(VarType::String(_), ColBuilder::Str(b)) => {
let mut local_buf = Vec::with_capacity(256);
process_string_rows(
b,
&mut local_buf,
chunk,
0,
num_rows,
row_bytes,
slots_per_row,
mapping,
file_encoding,
);
}
_ => unreachable!(),
}
});
} else {
for (i, mapping) in mappings.iter().enumerate() {
match (&mapping.var_type, &mut self.builders[i]) {
(VarType::Numeric, ColBuilder::Float64(b)) => {
process_numeric_rows(b, chunk, 0, num_rows, row_bytes, mapping.slot_index);
}
(VarType::String(_), ColBuilder::Str(b)) => {
process_string_rows(
b,
&mut self.string_buf,
chunk,
0,
num_rows,
row_bytes,
slots_per_row,
mapping,
self.file_encoding,
);
}
_ => unreachable!(),
}
}
}
self.rows_appended += num_rows;
}
fn push_raw_chunk_tiled(&mut self, chunk: &[u8], num_rows: usize, slots_per_row: usize) {
let row_bytes = slots_per_row * 8;
let tile_rows = (L3_TILE_BYTES / row_bytes).max(64);
let mappings = &self.mappings;
let file_encoding = self.file_encoding;
let mut row_offset = 0;
while row_offset < num_rows {
let n = (num_rows - row_offset).min(tile_rows);
let tile_start = row_offset * row_bytes;
self.builders
.par_iter_mut()
.enumerate()
.for_each(|(i, builder)| {
let mapping = &mappings[i];
match (&mapping.var_type, builder) {
(VarType::Numeric, ColBuilder::Float64(b)) => {
process_numeric_rows(
b,
chunk,
tile_start,
n,
row_bytes,
mapping.slot_index,
);
}
(VarType::String(_), ColBuilder::Str(b)) => {
let mut local_buf = Vec::with_capacity(256);
process_string_rows(
b,
&mut local_buf,
chunk,
tile_start,
n,
row_bytes,
slots_per_row,
mapping,
file_encoding,
);
}
_ => unreachable!(),
}
});
row_offset += n;
}
self.rows_appended += num_rows;
}
pub fn finish(self) -> Result<RecordBatch> {
let mut columns: Vec<ArrayRef> = self
.builders
.into_iter()
.map(|b| -> ArrayRef {
match b {
ColBuilder::Float64(mut b) => Arc::new(b.finish()),
ColBuilder::Str(mut b) => Arc::new(b.finish()),
}
})
.collect();
for &(col_idx, kind) in &self.temporal_columns {
let float_arr = columns[col_idx]
.as_any()
.downcast_ref::<Float64Array>()
.expect("temporal column should be Float64Array");
columns[col_idx] = convert_float64_to_temporal(float_arr, kind);
}
let batch = RecordBatch::try_new(self.schema, columns)?;
Ok(batch)
}
pub fn len(&self) -> usize {
self.rows_appended
}
}
#[inline(always)]
fn process_numeric_rows(
builder: &mut Float64Builder,
chunk: &[u8],
base_offset: usize,
num_rows: usize,
row_bytes: usize,
slot_index: usize,
) {
let slot_offset = slot_index * 8;
for row in 0..num_rows {
let offset = base_offset + row * row_bytes + slot_offset;
let val = f64::from_le_bytes(unsafe { *(chunk.as_ptr().add(offset) as *const [u8; 8]) });
if is_sysmis(val) {
builder.append_null();
} else {
builder.append_value(val);
}
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn process_string_rows(
builder: &mut StringViewBuilder,
string_buf: &mut Vec<u8>,
chunk: &[u8],
base_offset: usize,
num_rows: usize,
row_bytes: usize,
slots_per_row: usize,
mapping: &ColumnMapping,
file_encoding: &'static Encoding,
) {
let width = match &mapping.var_type {
VarType::String(w) => *w,
_ => unreachable!(),
};
for row in 0..num_rows {
let row_start = base_offset + row * row_bytes;
let raw_slots: &[[u8; 8]] = unsafe {
std::slice::from_raw_parts(chunk[row_start..].as_ptr() as *const [u8; 8], slots_per_row)
};
push_string_from_raw_slots(
builder,
string_buf,
raw_slots,
mapping.slot_index,
width,
mapping.n_segments,
&mapping.vls_layout,
file_encoding,
);
}
}
#[inline(never)]
fn convert_float64_to_temporal(arr: &Float64Array, kind: TemporalKind) -> ArrayRef {
let nulls = arr.nulls().cloned();
let values = arr.values();
match kind {
TemporalKind::Date => {
let converted: Vec<i32> = values
.iter()
.map(|&v| (v / SECONDS_PER_DAY - SPSS_EPOCH_OFFSET_DAYS as f64) as i32)
.collect();
Arc::new(Date32Array::new(converted.into(), nulls))
}
TemporalKind::Timestamp => {
let converted: Vec<i64> = values
.iter()
.map(|&v| ((v - SPSS_EPOCH_OFFSET_SECONDS) * MICROS_PER_SECOND) as i64)
.collect();
Arc::new(TimestampMicrosecondArray::new(converted.into(), nulls))
}
TemporalKind::Duration => {
let converted: Vec<i64> = values
.iter()
.map(|&v| (v * MICROS_PER_SECOND) as i64)
.collect();
Arc::new(DurationMicrosecondArray::new(converted.into(), nulls))
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn push_string_from_raw_slots(
builder: &mut StringViewBuilder,
string_buf: &mut Vec<u8>,
raw_slots: &[[u8; 8]],
start_slot: usize,
width: usize,
n_segments: usize,
vls_layout: &[VlsSegmentInfo],
file_encoding: &'static Encoding,
) {
string_buf.clear();
if n_segments <= 1 {
let n_slots = width.div_ceil(8);
for i in 0..n_slots {
let idx = start_slot + i;
if idx < raw_slots.len() {
string_buf.extend_from_slice(&raw_slots[idx]);
}
}
} else {
let mut slot = start_slot;
let mut cumulative = 0;
for seg_info in vls_layout {
cumulative += seg_info.useful_bytes;
let slots_to_read = seg_info.useful_bytes.div_ceil(8);
for i in 0..slots_to_read {
if slot + i < raw_slots.len() {
string_buf.extend_from_slice(&raw_slots[slot + i]);
}
}
string_buf.truncate(cumulative);
slot += 32;
}
}
string_buf.truncate(width);
let trimmed = io_utils::trim_trailing_padding(string_buf);
let decoded = encoding::decode_str_lossy(trimmed, file_encoding);
builder.append_value(&*decoded);
}