1use std::str::FromStr;
2
3use crate::{
4 config::{DatabaseConfig, DatabaseType},
5 error::Error,
6};
7
8pub struct DatabaseManager {
9 pool: sqlx::Pool<sqlx::Any>,
10 db_type: DatabaseType,
11}
12
13impl DatabaseManager {
14 pub async fn new(db_config: DatabaseConfig) -> Result<Self, Error> {
15 let db_type = DatabaseType::from_str(&db_config.db_type)?;
16 let connect_str = match db_type {
17 DatabaseType::Sqlite => {
18 format!("sqlite://{}", db_config.database)
19 }
20 DatabaseType::Postgres => {
21 let host = db_config.host.unwrap_or("localhost".to_string());
22 let port = db_config.port.unwrap_or(5432);
23 let username = db_config.username.unwrap_or("postgres".to_string());
24 let password = db_config.password.unwrap_or("password".to_string());
25 format!(
26 "postgres://{}:{}@{}:{}/{}",
27 username, password, host, port, db_config.database
28 )
29 }
30 DatabaseType::MySql => {
31 let host = db_config.host.unwrap_or("localhost".to_string());
32 let port = db_config.port.unwrap_or(3306);
33 let username = db_config.username.unwrap_or("root".to_string());
34 let password = db_config.password.unwrap_or("password".to_string());
35 format!(
36 "mysql://{}:{}@{}:{}/{}",
37 username, password, host, port, db_config.database
38 )
39 }
40 };
41
42 let max_connections = db_config.max_connections.unwrap_or(12);
43
44 sqlx::any::install_default_drivers();
45 let pool = sqlx::any::AnyPoolOptions::new()
46 .max_connections(max_connections)
47 .connect(&connect_str)
48 .await?;
49
50 Ok(Self { pool, db_type })
51 }
52
53 pub fn pool(&self) -> &sqlx::Pool<sqlx::Any> {
54 &self.pool
55 }
56
57 pub fn db_type(&self) -> &DatabaseType {
58 &self.db_type
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use crate::config::AppConfig;
65
66 use super::*;
67
68 #[tokio::test]
69 #[ignore = "Requires a database to be set up"]
70 async fn test_database_manager() {
71 let config = AppConfig::init().expect("Failed to load config");
72 let db_manager = DatabaseManager::new(config.botcat_capoo.database)
73 .await
74 .expect("Failed to create DatabaseManager");
75 let pool = db_manager.pool();
76 let db_type = db_manager.db_type();
77 println!("Database Type: {:?}", db_type);
78 let row: (i32,) = sqlx::query_as("SELECT 1")
80 .fetch_one(pool)
81 .await
82 .expect("Failed to execute query");
83 assert_eq!(row.0, 1);
84 }
85}