Skip to main content

heliosdb_proxy/backend/
types.rs

1//! PostgreSQL type OIDs and text-format value encoding/decoding.
2//!
3//! This is intentionally narrow — it covers the seven common OIDs needed
4//! for TR-management queries (`pg_is_in_recovery`, `pg_last_wal_replay_lsn`,
5//! failover status, tenant quota config, etc.). Binary format is out of
6//! scope; everything rides the simple-query text path.
7
8use super::error::{BackendError, BackendResult};
9use chrono::{DateTime, FixedOffset};
10
11/// PostgreSQL type OIDs the backend client understands.
12///
13/// These line up with `pg_type.oid` values in system catalogs. Clients
14/// that receive other OIDs fall back to returning the raw UTF-8 string
15/// — callers that need strict typing should check the OID and refuse
16/// to interpret unfamiliar ones.
17pub mod oid {
18    pub const BOOL: u32 = 16;
19    pub const INT8: u32 = 20;
20    pub const INT4: u32 = 23;
21    pub const TEXT: u32 = 25;
22    pub const FLOAT8: u32 = 701;
23    pub const TIMESTAMPTZ: u32 = 1184;
24    pub const NUMERIC: u32 = 1700;
25    /// PG_LSN — used by WAL-position queries on replicas.
26    pub const PG_LSN: u32 = 3220;
27}
28
29/// A single column's text-format value as received from the backend.
30///
31/// Held as a `Cow<str>` view into the row bytes; callers typically turn
32/// it into a typed Rust value via the `as_<type>` helpers below.
33#[derive(Debug, Clone, PartialEq)]
34pub enum TextValue {
35    /// The column is SQL NULL.
36    Null,
37    /// Raw UTF-8 bytes as sent by the server.
38    Text(String),
39}
40
41impl TextValue {
42    /// Return `true` if this value is SQL NULL.
43    pub fn is_null(&self) -> bool {
44        matches!(self, TextValue::Null)
45    }
46
47    /// Borrow the underlying string if not NULL.
48    pub fn as_str(&self) -> Option<&str> {
49        match self {
50            TextValue::Null => None,
51            TextValue::Text(s) => Some(s.as_str()),
52        }
53    }
54
55    /// Consume and return the inner `String`, or `None` for NULL.
56    pub fn into_string(self) -> Option<String> {
57        match self {
58            TextValue::Null => None,
59            TextValue::Text(s) => Some(s),
60        }
61    }
62
63    /// Decode as `bool`. PostgreSQL text format is `"t"` or `"f"`.
64    pub fn as_bool(&self, column: &str) -> BackendResult<Option<bool>> {
65        match self {
66            TextValue::Null => Ok(None),
67            TextValue::Text(s) => match s.as_str() {
68                "t" | "true" | "TRUE" => Ok(Some(true)),
69                "f" | "false" | "FALSE" => Ok(Some(false)),
70                other => Err(BackendError::ParseValue {
71                    column: column.to_string(),
72                    reason: format!("expected bool ('t'|'f'), got {:?}", other),
73                }),
74            },
75        }
76    }
77
78    /// Decode as `i64` (covers INT4 and INT8 at the wire level — text
79    /// format is the same).
80    pub fn as_i64(&self, column: &str) -> BackendResult<Option<i64>> {
81        match self {
82            TextValue::Null => Ok(None),
83            TextValue::Text(s) => s.parse::<i64>().map(Some).map_err(|e| {
84                BackendError::ParseValue {
85                    column: column.to_string(),
86                    reason: format!("i64: {}", e),
87                }
88            }),
89        }
90    }
91
92    /// Decode as `f64` (FLOAT8 / double precision).
93    pub fn as_f64(&self, column: &str) -> BackendResult<Option<f64>> {
94        match self {
95            TextValue::Null => Ok(None),
96            TextValue::Text(s) => s.parse::<f64>().map(Some).map_err(|e| {
97                BackendError::ParseValue {
98                    column: column.to_string(),
99                    reason: format!("f64: {}", e),
100                }
101            }),
102        }
103    }
104
105    /// Decode as `DateTime<FixedOffset>` (TIMESTAMPTZ). PG text format
106    /// with a timezone offset: `2026-04-24 12:34:56.789+00`.
107    pub fn as_timestamptz(
108        &self,
109        column: &str,
110    ) -> BackendResult<Option<DateTime<FixedOffset>>> {
111        match self {
112            TextValue::Null => Ok(None),
113            TextValue::Text(s) => {
114                // PG emits a space between date and time by default; RFC3339
115                // wants a 'T'. Support either.
116                let normalised = if s.contains(' ') && !s.contains('T') {
117                    s.replacen(' ', "T", 1)
118                } else {
119                    s.clone()
120                };
121                // Append minutes to a bare hour offset: "+00" -> "+00:00".
122                let normalised =
123                    if let Some(idx) = normalised.rfind(['+', '-']) {
124                        let off = &normalised[idx + 1..];
125                        if off.len() == 2 && off.bytes().all(|b| b.is_ascii_digit()) {
126                            format!("{}:00", normalised)
127                        } else {
128                            normalised
129                        }
130                    } else {
131                        normalised
132                    };
133                DateTime::parse_from_rfc3339(&normalised)
134                    .map(Some)
135                    .map_err(|e| BackendError::ParseValue {
136                        column: column.to_string(),
137                        reason: format!("timestamptz {:?}: {}", s, e),
138                    })
139            }
140        }
141    }
142
143    /// Decode as a textual pg_lsn (e.g. `"0/16B3758"`). We leave LSN
144    /// arithmetic to callers — string form is what `pg_last_wal_*_lsn()`
145    /// returns in text format, and the natural lex order on these
146    /// strings matches WAL ordering for positions in the same
147    /// timeline.
148    pub fn as_pg_lsn(&self, column: &str) -> BackendResult<Option<String>> {
149        match self {
150            TextValue::Null => Ok(None),
151            TextValue::Text(s) => {
152                // Validate shape: H[H..]/H[H..]
153                if let Some((hi, lo)) = s.split_once('/') {
154                    let hex_ok = |p: &str| {
155                        !p.is_empty() && p.bytes().all(|b| b.is_ascii_hexdigit())
156                    };
157                    if hex_ok(hi) && hex_ok(lo) {
158                        return Ok(Some(s.clone()));
159                    }
160                }
161                Err(BackendError::ParseValue {
162                    column: column.to_string(),
163                    reason: format!("pg_lsn {:?}: expected 'H/H' hex pair", s),
164                })
165            }
166        }
167    }
168
169    /// Decode as `NUMERIC` — text is PG's canonical form (e.g. "3.1415").
170    /// We return the raw string; callers that need arithmetic can route
171    /// to `rust_decimal` or parse further.
172    pub fn as_numeric(&self, column: &str) -> BackendResult<Option<String>> {
173        match self {
174            TextValue::Null => Ok(None),
175            TextValue::Text(s) => {
176                // Reject only obviously malformed values — a single optional
177                // sign, digits, optional single dot, optional exponent.
178                let bytes = s.as_bytes();
179                let mut i = 0;
180                if bytes.first().map_or(false, |&b| b == b'+' || b == b'-') {
181                    i += 1;
182                }
183                let mut saw_digit = false;
184                let mut saw_dot = false;
185                while i < bytes.len() {
186                    let b = bytes[i];
187                    if b.is_ascii_digit() {
188                        saw_digit = true;
189                    } else if b == b'.' && !saw_dot {
190                        saw_dot = true;
191                    } else if (b == b'e' || b == b'E') && saw_digit {
192                        // Exponent form — stop validating shape, accept.
193                        saw_digit = true;
194                        break;
195                    } else if s.eq_ignore_ascii_case("NaN") {
196                        return Ok(Some("NaN".to_string()));
197                    } else {
198                        return Err(BackendError::ParseValue {
199                            column: column.to_string(),
200                            reason: format!("numeric {:?}", s),
201                        });
202                    }
203                    i += 1;
204                }
205                if saw_digit {
206                    Ok(Some(s.clone()))
207                } else {
208                    Err(BackendError::ParseValue {
209                        column: column.to_string(),
210                        reason: format!("numeric {:?}: no digits", s),
211                    })
212                }
213            }
214        }
215    }
216}
217
218/// Encode a Rust value as a PostgreSQL text-format parameter.
219///
220/// The backend client substitutes parameters into simple-query SQL by
221/// properly-quoting literals; we do not use the extended protocol here.
222/// This function produces the already-quoted literal (e.g. `'alice'::text`
223/// for a string, `42` for an i64).
224///
225/// Implementations are deliberately tight — enough to serialise the
226/// argument values TR-management queries need. Callers with richer
227/// types should extend this set.
228pub fn encode_literal(v: &ParamValue) -> String {
229    match v {
230        ParamValue::Null => "NULL".to_string(),
231        ParamValue::Bool(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
232        ParamValue::Int(i) => i.to_string(),
233        ParamValue::Float(f) => {
234            // Match PG conventions: `NaN` and `Infinity` are unquoted
235            // identifiers inside numeric context, but for parameters we
236            // pass them via the text format cast.
237            if f.is_nan() {
238                "'NaN'::float8".to_string()
239            } else if f.is_infinite() {
240                if *f > 0.0 {
241                    "'Infinity'::float8".to_string()
242                } else {
243                    "'-Infinity'::float8".to_string()
244                }
245            } else {
246                format!("{:?}", f) // {:?} preserves precision round-trip
247            }
248        }
249        ParamValue::Text(s) => {
250            // PG simple-query string literal: wrap in single quotes,
251            // escape embedded single quotes by doubling.
252            let mut out = String::with_capacity(s.len() + 2);
253            out.push('\'');
254            for ch in s.chars() {
255                if ch == '\'' {
256                    out.push_str("''");
257                } else {
258                    out.push(ch);
259                }
260            }
261            out.push('\'');
262            out
263        }
264        ParamValue::Lsn(s) => format!("'{}'::pg_lsn", s),
265    }
266}
267
268/// Minimal parameter-value enum covering the seven supported OIDs.
269///
270/// Kept intentionally small — this is for TR-management queries, not
271/// general-purpose query execution.
272#[derive(Debug, Clone, PartialEq)]
273pub enum ParamValue {
274    Null,
275    Bool(bool),
276    Int(i64),
277    Float(f64),
278    Text(String),
279    Lsn(String),
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_text_value_bool() {
288        let t = TextValue::Text("t".to_string());
289        assert_eq!(t.as_bool("x").unwrap(), Some(true));
290        let f = TextValue::Text("f".to_string());
291        assert_eq!(f.as_bool("x").unwrap(), Some(false));
292        let n = TextValue::Null;
293        assert_eq!(n.as_bool("x").unwrap(), None);
294        let bad = TextValue::Text("maybe".to_string());
295        assert!(bad.as_bool("x").is_err());
296    }
297
298    #[test]
299    fn test_text_value_i64() {
300        assert_eq!(
301            TextValue::Text("42".to_string()).as_i64("x").unwrap(),
302            Some(42)
303        );
304        assert_eq!(
305            TextValue::Text("-1".to_string()).as_i64("x").unwrap(),
306            Some(-1)
307        );
308        assert!(TextValue::Text("abc".to_string()).as_i64("x").is_err());
309    }
310
311    #[test]
312    fn test_text_value_f64() {
313        assert_eq!(
314            TextValue::Text("3.14".to_string()).as_f64("x").unwrap(),
315            Some(3.14)
316        );
317        assert!(TextValue::Text("oops".to_string()).as_f64("x").is_err());
318    }
319
320    #[test]
321    fn test_text_value_timestamptz_pg_format() {
322        let v = TextValue::Text("2026-04-24 12:34:56.789+00".to_string());
323        let parsed = v.as_timestamptz("ts").unwrap().expect("some");
324        assert_eq!(parsed.to_rfc3339().starts_with("2026-04-24T12:34:56.789"), true);
325    }
326
327    #[test]
328    fn test_text_value_timestamptz_rfc3339() {
329        let v = TextValue::Text("2026-04-24T12:34:56+02:00".to_string());
330        assert!(v.as_timestamptz("ts").unwrap().is_some());
331    }
332
333    #[test]
334    fn test_text_value_pg_lsn_roundtrip() {
335        assert_eq!(
336            TextValue::Text("0/16B3758".to_string())
337                .as_pg_lsn("x")
338                .unwrap(),
339            Some("0/16B3758".to_string())
340        );
341        assert!(TextValue::Text("nope".to_string()).as_pg_lsn("x").is_err());
342        assert!(TextValue::Text("/abc".to_string()).as_pg_lsn("x").is_err());
343    }
344
345    #[test]
346    fn test_text_value_numeric_accepts_valid() {
347        for s in ["0", "1", "-42", "3.14", "+1.0", "1e10", "-2.5E-3", "NaN"] {
348            assert!(
349                TextValue::Text(s.to_string()).as_numeric("x").unwrap().is_some(),
350                "should accept {:?}",
351                s
352            );
353        }
354    }
355
356    #[test]
357    fn test_text_value_numeric_rejects_invalid() {
358        for s in ["", "abc", "1..2", "-", "+"] {
359            assert!(
360                TextValue::Text(s.to_string()).as_numeric("x").is_err(),
361                "should reject {:?}",
362                s
363            );
364        }
365    }
366
367    #[test]
368    fn test_encode_literal_null_bool_int() {
369        assert_eq!(encode_literal(&ParamValue::Null), "NULL");
370        assert_eq!(encode_literal(&ParamValue::Bool(true)), "TRUE");
371        assert_eq!(encode_literal(&ParamValue::Bool(false)), "FALSE");
372        assert_eq!(encode_literal(&ParamValue::Int(-7)), "-7");
373    }
374
375    #[test]
376    fn test_encode_literal_text_escapes_single_quote() {
377        assert_eq!(encode_literal(&ParamValue::Text("a'b".to_string())), "'a''b'");
378        assert_eq!(encode_literal(&ParamValue::Text("plain".to_string())), "'plain'");
379    }
380
381    #[test]
382    fn test_encode_literal_lsn() {
383        assert_eq!(
384            encode_literal(&ParamValue::Lsn("0/16B3758".to_string())),
385            "'0/16B3758'::pg_lsn"
386        );
387    }
388
389    #[test]
390    fn test_encode_literal_float_special() {
391        assert_eq!(
392            encode_literal(&ParamValue::Float(f64::NAN)),
393            "'NaN'::float8"
394        );
395        assert_eq!(
396            encode_literal(&ParamValue::Float(f64::INFINITY)),
397            "'Infinity'::float8"
398        );
399        assert_eq!(
400            encode_literal(&ParamValue::Float(f64::NEG_INFINITY)),
401            "'-Infinity'::float8"
402        );
403    }
404}