Skip to main content

systemprompt_database/services/
database.rs

1use super::postgres::PostgresProvider;
2use super::provider::DatabaseProvider;
3use crate::models::{DatabaseInfo, QueryResult};
4use anyhow::Result;
5use std::sync::Arc;
6
7pub struct Database {
8    provider: Arc<dyn DatabaseProvider>,
9    write_provider: Option<Arc<dyn DatabaseProvider>>,
10}
11
12impl std::fmt::Debug for Database {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        f.debug_struct("Database")
15            .field("backend", &"PostgreSQL")
16            .finish()
17    }
18}
19
20impl Database {
21    pub async fn new_postgres(url: &str) -> Result<Self> {
22        let provider = PostgresProvider::new(url).await?;
23        Ok(Self {
24            provider: Arc::new(provider),
25            write_provider: None,
26        })
27    }
28
29    pub async fn from_config(db_type: &str, url: &str) -> Result<Self> {
30        match db_type.to_lowercase().as_str() {
31            "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
32            other => Err(anyhow::anyhow!(
33                "Unsupported database type: {other}. Only PostgreSQL is supported."
34            )),
35        }
36    }
37
38    pub async fn from_config_with_write(
39        db_type: &str,
40        read_url: &str,
41        write_url: Option<&str>,
42    ) -> Result<Self> {
43        let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
44            "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
45            other => {
46                return Err(anyhow::anyhow!(
47                    "Unsupported database type: {other}. Only PostgreSQL is supported."
48                ));
49            },
50        };
51
52        let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
53            Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
54            None => None,
55        };
56
57        Ok(Self {
58            provider,
59            write_provider,
60        })
61    }
62
63    pub fn get_postgres_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
64        self.pool_arc()
65    }
66
67    pub fn write_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
68        self.write_provider.as_ref().map_or_else(
69            || self.get_postgres_pool_arc(),
70            |wp| {
71                wp.get_postgres_pool()
72                    .ok_or_else(|| anyhow::anyhow!("Write database is not PostgreSQL"))
73            },
74        )
75    }
76
77    #[must_use]
78    pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
79        self.write_provider
80            .as_ref()
81            .and_then(|wp| wp.get_postgres_pool())
82            .or_else(|| self.provider.get_postgres_pool())
83    }
84
85    #[must_use]
86    pub fn has_write_pool(&self) -> bool {
87        self.write_provider.is_some()
88    }
89
90    #[must_use]
91    pub fn write_provider(&self) -> &dyn DatabaseProvider {
92        self.write_provider
93            .as_deref()
94            .unwrap_or_else(|| self.provider.as_ref())
95    }
96
97    pub async fn query(&self, sql: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
98        self.provider.query_raw(sql).await
99    }
100
101    pub async fn query_with(
102        &self,
103        sql: &dyn crate::models::QuerySelector,
104        params: Vec<serde_json::Value>,
105    ) -> Result<QueryResult> {
106        self.provider.query_raw_with(sql, params).await
107    }
108
109    pub async fn execute_batch(&self, sql: &str) -> Result<()> {
110        self.provider.execute_batch(sql).await
111    }
112
113    pub async fn get_info(&self) -> Result<DatabaseInfo> {
114        self.provider.get_database_info().await
115    }
116
117    pub async fn test_connection(&self) -> Result<()> {
118        self.provider.test_connection().await?;
119        if let Some(wp) = &self.write_provider {
120            wp.test_connection().await?;
121        }
122        Ok(())
123    }
124
125    #[must_use]
126    pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
127        self.write_provider
128            .as_ref()
129            .and_then(|wp| wp.get_postgres_pool())
130            .or_else(|| self.provider.get_postgres_pool())
131    }
132
133    pub fn pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
134        self.get_postgres_pool()
135            .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
136    }
137
138    #[must_use]
139    pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
140        self.get_postgres_pool()
141    }
142
143    #[must_use]
144    pub fn read_pool(&self) -> Option<Arc<sqlx::PgPool>> {
145        self.provider.get_postgres_pool()
146    }
147
148    pub fn read_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
149        self.provider
150            .get_postgres_pool()
151            .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
152    }
153
154    pub async fn begin(&self) -> Result<sqlx::Transaction<'_, sqlx::Postgres>> {
155        let pool = self.write_pool_arc()?;
156        pool.begin().await.map_err(Into::into)
157    }
158}
159
160pub type DbPool = Arc<Database>;
161
162pub trait DatabaseExt {
163    fn database(&self) -> Arc<Database>;
164}
165
166impl DatabaseExt for Arc<Database> {
167    fn database(&self) -> Arc<Database> {
168        Self::clone(self)
169    }
170}
171
172#[async_trait::async_trait]
173impl DatabaseProvider for Database {
174    fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
175        self.write_provider
176            .as_ref()
177            .and_then(|wp| wp.get_postgres_pool())
178            .or_else(|| self.provider.get_postgres_pool())
179    }
180
181    async fn execute(
182        &self,
183        query: &dyn crate::models::QuerySelector,
184        params: &[&dyn crate::models::ToDbValue],
185    ) -> Result<u64> {
186        self.write_provider().execute(query, params).await
187    }
188
189    async fn execute_raw(&self, sql: &str) -> Result<()> {
190        self.write_provider().execute_raw(sql).await
191    }
192
193    async fn fetch_all(
194        &self,
195        query: &dyn crate::models::QuerySelector,
196        params: &[&dyn crate::models::ToDbValue],
197    ) -> Result<Vec<crate::models::JsonRow>> {
198        self.provider.fetch_all(query, params).await
199    }
200
201    async fn fetch_one(
202        &self,
203        query: &dyn crate::models::QuerySelector,
204        params: &[&dyn crate::models::ToDbValue],
205    ) -> Result<crate::models::JsonRow> {
206        self.provider.fetch_one(query, params).await
207    }
208
209    async fn fetch_optional(
210        &self,
211        query: &dyn crate::models::QuerySelector,
212        params: &[&dyn crate::models::ToDbValue],
213    ) -> Result<Option<crate::models::JsonRow>> {
214        self.provider.fetch_optional(query, params).await
215    }
216
217    async fn fetch_scalar_value(
218        &self,
219        query: &dyn crate::models::QuerySelector,
220        params: &[&dyn crate::models::ToDbValue],
221    ) -> Result<crate::models::DbValue> {
222        self.provider.fetch_scalar_value(query, params).await
223    }
224
225    async fn begin_transaction(&self) -> Result<Box<dyn crate::models::DatabaseTransaction>> {
226        self.write_provider().begin_transaction().await
227    }
228
229    async fn get_database_info(&self) -> Result<DatabaseInfo> {
230        self.provider.get_database_info().await
231    }
232
233    async fn test_connection(&self) -> Result<()> {
234        self.provider.test_connection().await
235    }
236
237    async fn execute_batch(&self, sql: &str) -> Result<()> {
238        self.write_provider().execute_batch(sql).await
239    }
240
241    async fn query_raw(&self, query: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
242        self.provider.query_raw(query).await
243    }
244
245    async fn query_raw_with(
246        &self,
247        query: &dyn crate::models::QuerySelector,
248        params: Vec<serde_json::Value>,
249    ) -> Result<QueryResult> {
250        self.provider.query_raw_with(query, params).await
251    }
252}