Skip to main content

heliosdb_proxy/backend/
types.rs

1//! PostgreSQL type OIDs and text-format value encoding/decoding.
2//!
3//! This is intentionally narrow — it covers the seven common OIDs needed
4//! for TR-management queries (`pg_is_in_recovery`, `pg_last_wal_replay_lsn`,
5//! failover status, tenant quota config, etc.). Binary format is out of
6//! scope; everything rides the simple-query text path.
7
8use super::error::{BackendError, BackendResult};
9use chrono::{DateTime, FixedOffset};
10
11/// PostgreSQL type OIDs the backend client understands.
12///
13/// These line up with `pg_type.oid` values in system catalogs. Clients
14/// that receive other OIDs fall back to returning the raw UTF-8 string
15/// — callers that need strict typing should check the OID and refuse
16/// to interpret unfamiliar ones.
17pub mod oid {
18    pub const BOOL: u32 = 16;
19    pub const INT8: u32 = 20;
20    pub const INT4: u32 = 23;
21    pub const TEXT: u32 = 25;
22    pub const FLOAT8: u32 = 701;
23    pub const TIMESTAMPTZ: u32 = 1184;
24    pub const NUMERIC: u32 = 1700;
25    /// PG_LSN — used by WAL-position queries on replicas.
26    pub const PG_LSN: u32 = 3220;
27}
28
29/// A single column's text-format value as received from the backend.
30///
31/// Held as a `Cow<str>` view into the row bytes; callers typically turn
32/// it into a typed Rust value via the `as_<type>` helpers below.
33#[derive(Debug, Clone, PartialEq)]
34pub enum TextValue {
35    /// The column is SQL NULL.
36    Null,
37    /// Raw UTF-8 bytes as sent by the server.
38    Text(String),
39}
40
41impl TextValue {
42    /// Return `true` if this value is SQL NULL.
43    pub fn is_null(&self) -> bool {
44        matches!(self, TextValue::Null)
45    }
46
47    /// Borrow the underlying string if not NULL.
48    pub fn as_str(&self) -> Option<&str> {
49        match self {
50            TextValue::Null => None,
51            TextValue::Text(s) => Some(s.as_str()),
52        }
53    }
54
55    /// Consume and return the inner `String`, or `None` for NULL.
56    pub fn into_string(self) -> Option<String> {
57        match self {
58            TextValue::Null => None,
59            TextValue::Text(s) => Some(s),
60        }
61    }
62
63    /// Decode as `bool`. PostgreSQL text format is `"t"` or `"f"`.
64    pub fn as_bool(&self, column: &str) -> BackendResult<Option<bool>> {
65        match self {
66            TextValue::Null => Ok(None),
67            TextValue::Text(s) => match s.as_str() {
68                "t" | "true" | "TRUE" => Ok(Some(true)),
69                "f" | "false" | "FALSE" => Ok(Some(false)),
70                other => Err(BackendError::ParseValue {
71                    column: column.to_string(),
72                    reason: format!("expected bool ('t'|'f'), got {:?}", other),
73                }),
74            },
75        }
76    }
77
78    /// Decode as `i64` (covers INT4 and INT8 at the wire level — text
79    /// format is the same).
80    pub fn as_i64(&self, column: &str) -> BackendResult<Option<i64>> {
81        match self {
82            TextValue::Null => Ok(None),
83            TextValue::Text(s) => {
84                s.parse::<i64>()
85                    .map(Some)
86                    .map_err(|e| BackendError::ParseValue {
87                        column: column.to_string(),
88                        reason: format!("i64: {}", e),
89                    })
90            }
91        }
92    }
93
94    /// Decode as `f64` (FLOAT8 / double precision).
95    pub fn as_f64(&self, column: &str) -> BackendResult<Option<f64>> {
96        match self {
97            TextValue::Null => Ok(None),
98            TextValue::Text(s) => {
99                s.parse::<f64>()
100                    .map(Some)
101                    .map_err(|e| BackendError::ParseValue {
102                        column: column.to_string(),
103                        reason: format!("f64: {}", e),
104                    })
105            }
106        }
107    }
108
109    /// Decode as `DateTime<FixedOffset>` (TIMESTAMPTZ). PG text format
110    /// with a timezone offset: `2026-04-24 12:34:56.789+00`.
111    pub fn as_timestamptz(&self, column: &str) -> BackendResult<Option<DateTime<FixedOffset>>> {
112        match self {
113            TextValue::Null => Ok(None),
114            TextValue::Text(s) => {
115                // PG emits a space between date and time by default; RFC3339
116                // wants a 'T'. Support either.
117                let normalised = if s.contains(' ') && !s.contains('T') {
118                    s.replacen(' ', "T", 1)
119                } else {
120                    s.clone()
121                };
122                // Append minutes to a bare hour offset: "+00" -> "+00:00".
123                let normalised = if let Some(idx) = normalised.rfind(['+', '-']) {
124                    let off = &normalised[idx + 1..];
125                    if off.len() == 2 && off.bytes().all(|b| b.is_ascii_digit()) {
126                        format!("{}:00", normalised)
127                    } else {
128                        normalised
129                    }
130                } else {
131                    normalised
132                };
133                DateTime::parse_from_rfc3339(&normalised)
134                    .map(Some)
135                    .map_err(|e| BackendError::ParseValue {
136                        column: column.to_string(),
137                        reason: format!("timestamptz {:?}: {}", s, e),
138                    })
139            }
140        }
141    }
142
143    /// Decode as a textual pg_lsn (e.g. `"0/16B3758"`). We leave LSN
144    /// arithmetic to callers — string form is what `pg_last_wal_*_lsn()`
145    /// returns in text format, and the natural lex order on these
146    /// strings matches WAL ordering for positions in the same
147    /// timeline.
148    pub fn as_pg_lsn(&self, column: &str) -> BackendResult<Option<String>> {
149        match self {
150            TextValue::Null => Ok(None),
151            TextValue::Text(s) => {
152                // Validate shape: H[H..]/H[H..]
153                if let Some((hi, lo)) = s.split_once('/') {
154                    let hex_ok =
155                        |p: &str| !p.is_empty() && p.bytes().all(|b| b.is_ascii_hexdigit());
156                    if hex_ok(hi) && hex_ok(lo) {
157                        return Ok(Some(s.clone()));
158                    }
159                }
160                Err(BackendError::ParseValue {
161                    column: column.to_string(),
162                    reason: format!("pg_lsn {:?}: expected 'H/H' hex pair", s),
163                })
164            }
165        }
166    }
167
168    /// Decode as `NUMERIC` — text is PG's canonical form (e.g. "3.1415").
169    /// We return the raw string; callers that need arithmetic can route
170    /// to `rust_decimal` or parse further.
171    pub fn as_numeric(&self, column: &str) -> BackendResult<Option<String>> {
172        match self {
173            TextValue::Null => Ok(None),
174            TextValue::Text(s) => {
175                // Reject only obviously malformed values — a single optional
176                // sign, digits, optional single dot, optional exponent.
177                let bytes = s.as_bytes();
178                let mut i = 0;
179                if bytes.first().is_some_and(|&b| b == b'+' || b == b'-') {
180                    i += 1;
181                }
182                let mut saw_digit = false;
183                let mut saw_dot = false;
184                while i < bytes.len() {
185                    let b = bytes[i];
186                    if b.is_ascii_digit() {
187                        saw_digit = true;
188                    } else if b == b'.' && !saw_dot {
189                        saw_dot = true;
190                    } else if (b == b'e' || b == b'E') && saw_digit {
191                        // Exponent form — stop validating shape, accept.
192                        saw_digit = true;
193                        break;
194                    } else if s.eq_ignore_ascii_case("NaN") {
195                        return Ok(Some("NaN".to_string()));
196                    } else {
197                        return Err(BackendError::ParseValue {
198                            column: column.to_string(),
199                            reason: format!("numeric {:?}", s),
200                        });
201                    }
202                    i += 1;
203                }
204                if saw_digit {
205                    Ok(Some(s.clone()))
206                } else {
207                    Err(BackendError::ParseValue {
208                        column: column.to_string(),
209                        reason: format!("numeric {:?}: no digits", s),
210                    })
211                }
212            }
213        }
214    }
215}
216
217/// Encode a Rust value as a PostgreSQL text-format parameter.
218///
219/// The backend client substitutes parameters into simple-query SQL by
220/// properly-quoting literals; we do not use the extended protocol here.
221/// This function produces the already-quoted literal (e.g. `'alice'::text`
222/// for a string, `42` for an i64).
223///
224/// Implementations are deliberately tight — enough to serialise the
225/// argument values TR-management queries need. Callers with richer
226/// types should extend this set.
227pub fn encode_literal(v: &ParamValue) -> String {
228    match v {
229        ParamValue::Null => "NULL".to_string(),
230        ParamValue::Bool(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
231        ParamValue::Int(i) => i.to_string(),
232        ParamValue::Float(f) => {
233            // Match PG conventions: `NaN` and `Infinity` are unquoted
234            // identifiers inside numeric context, but for parameters we
235            // pass them via the text format cast.
236            if f.is_nan() {
237                "'NaN'::float8".to_string()
238            } else if f.is_infinite() {
239                if *f > 0.0 {
240                    "'Infinity'::float8".to_string()
241                } else {
242                    "'-Infinity'::float8".to_string()
243                }
244            } else {
245                format!("{:?}", f) // {:?} preserves precision round-trip
246            }
247        }
248        ParamValue::Text(s) => {
249            // PG simple-query string literal: wrap in single quotes,
250            // escape embedded single quotes by doubling.
251            let mut out = String::with_capacity(s.len() + 2);
252            out.push('\'');
253            for ch in s.chars() {
254                if ch == '\'' {
255                    out.push_str("''");
256                } else {
257                    out.push(ch);
258                }
259            }
260            out.push('\'');
261            out
262        }
263        ParamValue::Lsn(s) => format!("'{}'::pg_lsn", s),
264    }
265}
266
267/// Minimal parameter-value enum covering the seven supported OIDs.
268///
269/// Kept intentionally small — this is for TR-management queries, not
270/// general-purpose query execution.
271#[derive(Debug, Clone, PartialEq)]
272pub enum ParamValue {
273    Null,
274    Bool(bool),
275    Int(i64),
276    Float(f64),
277    Text(String),
278    Lsn(String),
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_text_value_bool() {
287        let t = TextValue::Text("t".to_string());
288        assert_eq!(t.as_bool("x").unwrap(), Some(true));
289        let f = TextValue::Text("f".to_string());
290        assert_eq!(f.as_bool("x").unwrap(), Some(false));
291        let n = TextValue::Null;
292        assert_eq!(n.as_bool("x").unwrap(), None);
293        let bad = TextValue::Text("maybe".to_string());
294        assert!(bad.as_bool("x").is_err());
295    }
296
297    #[test]
298    fn test_text_value_i64() {
299        assert_eq!(
300            TextValue::Text("42".to_string()).as_i64("x").unwrap(),
301            Some(42)
302        );
303        assert_eq!(
304            TextValue::Text("-1".to_string()).as_i64("x").unwrap(),
305            Some(-1)
306        );
307        assert!(TextValue::Text("abc".to_string()).as_i64("x").is_err());
308    }
309
310    #[test]
311    fn test_text_value_f64() {
312        assert_eq!(
313            TextValue::Text("3.14".to_string()).as_f64("x").unwrap(),
314            Some(3.14)
315        );
316        assert!(TextValue::Text("oops".to_string()).as_f64("x").is_err());
317    }
318
319    #[test]
320    fn test_text_value_timestamptz_pg_format() {
321        let v = TextValue::Text("2026-04-24 12:34:56.789+00".to_string());
322        let parsed = v.as_timestamptz("ts").unwrap().expect("some");
323        assert_eq!(
324            parsed.to_rfc3339().starts_with("2026-04-24T12:34:56.789"),
325            true
326        );
327    }
328
329    #[test]
330    fn test_text_value_timestamptz_rfc3339() {
331        let v = TextValue::Text("2026-04-24T12:34:56+02:00".to_string());
332        assert!(v.as_timestamptz("ts").unwrap().is_some());
333    }
334
335    #[test]
336    fn test_text_value_pg_lsn_roundtrip() {
337        assert_eq!(
338            TextValue::Text("0/16B3758".to_string())
339                .as_pg_lsn("x")
340                .unwrap(),
341            Some("0/16B3758".to_string())
342        );
343        assert!(TextValue::Text("nope".to_string()).as_pg_lsn("x").is_err());
344        assert!(TextValue::Text("/abc".to_string()).as_pg_lsn("x").is_err());
345    }
346
347    #[test]
348    fn test_text_value_numeric_accepts_valid() {
349        for s in ["0", "1", "-42", "3.14", "+1.0", "1e10", "-2.5E-3", "NaN"] {
350            assert!(
351                TextValue::Text(s.to_string())
352                    .as_numeric("x")
353                    .unwrap()
354                    .is_some(),
355                "should accept {:?}",
356                s
357            );
358        }
359    }
360
361    #[test]
362    fn test_text_value_numeric_rejects_invalid() {
363        for s in ["", "abc", "1..2", "-", "+"] {
364            assert!(
365                TextValue::Text(s.to_string()).as_numeric("x").is_err(),
366                "should reject {:?}",
367                s
368            );
369        }
370    }
371
372    #[test]
373    fn test_encode_literal_null_bool_int() {
374        assert_eq!(encode_literal(&ParamValue::Null), "NULL");
375        assert_eq!(encode_literal(&ParamValue::Bool(true)), "TRUE");
376        assert_eq!(encode_literal(&ParamValue::Bool(false)), "FALSE");
377        assert_eq!(encode_literal(&ParamValue::Int(-7)), "-7");
378    }
379
380    #[test]
381    fn test_encode_literal_text_escapes_single_quote() {
382        assert_eq!(
383            encode_literal(&ParamValue::Text("a'b".to_string())),
384            "'a''b'"
385        );
386        assert_eq!(
387            encode_literal(&ParamValue::Text("plain".to_string())),
388            "'plain'"
389        );
390    }
391
392    #[test]
393    fn test_encode_literal_lsn() {
394        assert_eq!(
395            encode_literal(&ParamValue::Lsn("0/16B3758".to_string())),
396            "'0/16B3758'::pg_lsn"
397        );
398    }
399
400    #[test]
401    fn test_encode_literal_float_special() {
402        assert_eq!(
403            encode_literal(&ParamValue::Float(f64::NAN)),
404            "'NaN'::float8"
405        );
406        assert_eq!(
407            encode_literal(&ParamValue::Float(f64::INFINITY)),
408            "'Infinity'::float8"
409        );
410        assert_eq!(
411            encode_literal(&ParamValue::Float(f64::NEG_INFINITY)),
412            "'-Infinity'::float8"
413        );
414    }
415}