Skip to main content

rustauth_deadpool_postgres/
builder.rs

1use deadpool_postgres::{Config, Pool};
2use rustauth_core::db::{auth_schema, AuthSchemaOptions, DbSchema};
3use rustauth_core::error::RustAuthError;
4use tokio_postgres::{
5    tls::{MakeTlsConnect, TlsConnect},
6    NoTls, Socket,
7};
8
9use crate::adapter::DeadpoolPostgresAdapter;
10use crate::config::{apply_default_pool_config, create_pool, DEFAULT_POOL_MAX_SIZE};
11
12/// Configures and connects a [`DeadpoolPostgresAdapter`].
13///
14/// Prefer the [`DeadpoolPostgresStoresBuilder`] name when building
15/// [`DeadpoolPostgresStores`]; both names refer to the same type.
16#[derive(Debug, Clone)]
17pub struct DeadpoolPostgresBuilder {
18    schema: DbSchema,
19    max_size: usize,
20    checked: bool,
21    database_url: Option<String>,
22    config: Option<Config>,
23}
24
25/// Preferred name for [`DeadpoolPostgresBuilder`] when configuring
26/// [`DeadpoolPostgresStores`].
27pub type DeadpoolPostgresStoresBuilder = DeadpoolPostgresBuilder;
28
29impl Default for DeadpoolPostgresBuilder {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl DeadpoolPostgresBuilder {
36    pub fn new() -> Self {
37        Self {
38            schema: auth_schema(AuthSchemaOptions::default()),
39            max_size: DEFAULT_POOL_MAX_SIZE,
40            checked: false,
41            database_url: None,
42            config: None,
43        }
44    }
45
46    #[must_use]
47    pub fn schema(mut self, schema: DbSchema) -> Self {
48        self.schema = schema;
49        self
50    }
51
52    #[must_use]
53    pub fn max_size(mut self, max_size: usize) -> Self {
54        self.max_size = max_size;
55        self
56    }
57
58    /// Validates the pool with `SELECT 1` after connecting.
59    #[must_use]
60    pub fn checked(mut self, checked: bool) -> Self {
61        self.checked = checked;
62        self
63    }
64
65    #[must_use]
66    pub fn database_url(mut self, database_url: impl Into<String>) -> Self {
67        self.database_url = Some(database_url.into());
68        self
69    }
70
71    #[must_use]
72    pub fn config(mut self, config: Config) -> Self {
73        self.config = Some(config);
74        self
75    }
76
77    /// Builds the adapter without validating the pool connection.
78    pub fn build_adapter(self) -> Result<DeadpoolPostgresAdapter, RustAuthError> {
79        self.build(NoTls)
80    }
81
82    pub fn build_adapter_tls<T>(self, tls: T) -> Result<DeadpoolPostgresAdapter, RustAuthError>
83    where
84        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
85        T::Stream: Sync + Send,
86        T::TlsConnect: Sync + Send,
87        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
88    {
89        self.build(tls)
90    }
91
92    pub async fn connect(self) -> Result<DeadpoolPostgresAdapter, RustAuthError> {
93        let checked = self.checked;
94        let adapter = self.build_adapter()?;
95        if checked {
96            adapter.validate_connection().await?;
97        }
98        Ok(adapter)
99    }
100
101    pub async fn connect_tls<T>(self, tls: T) -> Result<DeadpoolPostgresAdapter, RustAuthError>
102    where
103        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
104        T::Stream: Sync + Send,
105        T::TlsConnect: Sync + Send,
106        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
107    {
108        let checked = self.checked;
109        let adapter = self.build_adapter_tls(tls)?;
110        if checked {
111            adapter.validate_connection().await?;
112        }
113        Ok(adapter)
114    }
115
116    fn build<T>(self, tls: T) -> Result<DeadpoolPostgresAdapter, RustAuthError>
117    where
118        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
119        T::Stream: Sync + Send,
120        T::TlsConnect: Sync + Send,
121        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
122    {
123        let mut config = self.config.unwrap_or_default();
124        if let Some(database_url) = self.database_url {
125            if config.url.is_some() {
126                return Err(RustAuthError::InvalidConfig(
127                    "deadpool-postgres builder: set either `database_url` or `config`, not both"
128                        .to_owned(),
129                ));
130            }
131            config.url = Some(database_url);
132        }
133        if config.url.is_none() && config.host.is_none() {
134            return Err(RustAuthError::InvalidConfig(
135                "deadpool-postgres builder: `database_url` or `config` is required".to_owned(),
136            ));
137        }
138        apply_default_pool_config(&mut config, self.max_size);
139        let pool = create_pool(config, tls)?;
140        Ok(DeadpoolPostgresAdapter::with_schema(pool, self.schema))
141    }
142}
143
144/// Database adapter and matching SQL-backed rate-limit store sharing one pool.
145#[derive(Clone)]
146pub struct DeadpoolPostgresStores {
147    pub adapter: DeadpoolPostgresAdapter,
148    pub rate_limit: crate::DeadpoolPostgresRateLimitStore,
149}
150
151impl std::fmt::Debug for DeadpoolPostgresStores {
152    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        formatter
154            .debug_struct("DeadpoolPostgresStores")
155            .field("adapter", &self.adapter)
156            .field("rate_limit", &self.rate_limit)
157            .finish()
158    }
159}
160
161impl DeadpoolPostgresStores {
162    pub fn builder() -> DeadpoolPostgresBuilder {
163        DeadpoolPostgresBuilder::new()
164    }
165
166    pub async fn connect(database_url: &str) -> Result<Self, RustAuthError> {
167        Self::builder()
168            .database_url(database_url)
169            .build_stores()
170            .await
171    }
172
173    pub async fn connect_with_schema(
174        database_url: &str,
175        schema: DbSchema,
176    ) -> Result<Self, RustAuthError> {
177        Self::builder()
178            .database_url(database_url)
179            .schema(schema)
180            .build_stores()
181            .await
182    }
183
184    pub async fn connect_checked(database_url: &str) -> Result<Self, RustAuthError> {
185        Self::builder()
186            .database_url(database_url)
187            .checked(true)
188            .build_stores()
189            .await
190    }
191
192    pub async fn connect_with_schema_checked(
193        database_url: &str,
194        schema: DbSchema,
195    ) -> Result<Self, RustAuthError> {
196        Self::builder()
197            .database_url(database_url)
198            .schema(schema)
199            .checked(true)
200            .build_stores()
201            .await
202    }
203
204    /// Wires the SQL-backed rate-limit store into [`RustAuthOptions`].
205    #[must_use]
206    pub fn apply_to_options(
207        &self,
208        options: rustauth_core::options::RustAuthOptions,
209    ) -> rustauth_core::options::RustAuthOptions {
210        use rustauth_core::options::RateLimitOptions;
211        options.rate_limit(RateLimitOptions::database(self.rate_limit.clone()))
212    }
213
214    pub fn adapter(&self) -> std::sync::Arc<dyn rustauth_core::db::DbAdapter> {
215        std::sync::Arc::new(self.adapter.clone())
216    }
217
218    pub fn adapter_ref(&self) -> &DeadpoolPostgresAdapter {
219        &self.adapter
220    }
221
222    pub fn pool(&self) -> &Pool {
223        self.adapter.pool()
224    }
225}
226
227impl DeadpoolPostgresBuilder {
228    pub async fn build_stores(self) -> Result<DeadpoolPostgresStores, RustAuthError> {
229        let adapter = self.connect().await?;
230        let rate_limit = crate::DeadpoolPostgresRateLimitStore::from(&adapter);
231        Ok(DeadpoolPostgresStores {
232            adapter,
233            rate_limit,
234        })
235    }
236
237    pub async fn build_stores_tls<T>(self, tls: T) -> Result<DeadpoolPostgresStores, RustAuthError>
238    where
239        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
240        T::Stream: Sync + Send,
241        T::TlsConnect: Sync + Send,
242        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
243    {
244        let checked = self.checked;
245        let adapter = self.build_adapter_tls(tls)?;
246        if checked {
247            adapter.validate_connection().await?;
248        }
249        let rate_limit = crate::DeadpoolPostgresRateLimitStore::from(&adapter);
250        Ok(DeadpoolPostgresStores {
251            adapter,
252            rate_limit,
253        })
254    }
255}