zero_postgres/
opts.rs

1//! Connection options.
2
3use std::sync::Arc;
4
5use url::Url;
6
7use crate::buffer_pool::{BufferPool, GLOBAL_BUFFER_POOL};
8use crate::error::Error;
9
10/// SSL connection mode.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum SslMode {
13    /// Don't use SSL
14    Disable,
15    /// Try SSL, fall back to unencrypted if not supported
16    #[default]
17    Prefer,
18    /// Require SSL connection
19    Require,
20}
21
22/// Connection options for PostgreSQL.
23#[derive(Debug, Clone)]
24pub struct Opts {
25    /// Hostname or IP address.
26    ///
27    /// Default: `""`
28    pub host: String,
29
30    /// Port number for the PostgreSQL server.
31    ///
32    /// Default: `5432`
33    pub port: u16,
34
35    /// Unix socket path.
36    ///
37    /// Default: `None`
38    pub socket: Option<String>,
39
40    /// Username for authentication.
41    ///
42    /// Default: `""`
43    pub user: String,
44
45    /// Database name to use.
46    ///
47    /// Default: `None`
48    pub database: Option<String>,
49
50    /// Password for authentication.
51    ///
52    /// Default: `None`
53    pub password: Option<String>,
54
55    /// Application name to report to the server.
56    ///
57    /// Default: `None`
58    pub application_name: Option<String>,
59
60    /// SSL connection mode.
61    ///
62    /// Default: `SslMode::Prefer`
63    pub ssl_mode: SslMode,
64
65    /// Additional connection parameters.
66    ///
67    /// Default: `[]`
68    pub params: Vec<(String, String)>,
69
70    /// When connected via TCP to loopback, upgrade to Unix socket for better performance.
71    ///
72    /// Default: `true`
73    pub upgrade_to_unix_socket: bool,
74
75    /// Maximum number of idle connections in the pool.
76    ///
77    /// Default: `100`
78    pub pool_max_idle_conn: usize,
79
80    /// Maximum number of concurrent connections (None = unlimited).
81    ///
82    /// Default: `None`
83    pub pool_max_concurrency: Option<usize>,
84
85    /// Buffer pool for reusing buffers across connections.
86    ///
87    /// Default: `GLOBAL_BUFFER_POOL`
88    pub buffer_pool: Arc<BufferPool>,
89}
90
91impl Default for Opts {
92    fn default() -> Self {
93        Self {
94            host: String::new(),
95            port: 5432,
96            socket: None,
97            user: String::new(),
98            database: None,
99            password: None,
100            application_name: None,
101            ssl_mode: SslMode::Prefer,
102            params: Vec::new(),
103            upgrade_to_unix_socket: true,
104            pool_max_idle_conn: 100,
105            pool_max_concurrency: None,
106            buffer_pool: Arc::clone(&GLOBAL_BUFFER_POOL),
107        }
108    }
109}
110
111impl TryFrom<&Url> for Opts {
112    type Error = Error;
113
114    /// Parse a PostgreSQL connection URL.
115    ///
116    /// Format: `(pg|postgres|postgresql)://[user[:password]@]host[:port][/database][?param1=value1&param2=value2&..]`
117    ///
118    /// Supported query parameters:
119    /// - `sslmode` or `ssl_mode`: disable, prefer, require
120    /// - `application_name`: application name
121    /// - `upgrade_to_unix_socket`: true/True/1/yes/on or false/False/0/no/off
122    /// - `pool_max_idle_conn`: maximum idle connections (positive integer)
123    /// - `pool_max_concurrency`: maximum concurrent connections (positive integer)
124    fn try_from(url: &Url) -> Result<Self, Self::Error> {
125        if !["pg", "postgres", "postgresql"].contains(&url.scheme()) {
126            return Err(Error::InvalidUsage(format!(
127                "Invalid scheme: expected 'pg://', 'postgres://', or 'postgresql://', got '{}://'",
128                url.scheme()
129            )));
130        }
131
132        let mut opts = Opts {
133            host: url.host_str().unwrap_or("localhost").to_string(),
134            port: url.port().unwrap_or(5432),
135            user: url.username().to_string(),
136            password: url.password().map(|s| s.to_string()),
137            database: url.path().strip_prefix('/').and_then(|s| {
138                if s.is_empty() {
139                    None
140                } else {
141                    Some(s.to_string())
142                }
143            }),
144            ..Opts::default()
145        };
146
147        for (key, value) in url.query_pairs() {
148            match key.as_ref() {
149                "sslmode" | "ssl_mode" => {
150                    opts.ssl_mode = match value.as_ref() {
151                        "disable" => SslMode::Disable,
152                        "prefer" => SslMode::Prefer,
153                        "require" => SslMode::Require,
154                        _ => {
155                            return Err(Error::InvalidUsage(format!(
156                                "Invalid sslmode: expected one of ['disable', 'prefer', 'require'], got {}",
157                                value
158                            )));
159                        }
160                    };
161                }
162                "application_name" => {
163                    opts.application_name = Some(value.to_string());
164                }
165                "upgrade_to_unix_socket" => {
166                    opts.upgrade_to_unix_socket = match value.as_ref() {
167                        "true" | "True" | "1" | "yes" | "on" => true,
168                        "false" | "False" | "0" | "no" | "off" => false,
169                        _ => {
170                            return Err(Error::InvalidUsage(format!(
171                                "Invalid upgrade_to_unix_socket parameter: {}",
172                                value
173                            )));
174                        }
175                    };
176                }
177                "pool_max_idle_conn" => {
178                    opts.pool_max_idle_conn = value.parse().map_err(|_| {
179                        Error::InvalidUsage(format!("Invalid pool_max_idle_conn: {}", value))
180                    })?;
181                }
182                "pool_max_concurrency" => {
183                    opts.pool_max_concurrency = Some(value.parse().map_err(|_| {
184                        Error::InvalidUsage(format!("Invalid pool_max_concurrency: {}", value))
185                    })?);
186                }
187                _ => {
188                    opts.params.push((key.to_string(), value.to_string()));
189                }
190            }
191        }
192
193        Ok(opts)
194    }
195}
196
197impl TryFrom<&str> for Opts {
198    type Error = Error;
199
200    fn try_from(s: &str) -> Result<Self, Self::Error> {
201        let url = Url::parse(s).map_err(|e| Error::InvalidUsage(format!("Invalid URL: {}", e)))?;
202        Self::try_from(&url)
203    }
204}