1use std::sync::Arc;
4
5use super::{
6 backend::DatabaseBackend,
7 error::Result,
8 query_builder::{DeleteBuilder, InsertBuilder, SelectBuilder, UpdateBuilder},
9};
10
11#[cfg(feature = "postgres")]
12use super::dialect::PostgresBackend;
13
14#[cfg(feature = "postgres")]
16const SQLSTATE_INVALID_CATALOG_NAME: &str = "3D000";
17
18#[cfg(feature = "sqlite")]
19use super::dialect::SqliteBackend;
20
21#[cfg(feature = "mysql")]
22use super::dialect::MySqlBackend;
23
24#[derive(Clone)]
26pub struct DatabaseConnection {
27 backend: Arc<dyn DatabaseBackend>,
28 is_cockroachdb: bool,
37}
38
39#[cfg(feature = "di")]
68#[async_trait::async_trait]
69impl reinhardt_di::Injectable for DatabaseConnection {
70 async fn inject(ctx: &reinhardt_di::InjectionContext) -> reinhardt_di::DiResult<Self> {
71 if let Some(conn) = ctx.get_singleton::<Self>() {
73 return Ok(std::sync::Arc::try_unwrap(conn).unwrap_or_else(|arc| (*arc).clone()));
74 }
75
76 if let Some(conn) = ctx.get_request::<Self>() {
78 return Ok(std::sync::Arc::try_unwrap(conn).unwrap_or_else(|arc| (*arc).clone()));
79 }
80
81 Err(reinhardt_di::DiError::NotRegistered {
83 type_name: std::any::type_name::<Self>().to_string(),
84 hint: "Use InjectionContextBuilder::singleton(db_connection) to register a \
85 DatabaseConnection. Create it with DatabaseConnection::connect_postgres(), \
86 connect_sqlite(), or connect_mysql()."
87 .to_string(),
88 })
89 }
90
91 async fn inject_uncached(ctx: &reinhardt_di::InjectionContext) -> reinhardt_di::DiResult<Self> {
92 Self::inject(ctx).await
95 }
96}
97
98impl DatabaseConnection {
99 pub fn new(backend: Arc<dyn DatabaseBackend>) -> Self {
107 Self::new_with_flavor(backend, false)
108 }
109
110 pub fn new_with_flavor(backend: Arc<dyn DatabaseBackend>, is_cockroachdb: bool) -> Self {
116 Self {
117 backend,
118 is_cockroachdb,
119 }
120 }
121
122 #[cfg(feature = "postgres")]
123 pub async fn connect_postgres(url: &str) -> Result<Self> {
125 Self::connect_postgres_with_pool_size(url, None).await
126 }
127
128 #[cfg(feature = "postgres")]
129 pub async fn connect_postgres_with_pool_size(
131 url: &str,
132 pool_size: Option<u32>,
133 ) -> Result<Self> {
134 let pool = Self::build_postgres_pool(url, pool_size).await?;
135 let is_cockroachdb = Self::probe_cockroachdb(&pool).await;
136
137 Ok(Self {
138 backend: Arc::new(PostgresBackend::new(pool)),
139 is_cockroachdb,
140 })
141 }
142
143 #[cfg(feature = "postgres")]
158 async fn probe_cockroachdb(pool: &sqlx::PgPool) -> bool {
159 sqlx::query_scalar::<_, bool>("SELECT version() LIKE 'CockroachDB%'")
160 .fetch_one(pool)
161 .await
162 .unwrap_or(false)
163 }
164
165 #[cfg(feature = "postgres")]
191 pub async fn connect_postgres_or_create(url: &str) -> Result<Self> {
192 Self::connect_postgres_or_create_with_pool_size(url, None).await
193 }
194
195 #[cfg(feature = "postgres")]
200 async fn build_postgres_pool(
201 url: &str,
202 pool_size: Option<u32>,
203 ) -> std::result::Result<sqlx::PgPool, sqlx::Error> {
204 use sqlx::postgres::PgPoolOptions;
205 use std::time::Duration;
206
207 let max_connections = pool_size
209 .or_else(|| {
210 std::env::var("DATABASE_POOL_MAX_CONNECTIONS")
211 .ok()
212 .and_then(|v| v.parse::<u32>().ok())
213 })
214 .unwrap_or(20); PgPoolOptions::new()
217 .max_connections(max_connections)
218 .min_connections(1) .acquire_timeout(Duration::from_secs(10)) .idle_timeout(Some(Duration::from_secs(10))) .max_lifetime(Some(Duration::from_secs(30 * 60))) .connect(url)
223 .await
224 }
225
226 #[cfg(feature = "postgres")]
230 pub async fn connect_postgres_or_create_with_pool_size(
231 url: &str,
232 pool_size: Option<u32>,
233 ) -> Result<Self> {
234 match Self::build_postgres_pool(url, pool_size).await {
237 Ok(pool) => {
238 let is_cockroachdb = Self::probe_cockroachdb(&pool).await;
239 return Ok(Self {
240 backend: Arc::new(PostgresBackend::new(pool)),
241 is_cockroachdb,
242 });
243 }
244 Err(e) => {
245 let is_db_not_found = matches!(
248 &e,
249 sqlx::Error::Database(db_err) if db_err.code().as_deref() == Some(SQLSTATE_INVALID_CATALOG_NAME)
250 );
251 if !is_db_not_found {
252 return Err(e.into());
253 }
254 }
256 }
257
258 let (admin_url, db_name) = Self::parse_postgres_url_for_creation(url)?;
260
261 use sqlx::postgres::PgPoolOptions;
263 use std::time::Duration;
264
265 let admin_pool = PgPoolOptions::new()
266 .max_connections(1)
267 .acquire_timeout(Duration::from_secs(10))
268 .connect(&admin_url)
269 .await
270 .map_err(|e| {
271 super::error::DatabaseError::ConnectionError(format!(
272 "Failed to connect to postgres database for auto-creation: {}",
273 e
274 ))
275 })?;
276
277 let create_sql = format!("CREATE DATABASE \"{}\"", db_name.replace('"', "\"\""));
279 sqlx::query(&create_sql)
280 .execute(&admin_pool)
281 .await
282 .map_err(|e| {
283 super::error::DatabaseError::QueryError(format!(
284 "Failed to create database '{}': {}",
285 db_name, e
286 ))
287 })?;
288
289 admin_pool.close().await;
291
292 Self::connect_postgres_with_pool_size(url, pool_size).await
294 }
295
296 #[cfg(feature = "postgres")]
298 fn parse_postgres_url_for_creation(url: &str) -> Result<(String, String)> {
299 let url_without_scheme = url
304 .strip_prefix("postgres://")
305 .or_else(|| url.strip_prefix("postgresql://"))
306 .ok_or_else(|| {
307 super::error::DatabaseError::ConnectionError(
308 "Invalid PostgreSQL URL: must start with postgres:// or postgresql://"
309 .to_string(),
310 )
311 })?;
312
313 let (path_part, query_part) = match url_without_scheme.find('?') {
315 Some(pos) => (&url_without_scheme[..pos], Some(&url_without_scheme[pos..])),
316 None => (url_without_scheme, None),
317 };
318
319 let last_slash_pos = path_part.rfind('/').ok_or_else(|| {
321 super::error::DatabaseError::ConnectionError(
322 "Invalid PostgreSQL URL: no database name found".to_string(),
323 )
324 })?;
325
326 let host_part = &path_part[..last_slash_pos];
327 let db_name = &path_part[last_slash_pos + 1..];
328
329 if db_name.is_empty() {
330 return Err(super::error::DatabaseError::ConnectionError(
331 "Invalid PostgreSQL URL: database name is empty".to_string(),
332 ));
333 }
334
335 let admin_url = match query_part {
337 Some(params) => format!("postgres://{}/postgres{}", host_part, params),
338 None => format!("postgres://{}/postgres", host_part),
339 };
340
341 Ok((admin_url, db_name.to_string()))
342 }
343
344 #[cfg(feature = "sqlite")]
346 pub async fn connect_sqlite(url: &str) -> Result<Self> {
347 use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
348 use std::path::Path;
349 use std::str::FromStr;
350
351 if url == "sqlite::memory:" {
353 let pool = SqlitePool::connect(url).await?;
354 return Ok(Self {
355 backend: Arc::new(SqliteBackend::new(pool)),
356 is_cockroachdb: false,
357 });
358 }
359
360 let file_path = if url.starts_with("sqlite:///") {
362 url.trim_start_matches("sqlite:///").to_string()
364 } else if url.starts_with("sqlite://") {
365 let rel_path = url.trim_start_matches("sqlite://");
368 std::env::current_dir()
369 .map_err(|e| {
370 super::error::DatabaseError::ConnectionError(format!(
371 "Failed to get current directory: {}",
372 e
373 ))
374 })?
375 .join(rel_path)
376 .to_string_lossy()
377 .to_string()
378 } else if url.starts_with("sqlite:") {
379 let rel_path = url.trim_start_matches("sqlite:");
382 std::env::current_dir()
383 .map_err(|e| {
384 super::error::DatabaseError::ConnectionError(format!(
385 "Failed to get current directory: {}",
386 e
387 ))
388 })?
389 .join(rel_path)
390 .to_string_lossy()
391 .to_string()
392 } else {
393 url.to_string()
394 };
395
396 let db_path = Path::new(&file_path);
398 let normalized_path = if db_path.exists() {
399 db_path.canonicalize().map_err(|e| {
401 super::error::DatabaseError::ConnectionError(format!(
402 "Failed to canonicalize path {}: {}",
403 db_path.display(),
404 e
405 ))
406 })?
407 } else {
408 if db_path.is_absolute() {
410 db_path.to_path_buf()
411 } else {
412 std::env::current_dir()
414 .map_err(|e| {
415 super::error::DatabaseError::ConnectionError(format!(
416 "Failed to get current directory: {}",
417 e
418 ))
419 })?
420 .join(db_path)
421 }
422 };
423
424 if let Some(parent) = normalized_path.parent()
426 && !parent.as_os_str().is_empty()
427 && !parent.exists()
428 {
429 std::fs::create_dir_all(parent).map_err(|e| {
430 super::error::DatabaseError::ConnectionError(format!(
431 "Failed to create database directory {}: {}",
432 parent.display(),
433 e
434 ))
435 })?;
436 }
437
438 let path_str = normalized_path.to_string_lossy().replace('\\', "/");
441 let absolute_url = format!("sqlite:///{}", path_str);
442
443 let options = SqliteConnectOptions::from_str(&absolute_url)
445 .map_err(|e| {
446 super::error::DatabaseError::ConnectionError(format!(
447 "Invalid SQLite URL '{}': {}",
448 absolute_url, e
449 ))
450 })?
451 .create_if_missing(true);
452
453 let pool = SqlitePool::connect_with(options).await?;
454
455 Ok(Self {
456 backend: Arc::new(SqliteBackend::new(pool)),
457 is_cockroachdb: false,
458 })
459 }
460
461 #[cfg(feature = "sqlite")]
463 pub fn from_sqlite_pool(pool: sqlx::SqlitePool) -> Self {
464 Self {
465 backend: Arc::new(SqliteBackend::new(pool)),
466 is_cockroachdb: false,
467 }
468 }
469
470 #[cfg(feature = "mysql")]
472 pub async fn connect_mysql(url: &str) -> Result<Self> {
473 use sqlx::MySqlPool;
474 let pool = MySqlPool::connect(url).await?;
475 Ok(Self {
476 backend: Arc::new(MySqlBackend::new(pool)),
477 is_cockroachdb: false,
478 })
479 }
480
481 pub fn backend(&self) -> Arc<dyn DatabaseBackend> {
483 self.backend.clone()
484 }
485
486 pub fn database_type(&self) -> super::types::DatabaseType {
488 self.backend.database_type()
489 }
490
491 pub fn is_cockroachdb(&self) -> bool {
502 self.is_cockroachdb
503 }
504
505 pub fn insert(&self, table: impl Into<String>) -> InsertBuilder {
507 InsertBuilder::new(self.backend.clone(), table)
508 }
509
510 pub fn update(&self, table: impl Into<String>) -> UpdateBuilder {
512 UpdateBuilder::new(self.backend.clone(), table)
513 }
514
515 pub fn select(&self) -> SelectBuilder {
517 SelectBuilder::new(self.backend.clone())
518 }
519
520 pub fn delete(&self, table: impl Into<String>) -> DeleteBuilder {
522 DeleteBuilder::new(self.backend.clone(), table)
523 }
524
525 #[cfg(feature = "settings")]
565 pub fn database_url_from<S>(settings: &S, env_override: Option<&str>) -> Result<String>
566 where
567 S: reinhardt_conf::HasCoreSettings + ?Sized,
568 {
569 if let Some(url) = env_override {
570 return Ok(url.to_string());
571 }
572
573 let core = settings.core();
574 let db_config = core.databases.get("default").ok_or_else(|| {
575 super::error::DatabaseError::ConnectionError(
576 "Database configuration `core.databases.default` not found in settings."
577 .to_string(),
578 )
579 })?;
580
581 Ok(db_config.to_url())
582 }
583
584 #[cfg(feature = "settings")]
620 #[deprecated(
621 since = "0.1.0-rc.29",
622 note = "use `DatabaseConnection::database_url_from` with a pre-built ProjectSettings instead"
623 )]
624 pub fn get_database_url_from_env_or_settings(
625 base_dir: Option<std::path::PathBuf>,
626 ) -> Result<String> {
627 use std::env;
628
629 if let Ok(url) = env::var("DATABASE_URL") {
631 return Ok(url);
632 }
633
634 let profile_str = env::var("REINHARDT_ENV").unwrap_or_else(|_| "local".to_string());
636 let profile = reinhardt_conf::settings::profile::Profile::parse(&profile_str);
637
638 let base_dir = base_dir.unwrap_or_else(|| {
639 env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."))
640 });
641 let settings_dir = base_dir.join("settings");
642
643 let merged = reinhardt_conf::settings::builder::SettingsBuilder::new()
645 .profile(profile)
646 .add_source(
647 reinhardt_conf::settings::sources::DefaultSource::new()
648 .with_value("debug", serde_json::Value::Bool(false))
649 .with_value(
650 "language_code",
651 serde_json::Value::String("en-us".to_string()),
652 )
653 .with_value("time_zone", serde_json::Value::String("UTC".to_string())),
654 )
655 .add_source(
656 reinhardt_conf::settings::sources::LowPriorityEnvSource::new()
657 .with_prefix("REINHARDT_"),
658 )
659 .add_source(
666 reinhardt_conf::settings::sources::TomlFileSource::new(settings_dir.join("base.toml"))
667 .with_interpolation(),
668 )
669 .add_source(
670 reinhardt_conf::settings::sources::TomlFileSource::new(
671 settings_dir.join(format!("{}.toml", profile_str)),
672 )
673 .with_interpolation(),
674 )
675 .build()
676 .map_err(|e| {
677 super::error::DatabaseError::ConnectionError(format!(
678 "Failed to load settings: {}. Please ensure settings files exist in the settings/ directory.",
679 e
680 ))
681 })?;
682
683 let db_config: reinhardt_conf::settings::DatabaseConfig = {
686 if let Some(db_val) = merged.get_raw("database") {
688 serde_json::from_value(db_val.clone())
690 .ok()
691 .or_else(|| {
692 if let serde_json::Value::Object(db_map) = db_val {
694 let engine = db_map
696 .get("engine")
697 .and_then(|v| v.as_str())
698 .unwrap_or("sqlite")
699 .to_string();
700 let name = db_map
701 .get("name")
702 .and_then(|v| v.as_str())
703 .map(|s| s.to_string())
704 .unwrap_or_else(|| "db.sqlite3".to_string());
705
706 let mut config =
707 reinhardt_conf::settings::DatabaseConfig::new(engine, name);
708 if let Some(user) = db_map
709 .get("user")
710 .and_then(|v| v.as_str())
711 {
712 config = config.with_user(user);
713 }
714 if let Some(password) = db_map
715 .get("password")
716 .and_then(|v| v.as_str())
717 {
718 config = config.with_password(password);
719 }
720 if let Some(host) = db_map
721 .get("host")
722 .and_then(|v| v.as_str())
723 {
724 config = config.with_host(host);
725 }
726 if let Some(port) = db_map
727 .get("port")
728 .and_then(|v| v.as_u64())
729 {
730 config = config.with_port(port as u16);
731 }
732 Some(config)
733 } else {
734 None
735 }
736 })
737 } else {
738 merged
740 .get_optional::<serde_json::Value>("databases")
741 .and_then(|dbs| {
742 if let serde_json::Value::Object(dbs_map) = dbs {
743 dbs_map
745 .get("default")
746 .or_else(|| dbs_map.get("database"))
747 .and_then(|db_val| serde_json::from_value(db_val.clone()).ok())
748 } else {
749 None
750 }
751 })
752 }
753 }
754 .ok_or_else(|| {
755 super::error::DatabaseError::ConnectionError(
756 "Database configuration not found in settings. Please configure [database] in your settings file or set DATABASE_URL environment variable.".to_string(),
757 )
758 })?;
759
760 Ok(db_config.to_url())
761 }
762
763 pub async fn execute(
765 &self,
766 sql: &str,
767 params: Vec<super::types::QueryValue>,
768 ) -> Result<super::types::QueryResult> {
769 self.backend.execute(sql, params).await
770 }
771
772 pub async fn fetch_one(
774 &self,
775 sql: &str,
776 params: Vec<super::types::QueryValue>,
777 ) -> Result<super::types::Row> {
778 self.backend.fetch_one(sql, params).await
779 }
780
781 pub async fn fetch_all(
783 &self,
784 sql: &str,
785 params: Vec<super::types::QueryValue>,
786 ) -> Result<Vec<super::types::Row>> {
787 self.backend.fetch_all(sql, params).await
788 }
789
790 pub async fn fetch_optional(
792 &self,
793 sql: &str,
794 params: Vec<super::types::QueryValue>,
795 ) -> Result<Option<super::types::Row>> {
796 self.backend.fetch_optional(sql, params).await
797 }
798
799 pub async fn begin(&self) -> Result<Box<dyn super::types::TransactionExecutor>> {
826 self.backend.begin().await
827 }
828
829 pub async fn begin_with_isolation(
847 &self,
848 level: super::types::IsolationLevel,
849 ) -> Result<Box<dyn super::types::TransactionExecutor>> {
850 self.backend.begin_with_isolation(level).await
851 }
852
853 #[cfg(feature = "postgres")]
854 pub fn into_postgres(&self) -> Option<sqlx::PgPool> {
856 self.backend
857 .as_any()
858 .downcast_ref::<super::dialect::PostgresBackend>()
859 .map(|backend| backend.pool().clone())
860 }
861
862 #[cfg(feature = "sqlite")]
864 pub fn into_sqlite(&self) -> Option<sqlx::SqlitePool> {
865 self.backend
866 .as_any()
867 .downcast_ref::<super::dialect::SqliteBackend>()
868 .map(|backend| backend.pool().clone())
869 }
870
871 #[cfg(feature = "mysql")]
873 pub fn into_mysql(&self) -> Option<sqlx::MySqlPool> {
874 self.backend
875 .as_any()
876 .downcast_ref::<super::dialect::MySqlBackend>()
877 .map(|backend| backend.pool().clone())
878 }
879}
880
881#[cfg(test)]
882mod tests {
883 use rstest::rstest;
884
885 fn build_create_database_sql(db_name: &str) -> String {
888 format!("CREATE DATABASE \"{}\"", db_name.replace('"', "\"\""))
889 }
890
891 #[rstest]
892 fn test_create_database_sql_normal_name() {
893 let db_name = "my_database";
895
896 let sql = build_create_database_sql(db_name);
898
899 assert_eq!(sql, "CREATE DATABASE \"my_database\"");
901 }
902
903 #[rstest]
904 fn test_create_database_sql_injection_with_double_quotes() {
905 let db_name = "test\"; DROP TABLE users; --";
907
908 let sql = build_create_database_sql(db_name);
910
911 assert_eq!(sql, "CREATE DATABASE \"test\"\"; DROP TABLE users; --\"");
913 }
916
917 #[rstest]
918 fn test_create_database_sql_injection_with_multiple_quotes() {
919 let db_name = "db\"\"injection";
921
922 let sql = build_create_database_sql(db_name);
924
925 assert_eq!(sql, "CREATE DATABASE \"db\"\"\"\"injection\"");
927 }
928
929 #[cfg(feature = "postgres")]
930 #[rstest]
931 fn test_parse_postgres_url_extracts_db_name() {
932 let url = "postgres://user:pass@localhost:5432/testdb";
934
935 let (admin_url, db_name) =
937 super::DatabaseConnection::parse_postgres_url_for_creation(url).unwrap();
938
939 assert_eq!(db_name, "testdb");
941 assert_eq!(admin_url, "postgres://user:pass@localhost:5432/postgres");
942 }
943}