Skip to main content

ngb_postgres_helper/
lib.rs

1mod try_execution;
2mod execution;
3
4use std::collections::HashMap;
5use std::sync::LazyLock;
6use regex::Regex;
7use tokio_postgres::types::{ToSql};
8pub use execution::*;
9pub use try_execution::*;
10
11pub struct Params<'q>(HashMap<&'q str, &'q (dyn ToSql + Sync)>);
12const REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"@(\w+)").unwrap());
13pub enum Error{
14    Postgres(tokio_postgres::Error),
15    TryFrom(Box<dyn std::error::Error + 'static>),
16}
17
18impl<'q> Params<'q> {
19    pub fn new() -> Self {
20        Self(HashMap::new())
21    }
22
23    pub fn with_capacity(capacity: usize) -> Self {
24        Self(HashMap::with_capacity(capacity))
25    }
26
27    pub fn from(arr: &'q [(&str, &(dyn ToSql + Sync))]) -> Self {
28        let mut params = HashMap::with_capacity(arr.len());
29        for (k, v) in arr {
30            params.insert(*k, *v);
31        }
32        Self(params)
33    }
34
35    pub fn add(&mut self, key: &'q str, value: &'q (dyn ToSql + Sync)) -> &mut Self {
36        self.0.insert(key, value);
37        self
38    }
39
40    pub fn len(&self) -> usize {
41        self.0.len()
42    }
43}
44
45impl<'q> Into<Params<'q>> for HashMap<&'q str, &'q (dyn ToSql + Sync)> {
46    fn into(self) -> Params<'q> {
47        Params(self)
48    }
49}
50
51fn compile_sql<'q>(sql: &str, params: &Params<'q>) -> (String, Vec<&'q (dyn ToSql + Sync)>) {
52    let mut clean_sql: Vec<u8> = Vec::with_capacity(sql.len());
53    let mut char_idx = 0;
54    let mut count = 1;
55    let mut clean_params = Vec::with_capacity(params.len());
56
57    let bytes = sql.as_bytes();
58
59    for found in REGEX.find_iter(sql) {
60        let key = &found.as_str()[1..];
61        let start = found.start();
62        let end = found.end();
63
64        clean_sql.extend_from_slice(&bytes[char_idx..start]);
65        let p = params.0.get(key);
66
67        if let Some(value) = p {
68            clean_sql.push('$' as u8);
69            append_index(&mut clean_sql, count);
70            clean_params.push(*value);
71            count = count + 1;
72        } else {
73            clean_sql.extend_from_slice(&found.as_str().as_bytes());
74            // clean_sql.push('@' as u8);
75            // clean_sql.extend_from_slice(key.as_bytes());
76        }
77        char_idx = end;
78    }
79    clean_sql.extend_from_slice(&bytes[char_idx..]);
80
81    unsafe {
82        let result = String::from_utf8_unchecked(clean_sql);
83        (result, clean_params)
84    }
85}
86
87fn append_index(s: &mut Vec<u8>, i: usize) {
88    if i < 10 {
89        s.push((i + 0x30) as u8);
90        return;
91    }
92    append_index(s, i / 10);
93    append_index(s, i % 10);
94}