Skip to main content

decimal_bytes/
encoding.rs

1//! Byte encoding for decimal values.
2//!
3//! This module implements a lexicographically sortable encoding for decimal numbers.
4//! The encoding ensures that byte-wise comparison yields the same result as numerical comparison.
5//!
6//! ## Encoding Format
7//!
8//! ```text
9//! [sign byte] [exponent bytes] [mantissa bytes]
10//! ```
11//!
12//! - **Sign byte**: 0x00 for negative, 0x80 for zero, 0xFF for positive
13//! - **Exponent**: Variable-length, biased encoding (inverted for negative numbers)
14//! - **Mantissa**: BCD-encoded digits, 2 per byte (inverted for negative numbers)
15//!
16//! ## Special Values (PostgreSQL compatible)
17//!
18//! - **-Infinity**: Sorts less than all negative numbers
19//! - **+Infinity**: Sorts greater than all positive numbers
20//! - **NaN**: Sorts greater than +Infinity (per PostgreSQL semantics)
21//!
22//! ## Sort Order
23//!
24//! ```text
25//! -Infinity < negatives < zero < positives < +Infinity < NaN
26//! ```
27
28use thiserror::Error;
29
30/// Sign byte values for regular numbers
31pub(crate) const SIGN_NEGATIVE: u8 = 0x00;
32pub(crate) const SIGN_ZERO: u8 = 0x80;
33pub(crate) const SIGN_POSITIVE: u8 = 0xFF;
34
35/// Special value encodings (designed for correct lexicographic ordering)
36/// -Infinity: [0x00, 0x00, 0x00] - sorts before all negative numbers
37pub const ENCODING_NEG_INFINITY: [u8; 3] = [0x00, 0x00, 0x00];
38/// +Infinity: [0xFF, 0xFF, 0xFE] - sorts after all positive numbers
39pub const ENCODING_POS_INFINITY: [u8; 3] = [0xFF, 0xFF, 0xFE];
40/// NaN: [0xFF, 0xFF, 0xFF] - sorts after +Infinity (PostgreSQL semantics)
41pub const ENCODING_NAN: [u8; 3] = [0xFF, 0xFF, 0xFF];
42
43/// Reserved exponent values (to distinguish special values from regular numbers)
44const RESERVED_NEG_INFINITY_EXP: u16 = 0x0000; // For negative sign byte
45const RESERVED_POS_INFINITY_EXP: u16 = 0xFFFE; // For positive sign byte
46const RESERVED_NAN_EXP: u16 = 0xFFFF; // For positive sign byte
47
48/// Exponent bias to make all exponents positive for encoding
49const EXPONENT_BIAS: i32 = 16384;
50const MAX_EXPONENT: i32 = 32767 - EXPONENT_BIAS - 2; // Reserve top 2 values for Infinity/NaN
51const MIN_EXPONENT: i32 = -EXPONENT_BIAS + 1; // Reserve 0x0000 for -Infinity
52
53/// Errors that can occur during decimal encoding/decoding.
54#[derive(Error, Debug, Clone, PartialEq)]
55pub enum DecimalError {
56    /// The input string format is invalid.
57    #[error("Invalid format: {0}")]
58    InvalidFormat(String),
59
60    /// The number exceeds the supported precision range.
61    #[error("Precision overflow: exponent out of range")]
62    PrecisionOverflow,
63
64    /// The encoded bytes are invalid.
65    #[error("Invalid encoding")]
66    InvalidEncoding,
67}
68
69/// Special decimal values (IEEE 754 / PostgreSQL compatible)
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum SpecialValue {
72    /// Positive infinity
73    Infinity,
74    /// Negative infinity
75    NegInfinity,
76    /// Not a Number
77    NaN,
78}
79
80/// Encodes a decimal string to sortable bytes.
81pub fn encode_decimal(value: &str) -> Result<Vec<u8>, DecimalError> {
82    // Check for special values first
83    if let Some(special) = parse_special_value(value) {
84        return Ok(encode_special_value(special));
85    }
86
87    let (is_negative, digits, exponent) = parse_decimal(value)?;
88
89    // Handle zero
90    if digits.is_empty() {
91        return Ok(vec![SIGN_ZERO]);
92    }
93
94    let mut result = Vec::with_capacity(1 + 2 + digits.len().div_ceil(2));
95
96    // Sign byte
97    result.push(if is_negative {
98        SIGN_NEGATIVE
99    } else {
100        SIGN_POSITIVE
101    });
102
103    // Encode exponent
104    encode_exponent(&mut result, exponent, is_negative);
105
106    // Encode mantissa (BCD, 2 digits per byte)
107    encode_mantissa(&mut result, &digits, is_negative);
108
109    Ok(result)
110}
111
112/// Parses special value strings (case-insensitive).
113fn parse_special_value(value: &str) -> Option<SpecialValue> {
114    let trimmed = value.trim();
115    let lower = trimmed.to_lowercase();
116
117    match lower.as_str() {
118        "infinity" | "inf" | "+infinity" | "+inf" => Some(SpecialValue::Infinity),
119        "-infinity" | "-inf" => Some(SpecialValue::NegInfinity),
120        "nan" | "-nan" | "+nan" => Some(SpecialValue::NaN), // PostgreSQL treats all NaN as equal
121        _ => None,
122    }
123}
124
125/// Encodes a special value to bytes.
126pub fn encode_special_value(special: SpecialValue) -> Vec<u8> {
127    match special {
128        SpecialValue::NegInfinity => ENCODING_NEG_INFINITY.to_vec(),
129        SpecialValue::Infinity => ENCODING_POS_INFINITY.to_vec(),
130        SpecialValue::NaN => ENCODING_NAN.to_vec(),
131    }
132}
133
134/// Checks if bytes represent a special value.
135pub fn decode_special_value(bytes: &[u8]) -> Option<SpecialValue> {
136    if bytes.len() == 3 {
137        if bytes == ENCODING_NEG_INFINITY {
138            return Some(SpecialValue::NegInfinity);
139        }
140        if bytes == ENCODING_POS_INFINITY {
141            return Some(SpecialValue::Infinity);
142        }
143        if bytes == ENCODING_NAN {
144            return Some(SpecialValue::NaN);
145        }
146    }
147    None
148}
149
150/// Encodes a decimal string with precision and scale constraints.
151///
152/// # Arguments
153/// * `value` - The decimal string to encode
154/// * `precision` - Maximum total significant digits (None = unlimited)
155/// * `scale` - Number of digits after decimal point; negative values round to left of decimal
156///
157/// # PostgreSQL Compatibility
158/// Supports negative scale (rounds to powers of 10):
159/// - scale = -3 rounds to nearest 1000
160/// - NUMERIC(2, -3) allows values like -99000 to 99000
161pub fn encode_decimal_with_constraints(
162    value: &str,
163    precision: Option<u32>,
164    scale: Option<i32>,
165) -> Result<Vec<u8>, DecimalError> {
166    // Handle special values - they ignore precision/scale
167    if parse_special_value(value).is_some() {
168        return encode_decimal(value);
169    }
170
171    let truncated = truncate_decimal(value, precision, scale)?;
172    encode_decimal(&truncated)
173}
174
175/// Decodes bytes back to a decimal string.
176pub fn decode_to_string(bytes: &[u8]) -> Result<String, DecimalError> {
177    if bytes.is_empty() {
178        return Err(DecimalError::InvalidEncoding);
179    }
180
181    // Check for special values first
182    if let Some(special) = decode_special_value(bytes) {
183        return Ok(match special {
184            SpecialValue::NegInfinity => "-Infinity".to_string(),
185            SpecialValue::Infinity => "Infinity".to_string(),
186            SpecialValue::NaN => "NaN".to_string(),
187        });
188    }
189
190    let sign_byte = bytes[0];
191
192    // Handle zero
193    if sign_byte == SIGN_ZERO {
194        return Ok("0".to_string());
195    }
196
197    let is_negative = sign_byte == SIGN_NEGATIVE;
198
199    if sign_byte != SIGN_NEGATIVE && sign_byte != SIGN_POSITIVE {
200        return Err(DecimalError::InvalidEncoding);
201    }
202
203    // Decode exponent (also validates it's not a reserved value)
204    let (exponent, mantissa_start) = decode_exponent(&bytes[1..], is_negative)?;
205
206    // Decode mantissa
207    let mantissa_bytes = &bytes[1 + mantissa_start..];
208    let digits = decode_mantissa(mantissa_bytes, is_negative)?;
209
210    // Format as string
211    format_decimal(is_negative, &digits, exponent)
212}
213
214/// Parses a decimal string into sign, digits, and exponent.
215fn parse_decimal(value: &str) -> Result<(bool, Vec<u8>, i32), DecimalError> {
216    let value = value.trim();
217    let mut chars = value.chars().peekable();
218
219    // Handle sign
220    let is_negative = if chars.peek() == Some(&'-') {
221        chars.next();
222        true
223    } else if chars.peek() == Some(&'+') {
224        chars.next();
225        false
226    } else {
227        false
228    };
229
230    // Collect the numeric part (before 'e' or 'E')
231    let mut integer_part = String::new();
232    let mut fractional_part = String::new();
233    let mut seen_decimal = false;
234    let mut seen_exponent_marker = false;
235
236    while let Some(&c) = chars.peek() {
237        if c == '.' {
238            if seen_decimal {
239                return Err(DecimalError::InvalidFormat(
240                    "Multiple decimal points".to_string(),
241                ));
242            }
243            seen_decimal = true;
244            chars.next();
245        } else if c.is_ascii_digit() {
246            if seen_decimal {
247                fractional_part.push(c);
248            } else {
249                integer_part.push(c);
250            }
251            chars.next();
252        } else if c == 'e' || c == 'E' {
253            seen_exponent_marker = true;
254            chars.next();
255            break;
256        } else {
257            return Err(DecimalError::InvalidFormat(format!(
258                "Invalid character: {}",
259                c
260            )));
261        }
262    }
263
264    // Parse exponent (required if 'e' or 'E' was seen)
265    let mut exp_offset: i32 = 0;
266    if seen_exponent_marker {
267        if chars.peek().is_none() {
268            return Err(DecimalError::InvalidFormat(
269                "Missing exponent after 'e'".to_string(),
270            ));
271        }
272        let exp_str: String = chars.collect();
273        exp_offset = exp_str
274            .parse()
275            .map_err(|_| DecimalError::InvalidFormat(format!("Invalid exponent: {}", exp_str)))?;
276    }
277
278    // Handle empty input
279    if integer_part.is_empty() && fractional_part.is_empty() {
280        return Ok((false, vec![], 0));
281    }
282
283    // If only fractional part, integer part is "0"
284    if integer_part.is_empty() {
285        integer_part.push('0');
286    }
287
288    // Remember where the decimal point was before combining
289    let decimal_position = integer_part.len();
290
291    // Combine all digits by appending fractional part (avoids extra allocation)
292    integer_part.push_str(&fractional_part);
293    let all_digits = integer_part;
294
295    // Find the first and last non-zero digit positions
296    let first_nonzero = all_digits.chars().position(|c| c != '0');
297    let last_nonzero = all_digits.chars().rev().position(|c| c != '0');
298
299    // If all zeros, return zero
300    if first_nonzero.is_none() {
301        return Ok((false, vec![], 0));
302    }
303
304    let first_nonzero = first_nonzero.unwrap();
305    let last_nonzero = all_digits.len() - 1 - last_nonzero.unwrap();
306
307    // Extract the significant digits
308    let significant = &all_digits[first_nonzero..=last_nonzero];
309
310    // Calculate the exponent
311    let exponent = (decimal_position as i32) - (first_nonzero as i32) + exp_offset;
312
313    // Convert significant digits to bytes
314    let digits: Vec<u8> = significant
315        .chars()
316        .map(|c| c.to_digit(10).unwrap() as u8)
317        .collect();
318
319    // Validate exponent range
320    if !(MIN_EXPONENT..=MAX_EXPONENT).contains(&exponent) {
321        return Err(DecimalError::PrecisionOverflow);
322    }
323
324    Ok((is_negative, digits, exponent))
325}
326
327/// Encodes the exponent as variable-length bytes.
328fn encode_exponent(result: &mut Vec<u8>, exponent: i32, is_negative: bool) {
329    // Bias the exponent to make it always positive
330    // Note: We add 1 to reserve 0x0000 for -Infinity on negative side
331    let biased = (exponent + EXPONENT_BIAS) as u16;
332
333    // For negative numbers, invert the exponent so larger negative numbers sort first
334    let encoded = if is_negative { !biased } else { biased };
335
336    // Use 2 bytes for the exponent (big-endian)
337    result.push((encoded >> 8) as u8);
338    result.push((encoded & 0xFF) as u8);
339}
340
341/// Decodes the exponent from bytes.
342fn decode_exponent(bytes: &[u8], is_negative: bool) -> Result<(i32, usize), DecimalError> {
343    if bytes.len() < 2 {
344        return Err(DecimalError::InvalidEncoding);
345    }
346
347    let encoded = ((bytes[0] as u16) << 8) | (bytes[1] as u16);
348
349    // Check for reserved values (should have been caught by decode_special_value)
350    if is_negative && encoded == RESERVED_NEG_INFINITY_EXP {
351        return Err(DecimalError::InvalidEncoding);
352    }
353    if !is_negative && (encoded == RESERVED_POS_INFINITY_EXP || encoded == RESERVED_NAN_EXP) {
354        return Err(DecimalError::InvalidEncoding);
355    }
356
357    let biased = if is_negative { !encoded } else { encoded };
358    let exponent = (biased as i32) - EXPONENT_BIAS;
359
360    Ok((exponent, 2))
361}
362
363/// Encodes the mantissa as BCD (2 digits per byte).
364fn encode_mantissa(result: &mut Vec<u8>, digits: &[u8], is_negative: bool) {
365    // Pack 2 digits per byte
366    let mut i = 0;
367    while i < digits.len() {
368        let high = digits[i];
369        let low = if i + 1 < digits.len() {
370            digits[i + 1]
371        } else {
372            0 // Pad with 0 if odd number of digits
373        };
374
375        let byte = (high << 4) | low;
376
377        // For negative numbers, invert to reverse the sort order
378        result.push(if is_negative { !byte } else { byte });
379
380        i += 2;
381    }
382}
383
384/// Decodes the mantissa from BCD bytes.
385fn decode_mantissa(bytes: &[u8], is_negative: bool) -> Result<Vec<u8>, DecimalError> {
386    let mut digits = Vec::with_capacity(bytes.len() * 2);
387
388    for &byte in bytes {
389        let byte = if is_negative { !byte } else { byte };
390        let high = (byte >> 4) & 0x0F;
391        let low = byte & 0x0F;
392
393        if high > 9 || low > 9 {
394            return Err(DecimalError::InvalidEncoding);
395        }
396
397        digits.push(high);
398        digits.push(low);
399    }
400
401    // Remove trailing zeros (padding)
402    while digits.last() == Some(&0) && digits.len() > 1 {
403        digits.pop();
404    }
405
406    Ok(digits)
407}
408
409/// Formats digits and exponent back to a decimal string.
410fn format_decimal(is_negative: bool, digits: &[u8], exponent: i32) -> Result<String, DecimalError> {
411    if digits.is_empty() {
412        return Ok("0".to_string());
413    }
414
415    let mut result = String::new();
416
417    if is_negative {
418        result.push('-');
419    }
420
421    let num_digits = digits.len() as i32;
422
423    if exponent >= num_digits {
424        // All digits are before the decimal point (integer part)
425        for d in digits {
426            result.push(char::from_digit(*d as u32, 10).unwrap());
427        }
428        // Add trailing zeros if needed
429        for _ in 0..(exponent - num_digits) {
430            result.push('0');
431        }
432    } else if exponent <= 0 {
433        // All digits are after the decimal point
434        result.push('0');
435        result.push('.');
436        for _ in 0..(-exponent) {
437            result.push('0');
438        }
439        for d in digits {
440            result.push(char::from_digit(*d as u32, 10).unwrap());
441        }
442    } else {
443        // Some digits before decimal, some after
444        let decimal_pos = exponent as usize;
445        for (i, d) in digits.iter().enumerate() {
446            if i == decimal_pos {
447                result.push('.');
448            }
449            result.push(char::from_digit(*d as u32, 10).unwrap());
450        }
451    }
452
453    Ok(result)
454}
455
456/// Truncates a decimal string to fit precision and scale constraints.
457///
458/// # PostgreSQL Compatibility
459/// - Positive scale: digits after decimal point
460/// - Negative scale: rounds to left of decimal (e.g., -3 rounds to nearest 1000)
461/// - Precision: total significant (non-rounded) digits
462fn truncate_decimal(
463    value: &str,
464    precision: Option<u32>,
465    scale: Option<i32>,
466) -> Result<String, DecimalError> {
467    // Parse to get sign and parts
468    let value = value.trim();
469    let is_negative = value.starts_with('-');
470    let value = value.trim_start_matches(['-', '+']);
471
472    // Split into integer and fractional parts
473    let (integer_part, fractional_part) = if let Some(dot_pos) = value.find('.') {
474        (&value[..dot_pos], &value[dot_pos + 1..])
475    } else {
476        (value, "")
477    };
478
479    // Trim leading zeros from integer part (but keep at least one digit)
480    let integer_part = integer_part.trim_start_matches('0');
481    let integer_part = if integer_part.is_empty() {
482        "0"
483    } else {
484        integer_part
485    };
486
487    let scale_val = scale.unwrap_or(0);
488
489    // Handle negative scale (round to left of decimal point)
490    if scale_val < 0 {
491        let round_digits = (-scale_val) as usize;
492
493        // Remove all fractional digits when scale is negative
494        let mut int_str = integer_part.to_string();
495
496        if int_str.len() <= round_digits {
497            // Number is smaller than the rounding unit
498            // Round: if the number >= half the rounding unit, round up to the unit
499            let num_val: u64 = int_str.parse().unwrap_or(0);
500            let rounding_unit = 10u64.pow(round_digits as u32);
501            let half_unit = rounding_unit / 2;
502
503            let result = if num_val >= half_unit {
504                rounding_unit.to_string()
505            } else {
506                "0".to_string()
507            };
508
509            return if is_negative && result != "0" {
510                Ok(format!("-{}", result))
511            } else {
512                Ok(result)
513            };
514        }
515
516        // Round the integer part
517        let keep_len = int_str.len() - round_digits;
518        let keep_part = &int_str[..keep_len];
519        let round_part = &int_str[keep_len..];
520
521        // Check if we need to round up
522        let first_rounded_digit = round_part.chars().next().unwrap_or('0');
523        let mut result_int = keep_part.to_string();
524
525        if first_rounded_digit >= '5' {
526            result_int = add_one_to_integer(&result_int);
527        }
528
529        // Add trailing zeros
530        int_str = format!("{}{}", result_int, "0".repeat(round_digits));
531
532        // Apply precision constraint for negative scale
533        if let Some(p) = precision {
534            let max_significant = p as usize;
535            let significant_len = result_int.trim_start_matches('0').len();
536            if significant_len > max_significant && max_significant > 0 {
537                // Truncate from left (keep least significant digits)
538                let trimmed = &result_int[result_int.len().saturating_sub(max_significant)..];
539                int_str = format!("{}{}", trimmed, "0".repeat(round_digits));
540            }
541        }
542
543        return if is_negative && int_str != "0" {
544            Ok(format!("-{}", int_str))
545        } else {
546            Ok(int_str)
547        };
548    }
549
550    // Handle positive scale (normal case - digits after decimal)
551    let scale_usize = scale_val as usize;
552
553    // Apply scale constraint (truncate/round fractional part)
554    let (mut integer_part, fractional_part) = if fractional_part.len() > scale_usize {
555        // Round the last digit
556        let truncated = &fractional_part[..scale_usize];
557        let next_digit = fractional_part.chars().nth(scale_usize).unwrap_or('0');
558
559        if next_digit >= '5' {
560            // Round up - this may carry into integer part
561            if scale_usize == 0 {
562                // Rounding to integer
563                (add_one_to_integer(integer_part), String::new())
564            } else {
565                let rounded = round_up(truncated);
566                if rounded.len() > scale_usize {
567                    // Carry into integer part
568                    let new_int = add_one_to_integer(integer_part);
569                    (new_int, "0".repeat(scale_usize))
570                } else {
571                    (integer_part.to_string(), rounded)
572                }
573            }
574        } else {
575            (integer_part.to_string(), truncated.to_string())
576        }
577    } else {
578        (integer_part.to_string(), fractional_part.to_string())
579    };
580
581    // Apply precision constraint
582    if let Some(p) = precision {
583        let max_integer_digits = if (p as i32) > scale_val {
584            (p as i32 - scale_val) as usize
585        } else {
586            0
587        };
588
589        if integer_part.len() > max_integer_digits && max_integer_digits > 0 {
590            // Truncate from the left (keep least significant digits)
591            integer_part = integer_part[integer_part.len() - max_integer_digits..].to_string();
592        } else if max_integer_digits == 0 {
593            integer_part = "0".to_string();
594        }
595    }
596
597    // Reconstruct
598    let result = if fractional_part.is_empty() || fractional_part.chars().all(|c| c == '0') {
599        integer_part
600    } else {
601        format!("{}.{}", integer_part, fractional_part.trim_end_matches('0'))
602    };
603
604    if is_negative && result != "0" {
605        Ok(format!("-{}", result))
606    } else {
607        Ok(result)
608    }
609}
610
611/// Adds 1 to an integer string.
612fn add_one_to_integer(s: &str) -> String {
613    let mut chars: Vec<char> = s.chars().collect();
614    let mut carry = true;
615
616    for c in chars.iter_mut().rev() {
617        if carry {
618            if *c == '9' {
619                *c = '0';
620            } else {
621                *c = char::from_digit(c.to_digit(10).unwrap() + 1, 10).unwrap();
622                carry = false;
623            }
624        }
625    }
626
627    if carry {
628        format!("1{}", chars.iter().collect::<String>())
629    } else {
630        chars.iter().collect()
631    }
632}
633
634/// Rounds up a digit string by adding 1 to the last digit.
635fn round_up(s: &str) -> String {
636    let mut chars: Vec<char> = s.chars().collect();
637    let mut carry = true;
638
639    for c in chars.iter_mut().rev() {
640        if carry {
641            if *c == '9' {
642                *c = '0';
643            } else {
644                *c = char::from_digit(c.to_digit(10).unwrap() + 1, 10).unwrap();
645                carry = false;
646            }
647        }
648    }
649
650    if carry {
651        // All 9s became 0s, prepend 1
652        format!("1{}", chars.iter().collect::<String>())
653    } else {
654        chars.iter().collect()
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_encode_decode_roundtrip() {
664        let values = vec![
665            "0",
666            "1",
667            "-1",
668            "123.456",
669            "-123.456",
670            "0.001",
671            "0.1",
672            "10",
673            "100",
674            "1000",
675            "-0.001",
676            "999999999999999999",
677        ];
678
679        for s in values {
680            let encoded = encode_decimal(s).unwrap();
681            let decoded = decode_to_string(&encoded).unwrap();
682            // Re-encode to normalize
683            let re_encoded = encode_decimal(&decoded).unwrap();
684            assert_eq!(encoded, re_encoded, "Roundtrip failed for {}", s);
685        }
686    }
687
688    #[test]
689    fn test_lexicographic_ordering() {
690        let values = vec![
691            "-1000", "-100", "-10", "-1", "-0.1", "-0.01", "0", "0.01", "0.1", "1", "10", "100",
692            "1000",
693        ];
694
695        let encoded: Vec<Vec<u8>> = values.iter().map(|s| encode_decimal(s).unwrap()).collect();
696
697        // Verify encoding preserves order
698        for i in 0..encoded.len() - 1 {
699            assert!(
700                encoded[i] < encoded[i + 1],
701                "Ordering failed: {} should be < {}",
702                values[i],
703                values[i + 1]
704            );
705        }
706    }
707
708    #[test]
709    fn test_zero_encoding() {
710        let encoded = encode_decimal("0").unwrap();
711        assert_eq!(encoded, vec![SIGN_ZERO]);
712
713        let encoded = encode_decimal("0.0").unwrap();
714        assert_eq!(encoded, vec![SIGN_ZERO]);
715
716        let encoded = encode_decimal("-0").unwrap();
717        assert_eq!(encoded, vec![SIGN_ZERO]);
718    }
719
720    #[test]
721    fn test_truncate_scale() {
722        assert_eq!(
723            truncate_decimal("123.456", None, Some(2)).unwrap(),
724            "123.46"
725        );
726        assert_eq!(
727            truncate_decimal("123.454", None, Some(2)).unwrap(),
728            "123.45"
729        );
730        assert_eq!(truncate_decimal("123.995", None, Some(2)).unwrap(), "124");
731        assert_eq!(truncate_decimal("9.999", None, Some(2)).unwrap(), "10");
732    }
733
734    #[test]
735    fn test_storage_efficiency() {
736        // 9 digit number: should be ~1 sign + 2 exp + 5 mantissa = 8 bytes
737        let encoded = encode_decimal("123456789").unwrap();
738        assert!(
739            encoded.len() <= 8,
740            "Expected <= 8 bytes, got {}",
741            encoded.len()
742        );
743
744        // Small decimal
745        let encoded = encode_decimal("0.1").unwrap();
746        assert!(
747            encoded.len() <= 4,
748            "Expected <= 4 bytes, got {}",
749            encoded.len()
750        );
751    }
752
753    // ==================== Special Values Tests ====================
754
755    #[test]
756    fn test_special_value_encoding() {
757        // Test encoding special values
758        let pos_inf = encode_decimal("Infinity").unwrap();
759        assert_eq!(pos_inf, ENCODING_POS_INFINITY.to_vec());
760
761        let neg_inf = encode_decimal("-Infinity").unwrap();
762        assert_eq!(neg_inf, ENCODING_NEG_INFINITY.to_vec());
763
764        let nan = encode_decimal("NaN").unwrap();
765        assert_eq!(nan, ENCODING_NAN.to_vec());
766    }
767
768    #[test]
769    fn test_special_value_decoding() {
770        // Test decoding special values
771        assert_eq!(
772            decode_to_string(&ENCODING_POS_INFINITY).unwrap(),
773            "Infinity"
774        );
775        assert_eq!(
776            decode_to_string(&ENCODING_NEG_INFINITY).unwrap(),
777            "-Infinity"
778        );
779        assert_eq!(decode_to_string(&ENCODING_NAN).unwrap(), "NaN");
780    }
781
782    #[test]
783    fn test_special_value_parsing_variants() {
784        // Test various ways to write special values (case-insensitive)
785        let variants = vec![
786            ("infinity", "Infinity"),
787            ("Infinity", "Infinity"),
788            ("INFINITY", "Infinity"),
789            ("inf", "Infinity"),
790            ("Inf", "Infinity"),
791            ("+infinity", "Infinity"),
792            ("+inf", "Infinity"),
793            ("-infinity", "-Infinity"),
794            ("-inf", "-Infinity"),
795            ("-Infinity", "-Infinity"),
796            ("nan", "NaN"),
797            ("NaN", "NaN"),
798            ("NAN", "NaN"),
799            ("-nan", "NaN"), // PostgreSQL treats -NaN as NaN
800            ("+nan", "NaN"),
801        ];
802
803        for (input, expected) in variants {
804            let encoded = encode_decimal(input).unwrap();
805            let decoded = decode_to_string(&encoded).unwrap();
806            assert_eq!(decoded, expected, "Failed for input: {}", input);
807        }
808    }
809
810    #[test]
811    fn test_special_value_ordering() {
812        // PostgreSQL order: -Infinity < negatives < zero < positives < Infinity < NaN
813        let values = vec![
814            "-Infinity",
815            "-1000000",
816            "-1",
817            "-0.001",
818            "0",
819            "0.001",
820            "1",
821            "1000000",
822            "Infinity",
823            "NaN",
824        ];
825
826        let encoded: Vec<Vec<u8>> = values.iter().map(|s| encode_decimal(s).unwrap()).collect();
827
828        // Verify ordering
829        for i in 0..encoded.len() - 1 {
830            assert!(
831                encoded[i] < encoded[i + 1],
832                "Special value ordering failed: {} should be < {} (bytes: {:?} < {:?})",
833                values[i],
834                values[i + 1],
835                encoded[i],
836                encoded[i + 1]
837            );
838        }
839    }
840
841    #[test]
842    fn test_special_value_roundtrip() {
843        let values = vec!["Infinity", "-Infinity", "NaN"];
844
845        for s in values {
846            let encoded = encode_decimal(s).unwrap();
847            let decoded = decode_to_string(&encoded).unwrap();
848            let re_encoded = encode_decimal(&decoded).unwrap();
849            assert_eq!(
850                encoded, re_encoded,
851                "Special value roundtrip failed for {}",
852                s
853            );
854        }
855    }
856
857    #[test]
858    fn test_decode_special_value_helper() {
859        assert_eq!(
860            decode_special_value(&ENCODING_POS_INFINITY),
861            Some(SpecialValue::Infinity)
862        );
863        assert_eq!(
864            decode_special_value(&ENCODING_NEG_INFINITY),
865            Some(SpecialValue::NegInfinity)
866        );
867        assert_eq!(decode_special_value(&ENCODING_NAN), Some(SpecialValue::NaN));
868
869        // Regular values should return None
870        let regular = encode_decimal("123.456").unwrap();
871        assert_eq!(decode_special_value(&regular), None);
872
873        let zero = encode_decimal("0").unwrap();
874        assert_eq!(decode_special_value(&zero), None);
875    }
876
877    // ==================== Negative Scale Tests ====================
878
879    #[test]
880    fn test_negative_scale_basic() {
881        // Round to nearest 10
882        assert_eq!(truncate_decimal("123", None, Some(-1)).unwrap(), "120");
883        assert_eq!(truncate_decimal("125", None, Some(-1)).unwrap(), "130");
884        assert_eq!(truncate_decimal("124", None, Some(-1)).unwrap(), "120");
885
886        // Round to nearest 100
887        assert_eq!(truncate_decimal("1234", None, Some(-2)).unwrap(), "1200");
888        assert_eq!(truncate_decimal("1250", None, Some(-2)).unwrap(), "1300");
889        assert_eq!(truncate_decimal("1249", None, Some(-2)).unwrap(), "1200");
890
891        // Round to nearest 1000
892        assert_eq!(truncate_decimal("12345", None, Some(-3)).unwrap(), "12000");
893        assert_eq!(truncate_decimal("12500", None, Some(-3)).unwrap(), "13000");
894    }
895
896    #[test]
897    fn test_negative_scale_small_numbers() {
898        // When number is smaller than rounding unit
899        assert_eq!(truncate_decimal("499", None, Some(-3)).unwrap(), "0");
900        assert_eq!(truncate_decimal("500", None, Some(-3)).unwrap(), "1000");
901        assert_eq!(truncate_decimal("999", None, Some(-3)).unwrap(), "1000");
902
903        assert_eq!(truncate_decimal("49", None, Some(-2)).unwrap(), "0");
904        assert_eq!(truncate_decimal("50", None, Some(-2)).unwrap(), "100");
905    }
906
907    #[test]
908    fn test_negative_scale_with_precision() {
909        // NUMERIC(2, -3): max 2 significant digits, round to nearest 1000
910        assert_eq!(
911            truncate_decimal("12345", Some(2), Some(-3)).unwrap(),
912            "12000"
913        );
914        // 99999 rounded to nearest 1000 = 100000
915        // "100" significant part exceeds precision 2, truncated from left to "00"
916        // Final: "00" + "000" trailing zeros = "00000"
917        // Note: PostgreSQL would error here; we truncate instead
918        assert_eq!(
919            truncate_decimal("99999", Some(2), Some(-3)).unwrap(),
920            "00000"
921        );
922    }
923
924    #[test]
925    fn test_negative_scale_negative_numbers() {
926        assert_eq!(truncate_decimal("-123", None, Some(-1)).unwrap(), "-120");
927        assert_eq!(truncate_decimal("-125", None, Some(-1)).unwrap(), "-130");
928        assert_eq!(truncate_decimal("-1234", None, Some(-2)).unwrap(), "-1200");
929    }
930
931    #[test]
932    fn test_negative_scale_with_decimal_input() {
933        // Fractional part is ignored with negative scale
934        assert_eq!(truncate_decimal("123.456", None, Some(-1)).unwrap(), "120");
935        assert_eq!(
936            truncate_decimal("1234.999", None, Some(-2)).unwrap(),
937            "1200"
938        );
939    }
940
941    #[test]
942    fn test_negative_scale_encoding_ordering() {
943        // Verify ordering is preserved with negative scale rounding
944        let values = vec!["-1000", "-100", "0", "100", "1000"];
945
946        let encoded: Vec<Vec<u8>> = values
947            .iter()
948            .map(|s| encode_decimal_with_constraints(s, None, Some(-2)).unwrap())
949            .collect();
950
951        for i in 0..encoded.len() - 1 {
952            assert!(
953                encoded[i] < encoded[i + 1],
954                "Negative scale ordering failed: {} should be < {}",
955                values[i],
956                values[i + 1]
957            );
958        }
959    }
960
961    #[test]
962    fn test_special_values_ignore_precision_scale() {
963        // Special values should pass through unchanged regardless of precision/scale
964        let inf = encode_decimal_with_constraints("Infinity", Some(5), Some(2)).unwrap();
965        assert_eq!(inf, ENCODING_POS_INFINITY.to_vec());
966
967        let neg_inf = encode_decimal_with_constraints("-Infinity", Some(5), Some(2)).unwrap();
968        assert_eq!(neg_inf, ENCODING_NEG_INFINITY.to_vec());
969
970        let nan = encode_decimal_with_constraints("NaN", Some(5), Some(2)).unwrap();
971        assert_eq!(nan, ENCODING_NAN.to_vec());
972    }
973}