1use super::{
2 Column,
3 ColumnInt128,
4 ColumnInt32,
5 ColumnInt64,
6 ColumnRef,
7};
8use crate::{
9 types::Type,
10 Error,
11 Result,
12};
13use bytes::BytesMut;
14use std::sync::Arc;
15
16pub struct ColumnDecimal {
24 type_: Type,
25 precision: usize,
26 scale: usize,
27 data: ColumnRef, }
29
30impl ColumnDecimal {
31 pub fn new(type_: Type) -> Self {
37 let (precision, scale) = match &type_ {
38 Type::Decimal { precision, scale } => (*precision, *scale),
39 _ => panic!("ColumnDecimal requires Decimal type"),
40 };
41
42 let data: ColumnRef = if precision <= 9 {
45 Arc::new(ColumnInt32::new())
46 } else if precision <= 18 {
47 Arc::new(ColumnInt64::new())
48 } else {
49 Arc::new(ColumnInt128::new())
50 };
51
52 Self { type_, precision, scale, data }
53 }
54
55 pub fn with_data(mut self, data: Vec<i128>) -> Self {
57 if self.precision <= 9 {
59 let mut col = ColumnInt32::new();
60 for value in data {
61 col.append(value as i32);
62 }
63 self.data = Arc::new(col);
64 } else if self.precision <= 18 {
65 let mut col = ColumnInt64::new();
66 for value in data {
67 col.append(value as i64);
68 }
69 self.data = Arc::new(col);
70 } else {
71 let mut col = ColumnInt128::new();
72 for value in data {
73 col.append(value);
74 }
75 self.data = Arc::new(col);
76 }
77 self
78 }
79
80 pub fn append_from_string(&mut self, s: &str) -> Result<()> {
87 let value = parse_decimal(s, self.scale)?;
88 self.append(value);
89 Ok(())
90 }
91
92 pub fn append(&mut self, value: i128) {
94 let data_mut =
96 Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
97
98 if self.precision <= 9 {
99 let col = data_mut
100 .as_any_mut()
101 .downcast_mut::<ColumnInt32>()
102 .expect("Expected ColumnInt32");
103 col.append(value as i32);
104 } else if self.precision <= 18 {
105 let col = data_mut
106 .as_any_mut()
107 .downcast_mut::<ColumnInt64>()
108 .expect("Expected ColumnInt64");
109 col.append(value as i64);
110 } else {
111 let col = data_mut
112 .as_any_mut()
113 .downcast_mut::<ColumnInt128>()
114 .expect("Expected ColumnInt128");
115 col.append(value);
116 }
117 }
118
119 pub fn at(&self, index: usize) -> i128 {
121 if self.precision <= 9 {
122 let col = self
123 .data
124 .as_any()
125 .downcast_ref::<ColumnInt32>()
126 .expect("Expected ColumnInt32");
127 col.at(index) as i128
128 } else if self.precision <= 18 {
129 let col = self
130 .data
131 .as_any()
132 .downcast_ref::<ColumnInt64>()
133 .expect("Expected ColumnInt64");
134 col.at(index) as i128
135 } else {
136 let col = self
137 .data
138 .as_any()
139 .downcast_ref::<ColumnInt128>()
140 .expect("Expected ColumnInt128");
141 col.at(index)
142 }
143 }
144
145 pub fn as_string(&self, index: usize) -> String {
147 format_decimal(self.at(index), self.scale)
148 }
149
150 pub fn precision(&self) -> usize {
152 self.precision
153 }
154
155 pub fn scale(&self) -> usize {
158 self.scale
159 }
160
161 pub fn len(&self) -> usize {
163 self.data.size()
164 }
165
166 pub fn is_empty(&self) -> bool {
168 self.data.size() == 0
169 }
170}
171
172impl Column for ColumnDecimal {
173 fn column_type(&self) -> &Type {
174 &self.type_
175 }
176
177 fn size(&self) -> usize {
178 self.data.size()
179 }
180
181 fn clear(&mut self) {
182 let data_mut =
183 Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
184 data_mut.clear();
185 }
186
187 fn reserve(&mut self, new_cap: usize) {
188 let data_mut =
189 Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
190 data_mut.reserve(new_cap);
191 }
192
193 fn append_column(&mut self, other: ColumnRef) -> Result<()> {
194 let other = other
195 .as_any()
196 .downcast_ref::<ColumnDecimal>()
197 .ok_or_else(|| Error::TypeMismatch {
198 expected: self.type_.name(),
199 actual: other.column_type().name(),
200 })?;
201
202 if self.precision != other.precision || self.scale != other.scale {
203 return Err(Error::TypeMismatch {
204 expected: format!(
205 "Decimal({}, {})",
206 self.precision, self.scale
207 ),
208 actual: format!(
209 "Decimal({}, {})",
210 other.precision, other.scale
211 ),
212 });
213 }
214
215 let data_mut =
217 Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
218 data_mut.append_column(other.data.clone())?;
219 Ok(())
220 }
221
222 fn load_from_buffer(
223 &mut self,
224 buffer: &mut &[u8],
225 rows: usize,
226 ) -> Result<()> {
227 let data_mut =
230 Arc::get_mut(&mut self.data).expect("Cannot modify shared column");
231 data_mut.load_from_buffer(buffer, rows)
232 }
233
234 fn save_to_buffer(&self, buffer: &mut BytesMut) -> Result<()> {
235 self.data.save_to_buffer(buffer)
238 }
239
240 fn clone_empty(&self) -> ColumnRef {
241 Arc::new(ColumnDecimal::new(self.type_.clone()))
242 }
243
244 fn slice(&self, begin: usize, len: usize) -> Result<ColumnRef> {
245 if begin + len > self.data.size() {
246 return Err(Error::InvalidArgument(format!(
247 "Slice out of bounds: begin={}, len={}, size={}",
248 begin,
249 len,
250 self.data.size()
251 )));
252 }
253
254 let sliced_data = self.data.slice(begin, len)?;
256 let mut result = ColumnDecimal::new(self.type_.clone());
257 result.data = sliced_data;
258 Ok(Arc::new(result))
259 }
260
261 fn as_any(&self) -> &dyn std::any::Any {
262 self
263 }
264
265 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
266 self
267 }
268}
269
270fn parse_decimal(s: &str, scale: usize) -> Result<i128> {
273 let s = s.trim();
274 let (sign, s) = if let Some(stripped) = s.strip_prefix('-') {
275 (-1, stripped)
276 } else if let Some(stripped) = s.strip_prefix('+') {
277 (1, stripped)
278 } else {
279 (1, s)
280 };
281
282 let parts: Vec<&str> = s.split('.').collect();
283 if parts.len() > 2 {
284 return Err(Error::Protocol(format!("Invalid decimal format: {}", s)));
285 }
286
287 let integer_part = parts[0].parse::<i128>().map_err(|e| {
288 Error::Protocol(format!("Invalid decimal integer part: {}", e))
289 })?;
290
291 let fractional_part = if parts.len() == 2 {
292 let frac_str = parts[1];
293 if frac_str.len() > scale {
294 return Err(Error::Protocol(format!(
295 "Decimal fractional part exceeds scale: {} > {}",
296 frac_str.len(),
297 scale
298 )));
299 }
300
301 let mut padded = frac_str.to_string();
303 while padded.len() < scale {
304 padded.push('0');
305 }
306
307 padded.parse::<i128>().map_err(|e| {
308 Error::Protocol(format!("Invalid decimal fractional part: {}", e))
309 })?
310 } else {
311 0
312 };
313
314 let scale_multiplier = 10_i128.pow(scale as u32);
316 let scaled_value = integer_part * scale_multiplier + fractional_part;
317
318 Ok(sign * scaled_value)
319}
320
321fn format_decimal(value: i128, scale: usize) -> String {
324 let (sign, abs_value) =
325 if value < 0 { ("-", -value) } else { ("", value) };
326
327 let scale_divisor = 10_i128.pow(scale as u32);
328 let integer_part = abs_value / scale_divisor;
329 let fractional_part = abs_value % scale_divisor;
330
331 if scale > 0 {
332 format!(
333 "{}{}.{:0width$}",
334 sign,
335 integer_part,
336 fractional_part,
337 width = scale
338 )
339 } else {
340 format!("{}{}", sign, integer_part)
341 }
342}
343
344#[cfg(test)]
345#[cfg_attr(coverage_nightly, coverage(off))]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_parse_decimal() {
351 assert_eq!(parse_decimal("123.45", 2).unwrap(), 12345);
352 assert_eq!(parse_decimal("123", 2).unwrap(), 12300);
353 assert_eq!(parse_decimal("0.5", 2).unwrap(), 50);
354 assert_eq!(parse_decimal("-123.45", 2).unwrap(), -12345);
355 }
356
357 #[test]
358 fn test_format_decimal() {
359 assert_eq!(format_decimal(12345, 2), "123.45");
360 assert_eq!(format_decimal(12300, 2), "123.00");
361 assert_eq!(format_decimal(50, 2), "0.50");
362 assert_eq!(format_decimal(-12345, 2), "-123.45");
363 assert_eq!(format_decimal(123, 0), "123");
364 }
365
366 #[test]
367 fn test_decimal_column() {
368 let mut col = ColumnDecimal::new(Type::decimal(9, 2));
369 col.append_from_string("123.45").unwrap();
370 col.append_from_string("-56.78").unwrap();
371 col.append_from_string("0.01").unwrap();
372
373 assert_eq!(col.len(), 3);
374 assert_eq!(col.as_string(0), "123.45");
375 assert_eq!(col.as_string(1), "-56.78");
376 assert_eq!(col.as_string(2), "0.01");
377 }
378
379 #[test]
380 fn test_decimal_precision_scale() {
381 let col = ColumnDecimal::new(Type::decimal(18, 4));
382 assert_eq!(col.precision(), 18);
383 assert_eq!(col.scale(), 4);
384 }
385
386 #[test]
389 fn test_decimal_uses_int32_for_precision_9() {
390 let col = ColumnDecimal::new(Type::decimal(9, 2));
392
393 assert!(col.data.as_any().is::<ColumnInt32>());
395
396 let int32_col = col.data.as_any().downcast_ref::<ColumnInt32>();
398 assert!(int32_col.is_some(), "Expected ColumnInt32 for precision 9");
399 }
400
401 #[test]
402 fn test_decimal_uses_int64_for_precision_18() {
403 let col = ColumnDecimal::new(Type::decimal(18, 4));
405
406 assert!(col.data.as_any().is::<ColumnInt64>());
408
409 let int64_col = col.data.as_any().downcast_ref::<ColumnInt64>();
411 assert!(int64_col.is_some(), "Expected ColumnInt64 for precision 18");
412 }
413
414 #[test]
415 fn test_decimal_uses_int128_for_precision_38() {
416 let col = ColumnDecimal::new(Type::decimal(38, 10));
418
419 assert!(col.data.as_any().is::<ColumnInt128>());
421
422 let int128_col = col.data.as_any().downcast_ref::<ColumnInt128>();
424 assert!(
425 int128_col.is_some(),
426 "Expected ColumnInt128 for precision 38"
427 );
428 }
429
430 #[test]
431 fn test_decimal_memory_efficiency() {
432 let mut col9 = ColumnDecimal::new(Type::decimal(9, 2));
434 for i in 0..1000 {
435 col9.append(i * 100);
436 }
437
438 let mut buf9 = BytesMut::new();
440 col9.save_to_buffer(&mut buf9).unwrap();
441 assert_eq!(
442 buf9.len(),
443 1000 * 4,
444 "Decimal(9,2) should use 4 bytes per value"
445 );
446
447 let mut col18 = ColumnDecimal::new(Type::decimal(18, 4));
449 for i in 0..1000 {
450 col18.append(i * 10000);
451 }
452
453 let mut buf18 = BytesMut::new();
454 col18.save_to_buffer(&mut buf18).unwrap();
455 assert_eq!(
456 buf18.len(),
457 1000 * 8,
458 "Decimal(18,4) should use 8 bytes per value"
459 );
460
461 let mut col38 = ColumnDecimal::new(Type::decimal(38, 10));
463 for i in 0..1000 {
464 col38.append(i * 1000000000);
465 }
466
467 let mut buf38 = BytesMut::new();
468 col38.save_to_buffer(&mut buf38).unwrap();
469 assert_eq!(
470 buf38.len(),
471 1000 * 16,
472 "Decimal(38,10) should use 16 bytes per value"
473 );
474 }
475
476 #[test]
477 fn test_decimal_bulk_copy_int32() {
478 let mut col = ColumnDecimal::new(Type::decimal(9, 2));
480
481 let test_values = vec![12345, -67890, 0, 100, -200];
483 for &val in &test_values {
484 col.append(val);
485 }
486
487 let mut buf = BytesMut::new();
489 col.save_to_buffer(&mut buf).unwrap();
490
491 let mut col2 = ColumnDecimal::new(Type::decimal(9, 2));
493 let mut reader = &buf[..];
494 col2.load_from_buffer(&mut reader, test_values.len()).unwrap();
495
496 assert_eq!(col2.len(), test_values.len());
498 for (i, &expected) in test_values.iter().enumerate() {
499 assert_eq!(col2.at(i), expected, "Value mismatch at index {}", i);
500 }
501 }
502
503 #[test]
504 fn test_decimal_bulk_copy_int64() {
505 let mut col = ColumnDecimal::new(Type::decimal(18, 4));
507
508 let test_values =
510 vec![1234567890123, -9876543210987, 0, 100000000, -200000000];
511 for &val in &test_values {
512 col.append(val);
513 }
514
515 let mut buf = BytesMut::new();
517 col.save_to_buffer(&mut buf).unwrap();
518
519 let mut col2 = ColumnDecimal::new(Type::decimal(18, 4));
521 let mut reader = &buf[..];
522 col2.load_from_buffer(&mut reader, test_values.len()).unwrap();
523
524 assert_eq!(col2.len(), test_values.len());
526 for (i, &expected) in test_values.iter().enumerate() {
527 assert_eq!(col2.at(i), expected, "Value mismatch at index {}", i);
528 }
529 }
530
531 #[test]
532 fn test_decimal_bulk_copy_int128() {
533 let mut col = ColumnDecimal::new(Type::decimal(38, 10));
535
536 let test_values = vec![
538 123456789012345678901234567890_i128,
539 -987654321098765432109876543210_i128,
540 0,
541 1000000000000000000,
542 -2000000000000000000,
543 ];
544 for &val in &test_values {
545 col.append(val);
546 }
547
548 let mut buf = BytesMut::new();
550 col.save_to_buffer(&mut buf).unwrap();
551
552 let mut col2 = ColumnDecimal::new(Type::decimal(38, 10));
554 let mut reader = &buf[..];
555 col2.load_from_buffer(&mut reader, test_values.len()).unwrap();
556
557 assert_eq!(col2.len(), test_values.len());
559 for (i, &expected) in test_values.iter().enumerate() {
560 assert_eq!(col2.at(i), expected, "Value mismatch at index {}", i);
561 }
562 }
563
564 #[test]
565 fn test_decimal_bulk_copy_large_dataset() {
566 let mut col = ColumnDecimal::new(Type::decimal(9, 2));
568
569 for i in 0..10_000 {
570 col.append(i * 100);
571 }
572
573 let mut buf = BytesMut::new();
575 col.save_to_buffer(&mut buf).unwrap();
576
577 let mut col2 = ColumnDecimal::new(Type::decimal(9, 2));
579 let mut reader = &buf[..];
580 col2.load_from_buffer(&mut reader, 10_000).unwrap();
581
582 assert_eq!(col2.len(), 10_000);
584 assert_eq!(col2.at(0), 0);
585 assert_eq!(col2.at(5_000), 5_000 * 100);
586 assert_eq!(col2.at(9_999), 9_999 * 100);
587 }
588
589 #[test]
590 fn test_decimal_append_column() {
591 let mut col1 = ColumnDecimal::new(Type::decimal(9, 2));
593 col1.append(12345);
594 col1.append(67890);
595
596 let mut col2 = ColumnDecimal::new(Type::decimal(9, 2));
597 col2.append(11111);
598 col2.append(22222);
599
600 col1.append_column(Arc::new(col2)).unwrap();
602
603 assert_eq!(col1.len(), 4);
604 assert_eq!(col1.at(0), 12345);
605 assert_eq!(col1.at(1), 67890);
606 assert_eq!(col1.at(2), 11111);
607 assert_eq!(col1.at(3), 22222);
608 }
609
610 #[test]
611 fn test_decimal_slice() {
612 let mut col = ColumnDecimal::new(Type::decimal(18, 4));
614 for i in 0..10 {
615 col.append(i * 10000);
616 }
617
618 let sliced = col.slice(2, 5).unwrap();
619 assert_eq!(sliced.size(), 5);
620
621 let sliced_concrete =
622 sliced.as_any().downcast_ref::<ColumnDecimal>().unwrap();
623 assert_eq!(sliced_concrete.at(0), 2 * 10000);
624 assert_eq!(sliced_concrete.at(4), 6 * 10000);
625 }
626
627 #[test]
628 fn test_decimal_clear_and_reuse() {
629 let mut col = ColumnDecimal::new(Type::decimal(9, 2));
631 col.append(100);
632 col.append(200);
633 assert_eq!(col.len(), 2);
634
635 col.clear();
636 assert_eq!(col.len(), 0);
637 assert!(col.is_empty());
638
639 col.append(300);
641 col.append(400);
642 assert_eq!(col.len(), 2);
643 assert_eq!(col.at(0), 300);
644 assert_eq!(col.at(1), 400);
645 }
646
647 #[test]
648 fn test_decimal_with_data_constructor() {
649 let data = vec![100, 200, 300];
651
652 let col9 =
654 ColumnDecimal::new(Type::decimal(9, 2)).with_data(data.clone());
655 assert_eq!(col9.len(), 3);
656 assert_eq!(col9.at(0), 100);
657 assert!(col9.data.as_any().is::<ColumnInt32>());
658
659 let col18 =
661 ColumnDecimal::new(Type::decimal(18, 4)).with_data(data.clone());
662 assert_eq!(col18.len(), 3);
663 assert_eq!(col18.at(0), 100);
664 assert!(col18.data.as_any().is::<ColumnInt64>());
665
666 let col38 =
668 ColumnDecimal::new(Type::decimal(38, 10)).with_data(data.clone());
669 assert_eq!(col38.len(), 3);
670 assert_eq!(col38.at(0), 100);
671 assert!(col38.data.as_any().is::<ColumnInt128>());
672 }
673}