use crate::error::Error::RowConvertError;
use crate::error::{Error, Result};
use crate::metadata::DataType;
use crate::row::InternalRow;
use crate::row::datum::{
MICROS_PER_MILLI, MILLIS_PER_SECOND, NANOS_PER_MILLI, append_decimal_to_builder,
millis_nanos_to_micros, millis_nanos_to_nanos,
};
use arrow::array::{
ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
FixedSizeBinaryBuilder, Float32Builder, Float64Builder, Int8Builder, Int16Builder,
Int32Builder, Int64Builder, StringBuilder, Time32MillisecondBuilder, Time32SecondBuilder,
Time64MicrosecondBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder,
TimestampMillisecondBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder,
};
use arrow_schema::DataType as ArrowDataType;
const VARIABLE_WIDTH_AVG_BYTES: usize = 64;
pub struct ColumnWriter {
pos: usize,
nullable: bool,
inner: TypedWriter,
}
enum TypedWriter {
Bool(BooleanBuilder),
Int8(Int8Builder),
Int16(Int16Builder),
Int32(Int32Builder),
Int64(Int64Builder),
Float32(Float32Builder),
Float64(Float64Builder),
Char {
len: usize,
builder: StringBuilder,
},
String(StringBuilder),
Bytes(BinaryBuilder),
Binary {
len: usize,
builder: FixedSizeBinaryBuilder,
},
Decimal128 {
src_precision: usize,
src_scale: usize,
target_precision: u32,
target_scale: i64,
builder: Decimal128Builder,
},
Date32(Date32Builder),
Time32Second(Time32SecondBuilder),
Time32Millisecond(Time32MillisecondBuilder),
Time64Microsecond(Time64MicrosecondBuilder),
Time64Nanosecond(Time64NanosecondBuilder),
TimestampNtzSecond {
precision: u32,
builder: TimestampSecondBuilder,
},
TimestampNtzMillisecond {
precision: u32,
builder: TimestampMillisecondBuilder,
},
TimestampNtzMicrosecond {
precision: u32,
builder: TimestampMicrosecondBuilder,
},
TimestampNtzNanosecond {
precision: u32,
builder: TimestampNanosecondBuilder,
},
TimestampLtzSecond {
precision: u32,
builder: TimestampSecondBuilder,
},
TimestampLtzMillisecond {
precision: u32,
builder: TimestampMillisecondBuilder,
},
TimestampLtzMicrosecond {
precision: u32,
builder: TimestampMicrosecondBuilder,
},
TimestampLtzNanosecond {
precision: u32,
builder: TimestampNanosecondBuilder,
},
}
macro_rules! with_builder {
($self:expr, $b:ident => $body:expr) => {
match $self {
TypedWriter::Bool($b) => $body,
TypedWriter::Int8($b) => $body,
TypedWriter::Int16($b) => $body,
TypedWriter::Int32($b) => $body,
TypedWriter::Int64($b) => $body,
TypedWriter::Float32($b) => $body,
TypedWriter::Float64($b) => $body,
TypedWriter::Char { builder: $b, .. } => $body,
TypedWriter::String($b) => $body,
TypedWriter::Bytes($b) => $body,
TypedWriter::Binary { builder: $b, .. } => $body,
TypedWriter::Decimal128 { builder: $b, .. } => $body,
TypedWriter::Date32($b) => $body,
TypedWriter::Time32Second($b) => $body,
TypedWriter::Time32Millisecond($b) => $body,
TypedWriter::Time64Microsecond($b) => $body,
TypedWriter::Time64Nanosecond($b) => $body,
TypedWriter::TimestampNtzSecond { builder: $b, .. } => $body,
TypedWriter::TimestampNtzMillisecond { builder: $b, .. } => $body,
TypedWriter::TimestampNtzMicrosecond { builder: $b, .. } => $body,
TypedWriter::TimestampNtzNanosecond { builder: $b, .. } => $body,
TypedWriter::TimestampLtzSecond { builder: $b, .. } => $body,
TypedWriter::TimestampLtzMillisecond { builder: $b, .. } => $body,
TypedWriter::TimestampLtzMicrosecond { builder: $b, .. } => $body,
TypedWriter::TimestampLtzNanosecond { builder: $b, .. } => $body,
}
};
}
impl ColumnWriter {
pub fn create(
fluss_type: &DataType,
arrow_type: &ArrowDataType,
pos: usize,
capacity: usize,
) -> Result<Self> {
let nullable = fluss_type.is_nullable();
let inner = match fluss_type {
DataType::Boolean(_) => TypedWriter::Bool(BooleanBuilder::with_capacity(capacity)),
DataType::TinyInt(_) => TypedWriter::Int8(Int8Builder::with_capacity(capacity)),
DataType::SmallInt(_) => TypedWriter::Int16(Int16Builder::with_capacity(capacity)),
DataType::Int(_) => TypedWriter::Int32(Int32Builder::with_capacity(capacity)),
DataType::BigInt(_) => TypedWriter::Int64(Int64Builder::with_capacity(capacity)),
DataType::Float(_) => TypedWriter::Float32(Float32Builder::with_capacity(capacity)),
DataType::Double(_) => TypedWriter::Float64(Float64Builder::with_capacity(capacity)),
DataType::Char(t) => TypedWriter::Char {
len: t.length() as usize,
builder: StringBuilder::with_capacity(
capacity,
capacity.saturating_mul(VARIABLE_WIDTH_AVG_BYTES),
),
},
DataType::String(_) => TypedWriter::String(StringBuilder::with_capacity(
capacity,
capacity.saturating_mul(VARIABLE_WIDTH_AVG_BYTES),
)),
DataType::Bytes(_) => TypedWriter::Bytes(BinaryBuilder::with_capacity(
capacity,
capacity.saturating_mul(VARIABLE_WIDTH_AVG_BYTES),
)),
DataType::Binary(t) => {
let arrow_len: i32 = t.length().try_into().map_err(|_| Error::IllegalArgument {
message: format!(
"Binary length {} exceeds Arrow's maximum (i32::MAX)",
t.length()
),
})?;
TypedWriter::Binary {
len: t.length(),
builder: FixedSizeBinaryBuilder::with_capacity(capacity, arrow_len),
}
}
DataType::Decimal(dt) => {
let (target_p, target_s) = match arrow_type {
ArrowDataType::Decimal128(p, s) => (*p, *s),
_ => {
return Err(Error::IllegalArgument {
message: format!(
"Expected Decimal128 Arrow type for Decimal, got: {arrow_type:?}"
),
});
}
};
if target_s < 0 {
return Err(Error::IllegalArgument {
message: format!("Negative decimal scale {target_s} is not supported"),
});
}
let builder = Decimal128Builder::with_capacity(capacity)
.with_precision_and_scale(target_p, target_s)
.map_err(|e| Error::IllegalArgument {
message: format!(
"Invalid decimal precision {target_p} or scale {target_s}: {e}"
),
})?;
TypedWriter::Decimal128 {
src_precision: dt.precision() as usize,
src_scale: dt.scale() as usize,
target_precision: target_p as u32,
target_scale: target_s as i64,
builder,
}
}
DataType::Date(_) => TypedWriter::Date32(Date32Builder::with_capacity(capacity)),
DataType::Time(_) => match arrow_type {
ArrowDataType::Time32(arrow_schema::TimeUnit::Second) => {
TypedWriter::Time32Second(Time32SecondBuilder::with_capacity(capacity))
}
ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond) => {
TypedWriter::Time32Millisecond(Time32MillisecondBuilder::with_capacity(
capacity,
))
}
ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond) => {
TypedWriter::Time64Microsecond(Time64MicrosecondBuilder::with_capacity(
capacity,
))
}
ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond) => {
TypedWriter::Time64Nanosecond(Time64NanosecondBuilder::with_capacity(capacity))
}
_ => {
return Err(Error::IllegalArgument {
message: format!("Unsupported Arrow type for Time: {arrow_type:?}"),
});
}
},
DataType::Timestamp(t) => {
let precision = t.precision();
match arrow_type {
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, _) => {
TypedWriter::TimestampNtzSecond {
precision,
builder: TimestampSecondBuilder::with_capacity(capacity),
}
}
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
TypedWriter::TimestampNtzMillisecond {
precision,
builder: TimestampMillisecondBuilder::with_capacity(capacity),
}
}
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => {
TypedWriter::TimestampNtzMicrosecond {
precision,
builder: TimestampMicrosecondBuilder::with_capacity(capacity),
}
}
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _) => {
TypedWriter::TimestampNtzNanosecond {
precision,
builder: TimestampNanosecondBuilder::with_capacity(capacity),
}
}
_ => {
return Err(Error::IllegalArgument {
message: format!(
"Unsupported Arrow type for Timestamp: {arrow_type:?}"
),
});
}
}
}
DataType::TimestampLTz(t) => {
let precision = t.precision();
match arrow_type {
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, _) => {
TypedWriter::TimestampLtzSecond {
precision,
builder: TimestampSecondBuilder::with_capacity(capacity),
}
}
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
TypedWriter::TimestampLtzMillisecond {
precision,
builder: TimestampMillisecondBuilder::with_capacity(capacity),
}
}
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => {
TypedWriter::TimestampLtzMicrosecond {
precision,
builder: TimestampMicrosecondBuilder::with_capacity(capacity),
}
}
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _) => {
TypedWriter::TimestampLtzNanosecond {
precision,
builder: TimestampNanosecondBuilder::with_capacity(capacity),
}
}
_ => {
return Err(Error::IllegalArgument {
message: format!(
"Unsupported Arrow type for TimestampLTz: {arrow_type:?}"
),
});
}
}
}
_ => {
return Err(Error::IllegalArgument {
message: format!("Unsupported Fluss DataType: {fluss_type:?}"),
});
}
};
Ok(Self {
pos,
nullable,
inner,
})
}
#[inline]
pub fn write_field(&mut self, row: &dyn InternalRow) -> Result<()> {
if self.nullable && row.is_null_at(self.pos)? {
self.append_null();
return Ok(());
}
self.write_non_null(row)
}
pub fn finish(&mut self) -> ArrayRef {
self.as_builder_mut().finish()
}
pub fn finish_cloned(&self) -> ArrayRef {
self.as_builder_ref().finish_cloned()
}
fn append_null(&mut self) {
with_builder!(&mut self.inner, b => b.append_null());
}
fn as_builder_mut(&mut self) -> &mut dyn ArrayBuilder {
with_builder!(&mut self.inner, b => b)
}
fn as_builder_ref(&self) -> &dyn ArrayBuilder {
with_builder!(&self.inner, b => b)
}
#[inline]
fn write_non_null(&mut self, row: &dyn InternalRow) -> Result<()> {
let pos = self.pos;
match &mut self.inner {
TypedWriter::Bool(b) => {
b.append_value(row.get_boolean(pos)?);
Ok(())
}
TypedWriter::Int8(b) => {
b.append_value(row.get_byte(pos)?);
Ok(())
}
TypedWriter::Int16(b) => {
b.append_value(row.get_short(pos)?);
Ok(())
}
TypedWriter::Int32(b) => {
b.append_value(row.get_int(pos)?);
Ok(())
}
TypedWriter::Int64(b) => {
b.append_value(row.get_long(pos)?);
Ok(())
}
TypedWriter::Float32(b) => {
b.append_value(row.get_float(pos)?);
Ok(())
}
TypedWriter::Float64(b) => {
b.append_value(row.get_double(pos)?);
Ok(())
}
TypedWriter::Char { len, builder } => {
let v = row.get_char(pos, *len)?;
builder.append_value(v);
Ok(())
}
TypedWriter::String(b) => {
let v = row.get_string(pos)?;
b.append_value(v);
Ok(())
}
TypedWriter::Bytes(b) => {
let v = row.get_bytes(pos)?;
b.append_value(v);
Ok(())
}
TypedWriter::Binary { len, builder } => {
let v = row.get_binary(pos, *len)?;
builder.append_value(v).map_err(|e| RowConvertError {
message: format!("Failed to append binary value: {e}"),
})?;
Ok(())
}
TypedWriter::Decimal128 {
src_precision,
src_scale,
target_precision,
target_scale,
builder,
} => {
let decimal = row.get_decimal(pos, *src_precision, *src_scale)?;
append_decimal_to_builder(&decimal, *target_precision, *target_scale, builder)
}
TypedWriter::Date32(b) => {
let date = row.get_date(pos)?;
b.append_value(date.get_inner());
Ok(())
}
TypedWriter::Time32Second(b) => {
let millis = row.get_time(pos)?.get_inner();
if millis % MILLIS_PER_SECOND as i32 != 0 {
return Err(RowConvertError {
message: format!(
"Time value {millis} ms has sub-second precision but schema expects seconds only"
),
});
}
b.append_value(millis / MILLIS_PER_SECOND as i32);
Ok(())
}
TypedWriter::Time32Millisecond(b) => {
b.append_value(row.get_time(pos)?.get_inner());
Ok(())
}
TypedWriter::Time64Microsecond(b) => {
let millis = row.get_time(pos)?.get_inner();
let micros = (millis as i64)
.checked_mul(MICROS_PER_MILLI)
.ok_or_else(|| RowConvertError {
message: format!(
"Time value {millis} ms overflows when converting to microseconds"
),
})?;
b.append_value(micros);
Ok(())
}
TypedWriter::Time64Nanosecond(b) => {
let millis = row.get_time(pos)?.get_inner();
let nanos = (millis as i64)
.checked_mul(NANOS_PER_MILLI)
.ok_or_else(|| RowConvertError {
message: format!(
"Time value {millis} ms overflows when converting to nanoseconds"
),
})?;
b.append_value(nanos);
Ok(())
}
TypedWriter::TimestampNtzSecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ntz(pos, *precision)?;
builder.append_value(ts.get_millisecond() / MILLIS_PER_SECOND);
Ok(())
}
TypedWriter::TimestampNtzMillisecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ntz(pos, *precision)?;
builder.append_value(ts.get_millisecond());
Ok(())
}
TypedWriter::TimestampNtzMicrosecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ntz(pos, *precision)?;
builder.append_value(millis_nanos_to_micros(
ts.get_millisecond(),
ts.get_nano_of_millisecond(),
)?);
Ok(())
}
TypedWriter::TimestampNtzNanosecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ntz(pos, *precision)?;
builder.append_value(millis_nanos_to_nanos(
ts.get_millisecond(),
ts.get_nano_of_millisecond(),
)?);
Ok(())
}
TypedWriter::TimestampLtzSecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ltz(pos, *precision)?;
builder.append_value(ts.get_epoch_millisecond() / MILLIS_PER_SECOND);
Ok(())
}
TypedWriter::TimestampLtzMillisecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ltz(pos, *precision)?;
builder.append_value(ts.get_epoch_millisecond());
Ok(())
}
TypedWriter::TimestampLtzMicrosecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ltz(pos, *precision)?;
builder.append_value(millis_nanos_to_micros(
ts.get_epoch_millisecond(),
ts.get_nano_of_millisecond(),
)?);
Ok(())
}
TypedWriter::TimestampLtzNanosecond {
precision, builder, ..
} => {
let ts = row.get_timestamp_ltz(pos, *precision)?;
builder.append_value(millis_nanos_to_nanos(
ts.get_epoch_millisecond(),
ts.get_nano_of_millisecond(),
)?);
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::DataTypes;
use crate::record::to_arrow_type;
use crate::row::{Date, Datum, GenericRow, Time, TimestampLtz, TimestampNtz};
use arrow::array::*;
use bigdecimal::BigDecimal;
use std::str::FromStr;
fn writer_for(fluss_type: &DataType, capacity: usize) -> ColumnWriter {
let arrow_type = to_arrow_type(fluss_type).unwrap();
ColumnWriter::create(fluss_type, &arrow_type, 0, capacity).unwrap()
}
fn write_one(fluss_type: &DataType, datum: Datum) -> ArrayRef {
let mut w = writer_for(fluss_type, 4);
w.write_field(&GenericRow::from_data(vec![datum])).unwrap();
w.finish()
}
#[test]
fn write_all_scalar_types() {
let arr = write_one(&DataTypes::boolean(), Datum::Bool(true));
assert!(
arr.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.value(0)
);
let arr = write_one(&DataTypes::tinyint(), Datum::Int8(42));
assert_eq!(
arr.as_any().downcast_ref::<Int8Array>().unwrap().value(0),
42
);
let arr = write_one(&DataTypes::smallint(), Datum::Int16(1000));
assert_eq!(
arr.as_any().downcast_ref::<Int16Array>().unwrap().value(0),
1000
);
let arr = write_one(&DataTypes::int(), Datum::Int32(100_000));
assert_eq!(
arr.as_any().downcast_ref::<Int32Array>().unwrap().value(0),
100_000
);
let arr = write_one(&DataTypes::bigint(), Datum::Int64(9_000_000_000));
assert_eq!(
arr.as_any().downcast_ref::<Int64Array>().unwrap().value(0),
9_000_000_000
);
let arr = write_one(&DataTypes::float(), Datum::Float32(1.5.into()));
assert!(
(arr.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.value(0)
- 1.5)
.abs()
< 0.001
);
let arr = write_one(&DataTypes::double(), Datum::Float64(1.125.into()));
assert!(
(arr.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0)
- 1.125)
.abs()
< 0.001
);
let arr = write_one(&DataTypes::string(), Datum::String("hello".into()));
assert_eq!(
arr.as_any().downcast_ref::<StringArray>().unwrap().value(0),
"hello"
);
let arr = write_one(&DataTypes::char(10), Datum::String("world".into()));
assert_eq!(
arr.as_any().downcast_ref::<StringArray>().unwrap().value(0),
"world"
);
let arr = write_one(&DataTypes::bytes(), Datum::Blob(vec![1, 2, 3].into()));
assert_eq!(
arr.as_any().downcast_ref::<BinaryArray>().unwrap().value(0),
&[1, 2, 3]
);
let arr = write_one(
&DataTypes::binary(4),
Datum::Blob(vec![10, 20, 30, 40].into()),
);
assert_eq!(
arr.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.unwrap()
.value(0),
&[10, 20, 30, 40]
);
let arr = write_one(&DataTypes::date(), Datum::Date(Date::new(19000)));
assert_eq!(
arr.as_any().downcast_ref::<Date32Array>().unwrap().value(0),
19000
);
let arr = write_one(
&DataTypes::time_with_precision(3),
Datum::Time(Time::new(45_000)),
);
assert_eq!(
arr.as_any()
.downcast_ref::<Time32MillisecondArray>()
.unwrap()
.value(0),
45_000
);
let decimal =
crate::row::Decimal::from_big_decimal(BigDecimal::from_str("123.45").unwrap(), 10, 2)
.unwrap();
let arr = write_one(&DataTypes::decimal(10, 2), Datum::Decimal(decimal));
assert_eq!(
arr.as_any()
.downcast_ref::<Decimal128Array>()
.unwrap()
.value(0),
12345
);
let arr = write_one(
&DataTypes::timestamp_with_precision(3),
Datum::TimestampNtz(TimestampNtz::new(1_700_000_000_000)),
);
assert_eq!(
arr.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap()
.value(0),
1_700_000_000_000
);
let arr = write_one(
&DataTypes::timestamp_ltz_with_precision(3),
Datum::TimestampLtz(TimestampLtz::new(1_700_000_000_000)),
);
assert_eq!(
arr.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap()
.value(0),
1_700_000_000_000
);
}
#[test]
fn write_null_and_multiple_rows() {
let arr = write_one(&DataTypes::int(), Datum::Null);
assert!(arr.is_null(0));
let mut w = writer_for(&DataTypes::int(), 8);
for val in [10, 20, 30] {
w.write_field(&GenericRow::from_data(vec![val])).unwrap();
}
let arr = w.finish();
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(int_arr.len(), 3);
assert_eq!(int_arr.value(0), 10);
assert_eq!(int_arr.value(1), 20);
assert_eq!(int_arr.value(2), 30);
let mut w = writer_for(&DataTypes::int(), 4);
w.write_field(&GenericRow::from_data(vec![42_i32])).unwrap();
assert_eq!(w.finish_cloned().len(), 1);
w.write_field(&GenericRow::from_data(vec![99_i32])).unwrap();
let int_arr = w
.finish()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.clone();
assert_eq!((int_arr.value(0), int_arr.value(1)), (42, 99));
}
#[test]
fn unsupported_type_returns_error() {
let fluss_type = DataTypes::array(DataTypes::int());
let arrow_type = ArrowDataType::List(arrow_schema::FieldRef::new(
arrow_schema::Field::new("item", ArrowDataType::Int32, true),
));
assert!(ColumnWriter::create(&fluss_type, &arrow_type, 0, 4).is_err());
}
}