1use super::{FromPg, ToPg, TypeError};
7use crate::protocol::types::oid;
8
9#[derive(Debug, Clone, PartialEq)]
11pub struct Numeric(pub String);
12
13impl Numeric {
14 pub fn new(s: impl Into<String>) -> Self {
16 Self(s.into())
17 }
18
19 pub fn to_f64(&self) -> Result<f64, std::num::ParseFloatError> {
21 self.0.parse()
22 }
23
24 pub fn to_i64(&self) -> Result<i64, std::num::ParseIntError> {
26 let int_part = self.0.split('.').next().unwrap_or("0");
28 int_part.parse()
29 }
30
31 pub fn as_str(&self) -> &str {
33 &self.0
34 }
35}
36
37impl FromPg for Numeric {
38 fn from_pg(bytes: &[u8], oid_val: u32, format: i16) -> Result<Self, TypeError> {
39 if oid_val != oid::NUMERIC {
40 return Err(TypeError::UnexpectedOid {
41 expected: "numeric",
42 got: oid_val,
43 });
44 }
45
46 if format == 1 {
47 decode_numeric_binary(bytes)
51 } else {
52 let s =
54 std::str::from_utf8(bytes).map_err(|e| TypeError::InvalidData(e.to_string()))?;
55 Ok(Numeric(s.to_string()))
56 }
57 }
58}
59
60impl ToPg for Numeric {
61 fn to_pg(&self) -> (Vec<u8>, u32, i16) {
62 (self.0.as_bytes().to_vec(), oid::NUMERIC, 0)
64 }
65}
66
67fn decode_numeric_binary(bytes: &[u8]) -> Result<Numeric, TypeError> {
69 if bytes.len() < 8 {
70 return Err(TypeError::InvalidData("NUMERIC too short".to_string()));
71 }
72
73 let ndigits = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
81 let weight = i16::from_be_bytes([bytes[2], bytes[3]]);
82 let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
83 let dscale = u16::from_be_bytes([bytes[6], bytes[7]]) as usize;
84
85 if bytes.len() < 8 + ndigits * 2 {
86 return Err(TypeError::InvalidData("NUMERIC truncated".to_string()));
87 }
88
89 if sign == 0xC000 {
90 return Ok(Numeric("NaN".to_string()));
91 }
92
93 if ndigits == 0 {
94 return Ok(Numeric("0".to_string()));
95 }
96
97 let mut digits = Vec::with_capacity(ndigits);
98 for i in 0..ndigits {
99 let d = u16::from_be_bytes([bytes[8 + i * 2], bytes[9 + i * 2]]);
100 digits.push(d);
101 }
102
103 let mut result = String::new();
104 if sign == 0x4000 {
105 result.push('-');
106 }
107
108 let int_digits = (weight + 1) as usize;
110 for (i, digit) in digits.iter().enumerate().take(int_digits.min(ndigits)) {
111 if i == 0 {
112 result.push_str(&digit.to_string());
113 } else {
114 result.push_str(&format!("{:04}", digit));
115 }
116 }
117 for _ in ndigits..int_digits {
119 result.push_str("0000");
120 }
121
122 if result.is_empty() || result == "-" {
123 result.push('0');
124 }
125
126 if dscale > 0 {
128 result.push('.');
129 let start = int_digits.max(0);
130 let mut decimal_digits = 0;
131 for digit in digits.iter().skip(start) {
132 let s = format!("{:04}", digit);
133 for c in s.chars() {
134 if decimal_digits >= dscale {
135 break;
136 }
137 result.push(c);
138 decimal_digits += 1;
139 }
140 }
141 while decimal_digits < dscale {
143 result.push('0');
144 decimal_digits += 1;
145 }
146 }
147
148 Ok(Numeric(result))
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_numeric_from_text() {
157 let n = Numeric::from_pg(b"123.456", oid::NUMERIC, 0).unwrap();
158 assert_eq!(n.0, "123.456");
159 assert!((n.to_f64().unwrap() - 123.456).abs() < 0.0001);
160 }
161
162 #[test]
163 fn test_numeric_to_i64() {
164 let n = Numeric::new("12345.67");
165 assert_eq!(n.to_i64().unwrap(), 12345);
166 }
167
168 #[test]
169 fn test_numeric_negative() {
170 let n = Numeric::new("-999.99");
171 assert_eq!(n.to_f64().unwrap(), -999.99);
172 }
173}