Skip to main content

qail_pg/types/
numeric.rs

1//! NUMERIC/DECIMAL type support for PostgreSQL.
2//!
3//! PostgreSQL NUMERIC is a variable-precision type stored in a complex binary format.
4//! For simplicity, we use String representation and convert on demand.
5
6use super::{FromPg, ToPg, TypeError};
7use crate::protocol::types::oid;
8
9/// NUMERIC/DECIMAL type (stored as string for precision)
10#[derive(Debug, Clone, PartialEq)]
11pub struct Numeric(pub String);
12
13impl Numeric {
14    /// Create from string representation
15    pub fn new(s: impl Into<String>) -> Self {
16        Self(s.into())
17    }
18
19    /// Parse as f64 (may lose precision for very large numbers)
20    pub fn to_f64(&self) -> Result<f64, std::num::ParseFloatError> {
21        self.0.parse()
22    }
23
24    /// Parse as i64 (truncates decimal part)
25    pub fn to_i64(&self) -> Result<i64, std::num::ParseIntError> {
26        // Remove decimal part if present
27        let int_part = self.0.split('.').next().unwrap_or("0");
28        int_part.parse()
29    }
30
31    /// Parse as i64 only when the numeric value has no non-zero decimal part.
32    pub fn to_i64_exact(&self) -> Result<i64, std::num::ParseIntError> {
33        if !self.is_integral() {
34            return self.0.parse();
35        }
36        self.0.split('.').next().unwrap_or("0").parse::<i64>()
37    }
38
39    /// Whether the decimal representation has no non-zero fractional digits.
40    pub fn is_integral(&self) -> bool {
41        let Some((_, fractional)) = self.0.split_once('.') else {
42            return true;
43        };
44        !fractional.is_empty() && fractional.bytes().all(|b| b == b'0')
45    }
46
47    /// Get the string representation
48    pub fn as_str(&self) -> &str {
49        &self.0
50    }
51}
52
53impl FromPg for Numeric {
54    fn from_pg(bytes: &[u8], oid_val: u32, format: i16) -> Result<Self, TypeError> {
55        if oid_val != oid::NUMERIC {
56            return Err(TypeError::UnexpectedOid {
57                expected: "numeric",
58                got: oid_val,
59            });
60        }
61
62        if format == 1 {
63            // Binary format: complex packed decimal format
64            // For now, we don't support binary NUMERIC - it requires unpacking
65            // the PostgreSQL packed decimal format (ndigits, weight, sign, dscale, digits)
66            decode_numeric_binary(bytes)
67        } else {
68            // Text format: just the string
69            let s =
70                std::str::from_utf8(bytes).map_err(|e| TypeError::InvalidData(e.to_string()))?;
71            Ok(Numeric(s.to_string()))
72        }
73    }
74}
75
76impl ToPg for Numeric {
77    fn to_pg(&self) -> (Vec<u8>, u32, i16) {
78        // Send as text for simplicity
79        (self.0.as_bytes().to_vec(), oid::NUMERIC, 0)
80    }
81}
82
83/// Decode PostgreSQL binary NUMERIC format
84fn decode_numeric_binary(bytes: &[u8]) -> Result<Numeric, TypeError> {
85    if bytes.len() < 8 {
86        return Err(TypeError::InvalidData("NUMERIC too short".to_string()));
87    }
88
89    // PostgreSQL NUMERIC binary format:
90    // 2 bytes: ndigits (number of base-10000 digits)
91    // 2 bytes: weight (position of first digit relative to decimal point)
92    // 2 bytes: sign (0=pos, 0x4000=neg, 0xC000=NaN)
93    // 2 bytes: dscale (number of decimal digits after decimal point)
94    // ndigits * 2 bytes: digits (each 0-9999)
95
96    let ndigits = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
97    let weight = i16::from_be_bytes([bytes[2], bytes[3]]);
98    let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
99    let dscale = u16::from_be_bytes([bytes[6], bytes[7]]) as usize;
100
101    if bytes.len() < 8 + ndigits * 2 {
102        return Err(TypeError::InvalidData("NUMERIC truncated".to_string()));
103    }
104
105    if sign == 0xC000 {
106        return Ok(Numeric("NaN".to_string()));
107    }
108    if !matches!(sign, 0 | 0x4000) {
109        return Err(TypeError::InvalidData(format!(
110            "NUMERIC sign out of range: {sign:#06x}"
111        )));
112    }
113
114    if ndigits == 0 {
115        return Ok(Numeric("0".to_string()));
116    }
117
118    let mut digits = Vec::with_capacity(ndigits);
119    for i in 0..ndigits {
120        let d = u16::from_be_bytes([bytes[8 + i * 2], bytes[9 + i * 2]]);
121        if d > 9999 {
122            return Err(TypeError::InvalidData(format!(
123                "NUMERIC digit out of range: {}",
124                d
125            )));
126        }
127        digits.push(d);
128    }
129
130    let mut result = String::new();
131    if sign == 0x4000 {
132        result.push('-');
133    }
134
135    // Integer part. A negative weight means every stored base-10000 digit is
136    // fractional; do not cast it to usize or it wraps to a huge group count.
137    let int_digits = i32::from(weight) + 1;
138    if int_digits > 0 {
139        let int_digits = int_digits as usize;
140        for i in 0..int_digits {
141            let digit = digits.get(i).copied().unwrap_or(0);
142            if i == 0 {
143                result.push_str(&digit.to_string());
144            } else {
145                result.push_str(&format!("{:04}", digit));
146            }
147        }
148    }
149
150    if result.is_empty() || result == "-" {
151        result.push('0');
152    }
153
154    // Decimal part
155    if dscale > 0 {
156        result.push('.');
157        let mut fractional = String::new();
158        if int_digits < 0 {
159            for _ in 0..(-int_digits) {
160                fractional.push_str("0000");
161            }
162        }
163
164        let start = int_digits.max(0) as usize;
165        for digit in digits.iter().skip(start) {
166            fractional.push_str(&format!("{:04}", digit));
167        }
168
169        if fractional.len() < dscale {
170            fractional.extend(std::iter::repeat_n('0', dscale - fractional.len()));
171        } else {
172            fractional.truncate(dscale);
173        }
174        result.push_str(&fractional);
175    }
176
177    Ok(Numeric(result))
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_numeric_from_text() {
186        let n = Numeric::from_pg(b"123.456", oid::NUMERIC, 0).unwrap();
187        assert_eq!(n.0, "123.456");
188        assert!((n.to_f64().unwrap() - 123.456).abs() < 0.0001);
189    }
190
191    #[test]
192    fn test_numeric_to_i64() {
193        let n = Numeric::new("12345.67");
194        assert_eq!(n.to_i64().unwrap(), 12345);
195    }
196
197    #[test]
198    fn test_numeric_to_i64_exact_rejects_fractional_values() {
199        assert_eq!(Numeric::new("12345.00").to_i64_exact().unwrap(), 12345);
200        assert!(Numeric::new("12345.67").to_i64_exact().is_err());
201    }
202
203    #[test]
204    fn test_numeric_negative() {
205        let n = Numeric::new("-999.99");
206        assert_eq!(n.to_f64().unwrap(), -999.99);
207    }
208
209    #[test]
210    fn test_numeric_binary_decodes_negative_weight() {
211        let mut bytes = Vec::new();
212        bytes.extend_from_slice(&1u16.to_be_bytes()); // ndigits
213        bytes.extend_from_slice(&(-2i16).to_be_bytes()); // weight
214        bytes.extend_from_slice(&0u16.to_be_bytes()); // sign
215        bytes.extend_from_slice(&8u16.to_be_bytes()); // dscale
216        bytes.extend_from_slice(&1u16.to_be_bytes()); // digit
217
218        let n = Numeric::from_pg(&bytes, oid::NUMERIC, 1).unwrap();
219        assert_eq!(n.as_str(), "0.00000001");
220    }
221
222    #[test]
223    fn test_numeric_binary_rejects_out_of_range_digits() {
224        let mut bytes = Vec::new();
225        bytes.extend_from_slice(&1u16.to_be_bytes()); // ndigits
226        bytes.extend_from_slice(&0i16.to_be_bytes()); // weight
227        bytes.extend_from_slice(&0u16.to_be_bytes()); // sign
228        bytes.extend_from_slice(&0u16.to_be_bytes()); // dscale
229        bytes.extend_from_slice(&10000u16.to_be_bytes()); // invalid base-10000 digit
230
231        let err = Numeric::from_pg(&bytes, oid::NUMERIC, 1).unwrap_err();
232        assert!(matches!(err, TypeError::InvalidData(msg) if msg.contains("out of range")));
233    }
234
235    #[test]
236    fn test_numeric_binary_rejects_unknown_sign_code() {
237        let mut bytes = Vec::new();
238        bytes.extend_from_slice(&1u16.to_be_bytes()); // ndigits
239        bytes.extend_from_slice(&0i16.to_be_bytes()); // weight
240        bytes.extend_from_slice(&0x2000u16.to_be_bytes()); // invalid sign
241        bytes.extend_from_slice(&0u16.to_be_bytes()); // dscale
242        bytes.extend_from_slice(&123u16.to_be_bytes()); // digit
243
244        let err = Numeric::from_pg(&bytes, oid::NUMERIC, 1).unwrap_err();
245        assert!(matches!(err, TypeError::InvalidData(msg) if msg.contains("sign out of range")));
246    }
247}