use super::{
Column,
ColumnInt128,
ColumnInt32,
ColumnInt64,
ColumnRef,
};
use crate::{
types::Type,
Error,
Result,
};
use bytes::BytesMut;
use std::sync::Arc;
pub struct ColumnDecimal {
type_: Type,
precision: usize,
scale: usize,
data: ColumnRef, }
impl ColumnDecimal {
pub fn new(type_: Type) -> Self {
let (precision, scale) = match &type_ {
Type::Decimal { precision, scale } => (*precision, *scale),
_ => panic!("ColumnDecimal requires Decimal type"),
};
let data: ColumnRef = if precision <= 9 {
Arc::new(ColumnInt32::new())
} else if precision <= 18 {
Arc::new(ColumnInt64::new())
} else {
Arc::new(ColumnInt128::new())
};
Self { type_, precision, scale, data }
}
pub fn with_data(mut self, data: Vec<i128>) -> Self {
if self.precision <= 9 {
let mut col = ColumnInt32::new();
for value in data {
col.append(value as i32);
}
self.data = Arc::new(col);
} else if self.precision <= 18 {
let mut col = ColumnInt64::new();
for value in data {
col.append(value as i64);
}
self.data = Arc::new(col);
} else {
let mut col = ColumnInt128::new();
for value in data {
col.append(value);
}
self.data = Arc::new(col);
}
self
}
pub fn append_from_string(&mut self, s: &str) -> Result<()> {
let value = parse_decimal(s, self.scale)?;
self.append(value);
Ok(())
}
pub fn append(&mut self, value: i128) {
let data_mut =
Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
if self.precision <= 9 {
let col = data_mut
.as_any_mut()
.downcast_mut::<ColumnInt32>()
.expect("Expected ColumnInt32");
col.append(value as i32);
} else if self.precision <= 18 {
let col = data_mut
.as_any_mut()
.downcast_mut::<ColumnInt64>()
.expect("Expected ColumnInt64");
col.append(value as i64);
} else {
let col = data_mut
.as_any_mut()
.downcast_mut::<ColumnInt128>()
.expect("Expected ColumnInt128");
col.append(value);
}
}
pub fn at(&self, index: usize) -> i128 {
if self.precision <= 9 {
let col = self
.data
.as_any()
.downcast_ref::<ColumnInt32>()
.expect("Expected ColumnInt32");
col.at(index) as i128
} else if self.precision <= 18 {
let col = self
.data
.as_any()
.downcast_ref::<ColumnInt64>()
.expect("Expected ColumnInt64");
col.at(index) as i128
} else {
let col = self
.data
.as_any()
.downcast_ref::<ColumnInt128>()
.expect("Expected ColumnInt128");
col.at(index)
}
}
pub fn as_string(&self, index: usize) -> String {
format_decimal(self.at(index), self.scale)
}
pub fn precision(&self) -> usize {
self.precision
}
pub fn scale(&self) -> usize {
self.scale
}
pub fn len(&self) -> usize {
self.data.size()
}
pub fn is_empty(&self) -> bool {
self.data.size() == 0
}
}
impl Column for ColumnDecimal {
fn column_type(&self) -> &Type {
&self.type_
}
fn size(&self) -> usize {
self.data.size()
}
fn clear(&mut self) {
let data_mut =
Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
data_mut.clear();
}
fn reserve(&mut self, new_cap: usize) {
let data_mut =
Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
data_mut.reserve(new_cap);
}
fn append_column(&mut self, other: ColumnRef) -> Result<()> {
let other = other
.as_any()
.downcast_ref::<ColumnDecimal>()
.ok_or_else(|| Error::TypeMismatch {
expected: self.type_.name(),
actual: other.column_type().name(),
})?;
if self.precision != other.precision || self.scale != other.scale {
return Err(Error::TypeMismatch {
expected: format!(
"Decimal({}, {})",
self.precision, self.scale
),
actual: format!(
"Decimal({}, {})",
other.precision, other.scale
),
});
}
let data_mut =
Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
data_mut.append_column(other.data.clone())?;
Ok(())
}
fn load_from_buffer(
&mut self,
buffer: &mut &[u8],
rows: usize,
) -> Result<()> {
let data_mut =
Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
data_mut.load_from_buffer(buffer, rows)
}
fn save_to_buffer(&self, buffer: &mut BytesMut) -> Result<()> {
self.data.save_to_buffer(buffer)
}
fn clone_empty(&self) -> ColumnRef {
Arc::new(ColumnDecimal::new(self.type_.clone()))
}
fn slice(&self, begin: usize, len: usize) -> Result<ColumnRef> {
if begin + len > self.data.size() {
return Err(Error::InvalidArgument(format!(
"Slice out of bounds: begin={}, len={}, size={}",
begin,
len,
self.data.size()
)));
}
let sliced_data = self.data.slice(begin, len)?;
let mut result = ColumnDecimal::new(self.type_.clone());
result.data = sliced_data;
Ok(Arc::new(result))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
fn parse_decimal(s: &str, scale: usize) -> Result<i128> {
let s = s.trim();
let (sign, s) = if let Some(stripped) = s.strip_prefix('-') {
(-1, stripped)
} else if let Some(stripped) = s.strip_prefix('+') {
(1, stripped)
} else {
(1, s)
};
let parts: Vec<&str> = s.split('.').collect();
if parts.len() > 2 {
return Err(Error::Protocol(format!("Invalid decimal format: {}", s)));
}
let integer_part = parts[0].parse::<i128>().map_err(|e| {
Error::Protocol(format!("Invalid decimal integer part: {}", e))
})?;
let fractional_part = if parts.len() == 2 {
let frac_str = parts[1];
if frac_str.len() > scale {
return Err(Error::Protocol(format!(
"Decimal fractional part exceeds scale: {} > {}",
frac_str.len(),
scale
)));
}
let mut padded = frac_str.to_string();
while padded.len() < scale {
padded.push('0');
}
padded.parse::<i128>().map_err(|e| {
Error::Protocol(format!("Invalid decimal fractional part: {}", e))
})?
} else {
0
};
let scale_multiplier = 10_i128.pow(scale as u32);
let scaled_value = integer_part * scale_multiplier + fractional_part;
Ok(sign * scaled_value)
}
fn format_decimal(value: i128, scale: usize) -> String {
let (sign, abs_value) =
if value < 0 { ("-", -value) } else { ("", value) };
let scale_divisor = 10_i128.pow(scale as u32);
let integer_part = abs_value / scale_divisor;
let fractional_part = abs_value % scale_divisor;
if scale > 0 {
format!(
"{}{}.{:0width$}",
sign,
integer_part,
fractional_part,
width = scale
)
} else {
format!("{}{}", sign, integer_part)
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_parse_decimal() {
assert_eq!(parse_decimal("123.45", 2).unwrap(), 12345);
assert_eq!(parse_decimal("123", 2).unwrap(), 12300);
assert_eq!(parse_decimal("0.5", 2).unwrap(), 50);
assert_eq!(parse_decimal("-123.45", 2).unwrap(), -12345);
}
#[test]
fn test_format_decimal() {
assert_eq!(format_decimal(12345, 2), "123.45");
assert_eq!(format_decimal(12300, 2), "123.00");
assert_eq!(format_decimal(50, 2), "0.50");
assert_eq!(format_decimal(-12345, 2), "-123.45");
assert_eq!(format_decimal(123, 0), "123");
}
#[test]
fn test_decimal_column() {
let mut col = ColumnDecimal::new(Type::decimal(9, 2));
col.append_from_string("123.45").unwrap();
col.append_from_string("-56.78").unwrap();
col.append_from_string("0.01").unwrap();
assert_eq!(col.len(), 3);
assert_eq!(col.as_string(0), "123.45");
assert_eq!(col.as_string(1), "-56.78");
assert_eq!(col.as_string(2), "0.01");
}
#[test]
fn test_decimal_precision_scale() {
let col = ColumnDecimal::new(Type::decimal(18, 4));
assert_eq!(col.precision(), 18);
assert_eq!(col.scale(), 4);
}
#[test]
fn test_decimal_uses_int32_for_precision_9() {
let col = ColumnDecimal::new(Type::decimal(9, 2));
assert!(col.data.as_any().is::<ColumnInt32>());
let int32_col = col.data.as_any().downcast_ref::<ColumnInt32>();
assert!(int32_col.is_some(), "Expected ColumnInt32 for precision 9");
}
#[test]
fn test_decimal_uses_int64_for_precision_18() {
let col = ColumnDecimal::new(Type::decimal(18, 4));
assert!(col.data.as_any().is::<ColumnInt64>());
let int64_col = col.data.as_any().downcast_ref::<ColumnInt64>();
assert!(int64_col.is_some(), "Expected ColumnInt64 for precision 18");
}
#[test]
fn test_decimal_uses_int128_for_precision_38() {
let col = ColumnDecimal::new(Type::decimal(38, 10));
assert!(col.data.as_any().is::<ColumnInt128>());
let int128_col = col.data.as_any().downcast_ref::<ColumnInt128>();
assert!(
int128_col.is_some(),
"Expected ColumnInt128 for precision 38"
);
}
#[test]
fn test_decimal_memory_efficiency() {
let mut col9 = ColumnDecimal::new(Type::decimal(9, 2));
for i in 0..1000 {
col9.append(i * 100);
}
let mut buf9 = BytesMut::new();
col9.save_to_buffer(&mut buf9).unwrap();
assert_eq!(
buf9.len(),
1000 * 4,
"Decimal(9,2) should use 4 bytes per value"
);
let mut col18 = ColumnDecimal::new(Type::decimal(18, 4));
for i in 0..1000 {
col18.append(i * 10000);
}
let mut buf18 = BytesMut::new();
col18.save_to_buffer(&mut buf18).unwrap();
assert_eq!(
buf18.len(),
1000 * 8,
"Decimal(18,4) should use 8 bytes per value"
);
let mut col38 = ColumnDecimal::new(Type::decimal(38, 10));
for i in 0..1000 {
col38.append(i * 1000000000);
}
let mut buf38 = BytesMut::new();
col38.save_to_buffer(&mut buf38).unwrap();
assert_eq!(
buf38.len(),
1000 * 16,
"Decimal(38,10) should use 16 bytes per value"
);
}
#[test]
fn test_decimal_bulk_copy_int32() {
let mut col = ColumnDecimal::new(Type::decimal(9, 2));
let test_values = vec![12345, -67890, 0, 100, -200];
for &val in &test_values {
col.append(val);
}
let mut buf = BytesMut::new();
col.save_to_buffer(&mut buf).unwrap();
let mut col2 = ColumnDecimal::new(Type::decimal(9, 2));
let mut reader = &buf[..];
col2.load_from_buffer(&mut reader, test_values.len()).unwrap();
assert_eq!(col2.len(), test_values.len());
for (i, &expected) in test_values.iter().enumerate() {
assert_eq!(col2.at(i), expected, "Value mismatch at index {}", i);
}
}
#[test]
fn test_decimal_bulk_copy_int64() {
let mut col = ColumnDecimal::new(Type::decimal(18, 4));
let test_values =
vec![1234567890123, -9876543210987, 0, 100000000, -200000000];
for &val in &test_values {
col.append(val);
}
let mut buf = BytesMut::new();
col.save_to_buffer(&mut buf).unwrap();
let mut col2 = ColumnDecimal::new(Type::decimal(18, 4));
let mut reader = &buf[..];
col2.load_from_buffer(&mut reader, test_values.len()).unwrap();
assert_eq!(col2.len(), test_values.len());
for (i, &expected) in test_values.iter().enumerate() {
assert_eq!(col2.at(i), expected, "Value mismatch at index {}", i);
}
}
#[test]
fn test_decimal_bulk_copy_int128() {
let mut col = ColumnDecimal::new(Type::decimal(38, 10));
let test_values = vec![
123456789012345678901234567890_i128,
-987654321098765432109876543210_i128,
0,
1000000000000000000,
-2000000000000000000,
];
for &val in &test_values {
col.append(val);
}
let mut buf = BytesMut::new();
col.save_to_buffer(&mut buf).unwrap();
let mut col2 = ColumnDecimal::new(Type::decimal(38, 10));
let mut reader = &buf[..];
col2.load_from_buffer(&mut reader, test_values.len()).unwrap();
assert_eq!(col2.len(), test_values.len());
for (i, &expected) in test_values.iter().enumerate() {
assert_eq!(col2.at(i), expected, "Value mismatch at index {}", i);
}
}
#[test]
fn test_decimal_bulk_copy_large_dataset() {
let mut col = ColumnDecimal::new(Type::decimal(9, 2));
for i in 0..10_000 {
col.append(i * 100);
}
let mut buf = BytesMut::new();
col.save_to_buffer(&mut buf).unwrap();
let mut col2 = ColumnDecimal::new(Type::decimal(9, 2));
let mut reader = &buf[..];
col2.load_from_buffer(&mut reader, 10_000).unwrap();
assert_eq!(col2.len(), 10_000);
assert_eq!(col2.at(0), 0);
assert_eq!(col2.at(5_000), 5_000 * 100);
assert_eq!(col2.at(9_999), 9_999 * 100);
}
#[test]
fn test_decimal_append_column() {
let mut col1 = ColumnDecimal::new(Type::decimal(9, 2));
col1.append(12345);
col1.append(67890);
let mut col2 = ColumnDecimal::new(Type::decimal(9, 2));
col2.append(11111);
col2.append(22222);
col1.append_column(Arc::new(col2)).unwrap();
assert_eq!(col1.len(), 4);
assert_eq!(col1.at(0), 12345);
assert_eq!(col1.at(1), 67890);
assert_eq!(col1.at(2), 11111);
assert_eq!(col1.at(3), 22222);
}
#[test]
fn test_decimal_slice() {
let mut col = ColumnDecimal::new(Type::decimal(18, 4));
for i in 0..10 {
col.append(i * 10000);
}
let sliced = col.slice(2, 5).unwrap();
assert_eq!(sliced.size(), 5);
let sliced_concrete =
sliced.as_any().downcast_ref::<ColumnDecimal>().unwrap();
assert_eq!(sliced_concrete.at(0), 2 * 10000);
assert_eq!(sliced_concrete.at(4), 6 * 10000);
}
#[test]
fn test_decimal_clear_and_reuse() {
let mut col = ColumnDecimal::new(Type::decimal(9, 2));
col.append(100);
col.append(200);
assert_eq!(col.len(), 2);
col.clear();
assert_eq!(col.len(), 0);
assert!(col.is_empty());
col.append(300);
col.append(400);
assert_eq!(col.len(), 2);
assert_eq!(col.at(0), 300);
assert_eq!(col.at(1), 400);
}
#[test]
fn test_decimal_with_data_constructor() {
let data = vec![100, 200, 300];
let col9 =
ColumnDecimal::new(Type::decimal(9, 2)).with_data(data.clone());
assert_eq!(col9.len(), 3);
assert_eq!(col9.at(0), 100);
assert!(col9.data.as_any().is::<ColumnInt32>());
let col18 =
ColumnDecimal::new(Type::decimal(18, 4)).with_data(data.clone());
assert_eq!(col18.len(), 3);
assert_eq!(col18.at(0), 100);
assert!(col18.data.as_any().is::<ColumnInt64>());
let col38 =
ColumnDecimal::new(Type::decimal(38, 10)).with_data(data.clone());
assert_eq!(col38.len(), 3);
assert_eq!(col38.at(0), 100);
assert!(col38.data.as_any().is::<ColumnInt128>());
}
}