1use super::{FromPg, ToPg, TypeError};
7use crate::protocol::types::oid;
8
9#[derive(Debug, Clone, PartialEq)]
11pub struct Numeric(pub String);
12
13impl Numeric {
14 pub fn new(s: impl Into<String>) -> Self {
16 Self(s.into())
17 }
18
19 pub fn to_f64(&self) -> Result<f64, std::num::ParseFloatError> {
21 self.0.parse()
22 }
23
24 pub fn to_i64(&self) -> Result<i64, std::num::ParseIntError> {
26 let int_part = self.0.split('.').next().unwrap_or("0");
28 int_part.parse()
29 }
30
31 pub fn to_i64_exact(&self) -> Result<i64, std::num::ParseIntError> {
33 if !self.is_integral() {
34 return self.0.parse();
35 }
36 self.0.split('.').next().unwrap_or("0").parse::<i64>()
37 }
38
39 pub fn is_integral(&self) -> bool {
41 let Some((_, fractional)) = self.0.split_once('.') else {
42 return true;
43 };
44 !fractional.is_empty() && fractional.bytes().all(|b| b == b'0')
45 }
46
47 pub fn as_str(&self) -> &str {
49 &self.0
50 }
51}
52
53impl FromPg for Numeric {
54 fn from_pg(bytes: &[u8], oid_val: u32, format: i16) -> Result<Self, TypeError> {
55 if oid_val != oid::NUMERIC {
56 return Err(TypeError::UnexpectedOid {
57 expected: "numeric",
58 got: oid_val,
59 });
60 }
61
62 if format == 1 {
63 decode_numeric_binary(bytes)
67 } else {
68 let s =
70 std::str::from_utf8(bytes).map_err(|e| TypeError::InvalidData(e.to_string()))?;
71 Ok(Numeric(s.to_string()))
72 }
73 }
74}
75
76impl ToPg for Numeric {
77 fn to_pg(&self) -> (Vec<u8>, u32, i16) {
78 (self.0.as_bytes().to_vec(), oid::NUMERIC, 0)
80 }
81}
82
83fn decode_numeric_binary(bytes: &[u8]) -> Result<Numeric, TypeError> {
85 if bytes.len() < 8 {
86 return Err(TypeError::InvalidData("NUMERIC too short".to_string()));
87 }
88
89 let ndigits = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
97 let weight = i16::from_be_bytes([bytes[2], bytes[3]]);
98 let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
99 let dscale = u16::from_be_bytes([bytes[6], bytes[7]]) as usize;
100
101 if bytes.len() < 8 + ndigits * 2 {
102 return Err(TypeError::InvalidData("NUMERIC truncated".to_string()));
103 }
104
105 if sign == 0xC000 {
106 return Ok(Numeric("NaN".to_string()));
107 }
108 if !matches!(sign, 0 | 0x4000) {
109 return Err(TypeError::InvalidData(format!(
110 "NUMERIC sign out of range: {sign:#06x}"
111 )));
112 }
113
114 if ndigits == 0 {
115 return Ok(Numeric("0".to_string()));
116 }
117
118 let mut digits = Vec::with_capacity(ndigits);
119 for i in 0..ndigits {
120 let d = u16::from_be_bytes([bytes[8 + i * 2], bytes[9 + i * 2]]);
121 if d > 9999 {
122 return Err(TypeError::InvalidData(format!(
123 "NUMERIC digit out of range: {}",
124 d
125 )));
126 }
127 digits.push(d);
128 }
129
130 let mut result = String::new();
131 if sign == 0x4000 {
132 result.push('-');
133 }
134
135 let int_digits = i32::from(weight) + 1;
138 if int_digits > 0 {
139 let int_digits = int_digits as usize;
140 for i in 0..int_digits {
141 let digit = digits.get(i).copied().unwrap_or(0);
142 if i == 0 {
143 result.push_str(&digit.to_string());
144 } else {
145 result.push_str(&format!("{:04}", digit));
146 }
147 }
148 }
149
150 if result.is_empty() || result == "-" {
151 result.push('0');
152 }
153
154 if dscale > 0 {
156 result.push('.');
157 let mut fractional = String::new();
158 if int_digits < 0 {
159 for _ in 0..(-int_digits) {
160 fractional.push_str("0000");
161 }
162 }
163
164 let start = int_digits.max(0) as usize;
165 for digit in digits.iter().skip(start) {
166 fractional.push_str(&format!("{:04}", digit));
167 }
168
169 if fractional.len() < dscale {
170 fractional.extend(std::iter::repeat_n('0', dscale - fractional.len()));
171 } else {
172 fractional.truncate(dscale);
173 }
174 result.push_str(&fractional);
175 }
176
177 Ok(Numeric(result))
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_numeric_from_text() {
186 let n = Numeric::from_pg(b"123.456", oid::NUMERIC, 0).unwrap();
187 assert_eq!(n.0, "123.456");
188 assert!((n.to_f64().unwrap() - 123.456).abs() < 0.0001);
189 }
190
191 #[test]
192 fn test_numeric_to_i64() {
193 let n = Numeric::new("12345.67");
194 assert_eq!(n.to_i64().unwrap(), 12345);
195 }
196
197 #[test]
198 fn test_numeric_to_i64_exact_rejects_fractional_values() {
199 assert_eq!(Numeric::new("12345.00").to_i64_exact().unwrap(), 12345);
200 assert!(Numeric::new("12345.67").to_i64_exact().is_err());
201 }
202
203 #[test]
204 fn test_numeric_negative() {
205 let n = Numeric::new("-999.99");
206 assert_eq!(n.to_f64().unwrap(), -999.99);
207 }
208
209 #[test]
210 fn test_numeric_binary_decodes_negative_weight() {
211 let mut bytes = Vec::new();
212 bytes.extend_from_slice(&1u16.to_be_bytes()); bytes.extend_from_slice(&(-2i16).to_be_bytes()); bytes.extend_from_slice(&0u16.to_be_bytes()); bytes.extend_from_slice(&8u16.to_be_bytes()); bytes.extend_from_slice(&1u16.to_be_bytes()); let n = Numeric::from_pg(&bytes, oid::NUMERIC, 1).unwrap();
219 assert_eq!(n.as_str(), "0.00000001");
220 }
221
222 #[test]
223 fn test_numeric_binary_rejects_out_of_range_digits() {
224 let mut bytes = Vec::new();
225 bytes.extend_from_slice(&1u16.to_be_bytes()); bytes.extend_from_slice(&0i16.to_be_bytes()); bytes.extend_from_slice(&0u16.to_be_bytes()); bytes.extend_from_slice(&0u16.to_be_bytes()); bytes.extend_from_slice(&10000u16.to_be_bytes()); let err = Numeric::from_pg(&bytes, oid::NUMERIC, 1).unwrap_err();
232 assert!(matches!(err, TypeError::InvalidData(msg) if msg.contains("out of range")));
233 }
234
235 #[test]
236 fn test_numeric_binary_rejects_unknown_sign_code() {
237 let mut bytes = Vec::new();
238 bytes.extend_from_slice(&1u16.to_be_bytes()); bytes.extend_from_slice(&0i16.to_be_bytes()); bytes.extend_from_slice(&0x2000u16.to_be_bytes()); bytes.extend_from_slice(&0u16.to_be_bytes()); bytes.extend_from_slice(&123u16.to_be_bytes()); let err = Numeric::from_pg(&bytes, oid::NUMERIC, 1).unwrap_err();
245 assert!(matches!(err, TypeError::InvalidData(msg) if msg.contains("sign out of range")));
246 }
247}