mockforge_http/
database.rs1#[cfg(feature = "database")]
7use anyhow::Result as AnyhowResult;
8#[cfg(feature = "database")]
9use sqlx::{postgres::PgPoolOptions, PgPool};
10#[cfg(feature = "database")]
11use std::sync::Arc;
12
13#[derive(Clone)]
15pub struct Database {
16 #[cfg(feature = "database")]
17 pool: Option<Arc<PgPool>>,
18 #[cfg(not(feature = "database"))]
19 _phantom: std::marker::PhantomData<()>,
20}
21
22impl Database {
23 #[cfg(feature = "database")]
25 pub const DEFAULT_MAX_CONNECTIONS: u32 = 10;
26
27 #[cfg(feature = "database")]
36 pub async fn connect_optional(database_url: Option<&str>) -> AnyhowResult<Self> {
37 Self::connect_optional_with_pool_size(database_url, None).await
38 }
39
40 #[cfg(feature = "database")]
45 pub async fn connect_optional_with_pool_size(
46 database_url: Option<&str>,
47 max_connections: Option<u32>,
48 ) -> AnyhowResult<Self> {
49 let pool = if let Some(url) = database_url {
50 if url.is_empty() {
51 None
52 } else {
53 let max_conn = max_connections.unwrap_or_else(|| {
54 std::env::var("MOCKFORGE_DB_MAX_CONNECTIONS")
55 .ok()
56 .and_then(|s| s.parse().ok())
57 .unwrap_or(Self::DEFAULT_MAX_CONNECTIONS)
58 });
59 tracing::info!("Connecting to database with max_connections={}", max_conn);
60 let pool = PgPoolOptions::new().max_connections(max_conn).connect(url).await?;
61 Some(Arc::new(pool))
62 }
63 } else {
64 None
65 };
66
67 Ok(Self { pool })
68 }
69
70 #[cfg(not(feature = "database"))]
72 pub async fn connect_optional(_database_url: Option<&str>) -> anyhow::Result<Self> {
73 Ok(Self {
74 _phantom: std::marker::PhantomData,
75 })
76 }
77
78 #[cfg(feature = "database")]
80 pub async fn migrate_if_connected(&self) -> AnyhowResult<()> {
81 if let Some(ref pool) = self.pool {
82 match sqlx::migrate!("./migrations").run(pool.as_ref()).await {
85 Ok(_) => {
86 tracing::info!("Database migrations completed successfully");
87 Ok(())
88 }
89 Err(e) => {
90 if e.to_string().contains("previously applied but is missing") {
92 tracing::warn!(
93 "Migration tracking issue (manually applied migration): {:?}",
94 e
95 );
96 tracing::info!(
97 "Continuing despite migration tracking issue - database is up to date"
98 );
99 Ok(())
100 } else {
101 Err(e.into())
102 }
103 }
104 }
105 } else {
106 tracing::debug!("No database connection, skipping migrations");
107 Ok(())
108 }
109 }
110
111 #[cfg(not(feature = "database"))]
113 pub async fn migrate_if_connected(&self) -> anyhow::Result<()> {
114 tracing::debug!("Database feature not enabled, skipping migrations");
115 Ok(())
116 }
117
118 #[cfg(feature = "database")]
120 pub fn pool(&self) -> Option<&PgPool> {
121 self.pool.as_deref()
122 }
123
124 #[cfg(not(feature = "database"))]
126 pub fn pool(&self) -> Option<()> {
127 None
128 }
129
130 pub fn is_connected(&self) -> bool {
132 #[cfg(feature = "database")]
133 {
134 self.pool.is_some()
135 }
136 #[cfg(not(feature = "database"))]
137 {
138 false
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[tokio::test]
148 async fn test_database_connect_optional_none() {
149 let db = Database::connect_optional(None).await.unwrap();
150 assert!(!db.is_connected());
151 }
152
153 #[tokio::test]
154 async fn test_database_connect_optional_empty_string() {
155 let db = Database::connect_optional(Some("")).await.unwrap();
156 assert!(!db.is_connected());
157 }
158
159 #[tokio::test]
160 async fn test_database_pool_returns_none_when_not_connected() {
161 let db = Database::connect_optional(None).await.unwrap();
162 assert!(db.pool().is_none());
163 }
164
165 #[tokio::test]
166 async fn test_database_migrate_skips_when_not_connected() {
167 let db = Database::connect_optional(None).await.unwrap();
168 let result = db.migrate_if_connected().await;
170 assert!(result.is_ok());
171 }
172
173 #[test]
174 fn test_database_is_connected_returns_false_by_default() {
175 #[cfg(not(feature = "database"))]
177 {
178 let db = Database {
179 _phantom: std::marker::PhantomData,
180 };
181 assert!(!db.is_connected());
182 }
183 }
184
185 #[test]
186 fn test_database_clone() {
187 #[cfg(not(feature = "database"))]
189 {
190 let db = Database {
191 _phantom: std::marker::PhantomData,
192 };
193 let cloned = db.clone();
194 assert!(!cloned.is_connected());
195 }
196 }
197}