qail_pg/types/
numeric.rs

1//! NUMERIC/DECIMAL type support for PostgreSQL.
2//!
3//! PostgreSQL NUMERIC is a variable-precision type stored in a complex binary format.
4//! For simplicity, we use String representation and convert on demand.
5
6use super::{FromPg, ToPg, TypeError};
7use crate::protocol::types::oid;
8
9/// NUMERIC/DECIMAL type (stored as string for precision)
10#[derive(Debug, Clone, PartialEq)]
11pub struct Numeric(pub String);
12
13impl Numeric {
14    /// Create from string representation
15    pub fn new(s: impl Into<String>) -> Self {
16        Self(s.into())
17    }
18
19    /// Parse as f64 (may lose precision for very large numbers)
20    pub fn to_f64(&self) -> Result<f64, std::num::ParseFloatError> {
21        self.0.parse()
22    }
23
24    /// Parse as i64 (truncates decimal part)
25    pub fn to_i64(&self) -> Result<i64, std::num::ParseIntError> {
26        // Remove decimal part if present
27        let int_part = self.0.split('.').next().unwrap_or("0");
28        int_part.parse()
29    }
30
31    /// Get the string representation
32    pub fn as_str(&self) -> &str {
33        &self.0
34    }
35}
36
37impl FromPg for Numeric {
38    fn from_pg(bytes: &[u8], oid_val: u32, format: i16) -> Result<Self, TypeError> {
39        if oid_val != oid::NUMERIC {
40            return Err(TypeError::UnexpectedOid {
41                expected: "numeric",
42                got: oid_val,
43            });
44        }
45
46        if format == 1 {
47            // Binary format: complex packed decimal format
48            // For now, we don't support binary NUMERIC - it requires unpacking
49            // the PostgreSQL packed decimal format (ndigits, weight, sign, dscale, digits)
50            decode_numeric_binary(bytes)
51        } else {
52            // Text format: just the string
53            let s =
54                std::str::from_utf8(bytes).map_err(|e| TypeError::InvalidData(e.to_string()))?;
55            Ok(Numeric(s.to_string()))
56        }
57    }
58}
59
60impl ToPg for Numeric {
61    fn to_pg(&self) -> (Vec<u8>, u32, i16) {
62        // Send as text for simplicity
63        (self.0.as_bytes().to_vec(), oid::NUMERIC, 0)
64    }
65}
66
67/// Decode PostgreSQL binary NUMERIC format
68fn decode_numeric_binary(bytes: &[u8]) -> Result<Numeric, TypeError> {
69    if bytes.len() < 8 {
70        return Err(TypeError::InvalidData("NUMERIC too short".to_string()));
71    }
72
73    // PostgreSQL NUMERIC binary format:
74    // 2 bytes: ndigits (number of base-10000 digits)
75    // 2 bytes: weight (position of first digit relative to decimal point)
76    // 2 bytes: sign (0=pos, 0x4000=neg, 0xC000=NaN)
77    // 2 bytes: dscale (number of decimal digits after decimal point)
78    // ndigits * 2 bytes: digits (each 0-9999)
79
80    let ndigits = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
81    let weight = i16::from_be_bytes([bytes[2], bytes[3]]);
82    let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
83    let dscale = u16::from_be_bytes([bytes[6], bytes[7]]) as usize;
84
85    if bytes.len() < 8 + ndigits * 2 {
86        return Err(TypeError::InvalidData("NUMERIC truncated".to_string()));
87    }
88
89    // Handle special cases
90    if sign == 0xC000 {
91        return Ok(Numeric("NaN".to_string()));
92    }
93
94    if ndigits == 0 {
95        return Ok(Numeric("0".to_string()));
96    }
97
98    // Extract digits
99    let mut digits = Vec::with_capacity(ndigits);
100    for i in 0..ndigits {
101        let d = u16::from_be_bytes([bytes[8 + i * 2], bytes[9 + i * 2]]);
102        digits.push(d);
103    }
104
105    // Build string representation
106    let mut result = String::new();
107    if sign == 0x4000 {
108        result.push('-');
109    }
110
111    // Integer part
112    let int_digits = (weight + 1) as usize;
113    for (i, digit) in digits.iter().enumerate().take(int_digits.min(ndigits)) {
114        if i == 0 {
115            result.push_str(&digit.to_string());
116        } else {
117            result.push_str(&format!("{:04}", digit));
118        }
119    }
120    // Pad with zeros if weight > ndigits
121    for _ in ndigits..int_digits {
122        result.push_str("0000");
123    }
124
125    if result.is_empty() || result == "-" {
126        result.push('0');
127    }
128
129    // Decimal part
130    if dscale > 0 {
131        result.push('.');
132        let start = int_digits.max(0);
133        let mut decimal_digits = 0;
134        for digit in digits.iter().skip(start) {
135            let s = format!("{:04}", digit);
136            for c in s.chars() {
137                if decimal_digits >= dscale {
138                    break;
139                }
140                result.push(c);
141                decimal_digits += 1;
142            }
143        }
144        // Pad with zeros if needed
145        while decimal_digits < dscale {
146            result.push('0');
147            decimal_digits += 1;
148        }
149    }
150
151    Ok(Numeric(result))
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn test_numeric_from_text() {
160        let n = Numeric::from_pg(b"123.456", oid::NUMERIC, 0).unwrap();
161        assert_eq!(n.0, "123.456");
162        assert!((n.to_f64().unwrap() - 123.456).abs() < 0.0001);
163    }
164
165    #[test]
166    fn test_numeric_to_i64() {
167        let n = Numeric::new("12345.67");
168        assert_eq!(n.to_i64().unwrap(), 12345);
169    }
170
171    #[test]
172    fn test_numeric_negative() {
173        let n = Numeric::new("-999.99");
174        assert_eq!(n.to_f64().unwrap(), -999.99);
175    }
176}