clickhouse_postgres_client/
connect_options.rs

1use core::{
2    ops::{Deref, DerefMut},
3    str::FromStr,
4};
5use std::env::{set_var, var, VarError};
6
7use sqlx_clickhouse_ext::sqlx_core::{error::Error, postgres::PgConnectOptions};
8use url::Url;
9
10#[derive(Debug, Clone)]
11pub struct ClickhousePgConnectOptions {
12    pub(crate) inner: PgConnectOptions,
13}
14impl ClickhousePgConnectOptions {
15    pub fn new() -> Self {
16        update_env();
17
18        Self {
19            inner: PgConnectOptions::new(),
20        }
21    }
22
23    pub fn into_inner(self) -> PgConnectOptions {
24        self.inner
25    }
26}
27
28impl Default for ClickhousePgConnectOptions {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl Deref for ClickhousePgConnectOptions {
35    type Target = PgConnectOptions;
36
37    fn deref(&self) -> &Self::Target {
38        &self.inner
39    }
40}
41impl DerefMut for ClickhousePgConnectOptions {
42    fn deref_mut(&mut self) -> &mut Self::Target {
43        &mut self.inner
44    }
45}
46
47impl FromStr for ClickhousePgConnectOptions {
48    type Err = Error;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        update_env();
52
53        let s = update_url(s)?;
54
55        PgConnectOptions::from_str(&s).map(|inner| Self { inner })
56    }
57}
58
59//
60const PORT_DEFAULT_STR: &str = "9005";
61const SSL_MODE_PREFER: &str = "prefer";
62const SSL_MODE_DISABLE: &str = "disable";
63
64fn update_env() {
65    if let Err(VarError::NotPresent) = var("PGPORT") {
66        set_var("PGPORT", PORT_DEFAULT_STR)
67    }
68
69    match var("PGSSLMODE") {
70        Ok(str) if str == SSL_MODE_PREFER => set_var("PGSSLMODE", SSL_MODE_DISABLE),
71        Err(VarError::NotPresent) => set_var("PGSSLMODE", SSL_MODE_DISABLE),
72        _ => (),
73    }
74}
75
76fn update_url(s: &str) -> Result<String, Error> {
77    let mut url: Url = s
78        .parse()
79        .map_err(|err: url::ParseError| Error::Configuration(err.into()))?;
80
81    url.query_pairs()
82        .map(|(k, v)| (k.to_string(), v.to_string()))
83        .collect::<Vec<_>>()
84        .into_iter()
85        .fold(url.query_pairs_mut().clear(), |ser, (key, value)| {
86            match key.as_ref() {
87                "sslmode" => {
88                    if value == SSL_MODE_PREFER {
89                        ser.append_pair(&key, SSL_MODE_DISABLE);
90                    } else {
91                        ser.append_pair(&key, &value);
92                    }
93                }
94                "ssl-mode" => {
95                    if value == SSL_MODE_PREFER {
96                        ser.append_pair(&key, SSL_MODE_DISABLE);
97                    } else {
98                        ser.append_pair(&key, &value);
99                    }
100                }
101                _ => {
102                    ser.append_pair(&key, &value);
103                }
104            };
105            ser
106        });
107
108    Ok(url.to_string())
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    use std::env::remove_var;
116
117    #[test]
118    fn test_update_env() {
119        remove_var("PGPORT");
120        remove_var("PGSSLMODE");
121        update_env();
122        assert_eq!(var("PGPORT").unwrap(), "9005");
123        assert_eq!(var("PGSSLMODE").unwrap(), "disable");
124
125        remove_var("PGPORT");
126        remove_var("PGSSLMODE");
127        set_var("PGSSLMODE", "prefer");
128        update_env();
129        assert_eq!(var("PGPORT").unwrap(), "9005");
130        assert_eq!(var("PGSSLMODE").unwrap(), "disable");
131    }
132
133    #[test]
134    fn test_update_url() {
135        let uri = "postgres:///?sslmode=prefer";
136        assert_eq!(update_url(uri).unwrap(), "postgres:///?sslmode=disable");
137
138        let uri = "postgres:///?ssl-mode=prefer";
139        assert_eq!(update_url(uri).unwrap(), "postgres:///?ssl-mode=disable");
140    }
141}