systemprompt_database/services/
database.rs1use super::postgres::PostgresProvider;
2use super::provider::DatabaseProvider;
3use crate::models::{DatabaseInfo, QueryResult};
4use anyhow::Result;
5use std::sync::Arc;
6
7pub struct Database {
8 provider: Arc<dyn DatabaseProvider>,
9 write_provider: Option<Arc<dyn DatabaseProvider>>,
10}
11
12impl std::fmt::Debug for Database {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 f.debug_struct("Database")
15 .field("backend", &"PostgreSQL")
16 .finish()
17 }
18}
19
20impl Database {
21 pub async fn new_postgres(url: &str) -> Result<Self> {
22 let provider = PostgresProvider::new(url).await?;
23 Ok(Self {
24 provider: Arc::new(provider),
25 write_provider: None,
26 })
27 }
28
29 pub async fn from_config(db_type: &str, url: &str) -> Result<Self> {
30 match db_type.to_lowercase().as_str() {
31 "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
32 other => Err(anyhow::anyhow!(
33 "Unsupported database type: {other}. Only PostgreSQL is supported."
34 )),
35 }
36 }
37
38 pub async fn from_config_with_write(
39 db_type: &str,
40 read_url: &str,
41 write_url: Option<&str>,
42 ) -> Result<Self> {
43 let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
44 "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
45 other => {
46 return Err(anyhow::anyhow!(
47 "Unsupported database type: {other}. Only PostgreSQL is supported."
48 ));
49 },
50 };
51
52 let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
53 Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
54 None => None,
55 };
56
57 Ok(Self {
58 provider,
59 write_provider,
60 })
61 }
62
63 pub fn get_postgres_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
64 self.pool_arc()
65 }
66
67 pub fn write_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
68 self.write_provider.as_ref().map_or_else(
69 || self.get_postgres_pool_arc(),
70 |wp| {
71 wp.get_postgres_pool()
72 .ok_or_else(|| anyhow::anyhow!("Write database is not PostgreSQL"))
73 },
74 )
75 }
76
77 #[must_use]
78 pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
79 self.write_provider
80 .as_ref()
81 .and_then(|wp| wp.get_postgres_pool())
82 .or_else(|| self.provider.get_postgres_pool())
83 }
84
85 #[must_use]
86 pub fn has_write_pool(&self) -> bool {
87 self.write_provider.is_some()
88 }
89
90 #[must_use]
91 pub fn write_provider(&self) -> &dyn DatabaseProvider {
92 self.write_provider
93 .as_deref()
94 .unwrap_or_else(|| self.provider.as_ref())
95 }
96
97 pub async fn query(&self, sql: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
98 self.provider.query_raw(sql).await
99 }
100
101 pub async fn query_with(
102 &self,
103 sql: &dyn crate::models::QuerySelector,
104 params: Vec<serde_json::Value>,
105 ) -> Result<QueryResult> {
106 self.provider.query_raw_with(sql, params).await
107 }
108
109 pub async fn execute_batch(&self, sql: &str) -> Result<()> {
110 self.provider.execute_batch(sql).await
111 }
112
113 pub async fn get_info(&self) -> Result<DatabaseInfo> {
114 self.provider.get_database_info().await
115 }
116
117 pub async fn test_connection(&self) -> Result<()> {
118 self.provider.test_connection().await?;
119 if let Some(wp) = &self.write_provider {
120 wp.test_connection().await?;
121 }
122 Ok(())
123 }
124
125 #[must_use]
126 pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
127 self.write_provider
128 .as_ref()
129 .and_then(|wp| wp.get_postgres_pool())
130 .or_else(|| self.provider.get_postgres_pool())
131 }
132
133 pub fn pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
134 self.get_postgres_pool()
135 .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
136 }
137
138 #[must_use]
139 pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
140 self.get_postgres_pool()
141 }
142
143 #[must_use]
144 pub fn read_pool(&self) -> Option<Arc<sqlx::PgPool>> {
145 self.provider.get_postgres_pool()
146 }
147
148 pub fn read_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
149 self.provider
150 .get_postgres_pool()
151 .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
152 }
153
154 pub async fn begin(&self) -> Result<sqlx::Transaction<'_, sqlx::Postgres>> {
155 let pool = self.write_pool_arc()?;
156 pool.begin().await.map_err(Into::into)
157 }
158}
159
160pub type DbPool = Arc<Database>;
161
162pub trait DatabaseExt {
163 fn database(&self) -> Arc<Database>;
164}
165
166impl DatabaseExt for Arc<Database> {
167 fn database(&self) -> Arc<Database> {
168 Self::clone(self)
169 }
170}
171
172#[async_trait::async_trait]
173impl DatabaseProvider for Database {
174 fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
175 self.write_provider
176 .as_ref()
177 .and_then(|wp| wp.get_postgres_pool())
178 .or_else(|| self.provider.get_postgres_pool())
179 }
180
181 async fn execute(
182 &self,
183 query: &dyn crate::models::QuerySelector,
184 params: &[&dyn crate::models::ToDbValue],
185 ) -> Result<u64> {
186 self.write_provider().execute(query, params).await
187 }
188
189 async fn execute_raw(&self, sql: &str) -> Result<()> {
190 self.write_provider().execute_raw(sql).await
191 }
192
193 async fn fetch_all(
194 &self,
195 query: &dyn crate::models::QuerySelector,
196 params: &[&dyn crate::models::ToDbValue],
197 ) -> Result<Vec<crate::models::JsonRow>> {
198 self.provider.fetch_all(query, params).await
199 }
200
201 async fn fetch_one(
202 &self,
203 query: &dyn crate::models::QuerySelector,
204 params: &[&dyn crate::models::ToDbValue],
205 ) -> Result<crate::models::JsonRow> {
206 self.provider.fetch_one(query, params).await
207 }
208
209 async fn fetch_optional(
210 &self,
211 query: &dyn crate::models::QuerySelector,
212 params: &[&dyn crate::models::ToDbValue],
213 ) -> Result<Option<crate::models::JsonRow>> {
214 self.provider.fetch_optional(query, params).await
215 }
216
217 async fn fetch_scalar_value(
218 &self,
219 query: &dyn crate::models::QuerySelector,
220 params: &[&dyn crate::models::ToDbValue],
221 ) -> Result<crate::models::DbValue> {
222 self.provider.fetch_scalar_value(query, params).await
223 }
224
225 async fn begin_transaction(&self) -> Result<Box<dyn crate::models::DatabaseTransaction>> {
226 self.write_provider().begin_transaction().await
227 }
228
229 async fn get_database_info(&self) -> Result<DatabaseInfo> {
230 self.provider.get_database_info().await
231 }
232
233 async fn test_connection(&self) -> Result<()> {
234 self.provider.test_connection().await
235 }
236
237 async fn execute_batch(&self, sql: &str) -> Result<()> {
238 self.write_provider().execute_batch(sql).await
239 }
240
241 async fn query_raw(&self, query: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
242 self.provider.query_raw(query).await
243 }
244
245 async fn query_raw_with(
246 &self,
247 query: &dyn crate::models::QuerySelector,
248 params: Vec<serde_json::Value>,
249 ) -> Result<QueryResult> {
250 self.provider.query_raw_with(query, params).await
251 }
252}