systemprompt_database/services/
database.rs1use super::postgres::PostgresProvider;
6use super::provider::DatabaseProvider;
7use crate::error::{DatabaseResult, RepositoryError};
8use crate::models::{DatabaseInfo, QueryResult};
9use std::sync::Arc;
10
11pub struct Database {
12 provider: Arc<dyn DatabaseProvider>,
13 write_provider: Option<Arc<dyn DatabaseProvider>>,
14}
15
16impl std::fmt::Debug for Database {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("Database")
19 .field("backend", &"PostgreSQL")
20 .finish()
21 }
22}
23
24impl Database {
25 pub async fn new_postgres(url: &str) -> DatabaseResult<Self> {
26 let provider = PostgresProvider::new(url).await?;
27 Ok(Self {
28 provider: Arc::new(provider),
29 write_provider: None,
30 })
31 }
32
33 pub async fn from_config(db_type: &str, url: &str) -> DatabaseResult<Self> {
34 match db_type.to_lowercase().as_str() {
35 "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
36 other => Err(RepositoryError::invalid_argument(format!(
37 "Unsupported database type: {other}. Only PostgreSQL is supported."
38 ))),
39 }
40 }
41
42 pub async fn from_config_with_write(
43 db_type: &str,
44 read_url: &str,
45 write_url: Option<&str>,
46 ) -> DatabaseResult<Self> {
47 let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
48 "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
49 other => {
50 return Err(RepositoryError::invalid_argument(format!(
51 "Unsupported database type: {other}. Only PostgreSQL is supported."
52 )));
53 },
54 };
55
56 let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
57 Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
58 None => None,
59 };
60
61 Ok(Self {
62 provider,
63 write_provider,
64 })
65 }
66
67 pub fn get_postgres_pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
68 self.pool_arc()
69 }
70
71 pub fn write_pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
72 self.write_provider.as_ref().map_or_else(
73 || self.get_postgres_pool_arc(),
74 |wp| {
75 wp.get_postgres_pool().ok_or_else(|| {
76 RepositoryError::invalid_state("Write database is not PostgreSQL")
77 })
78 },
79 )
80 }
81
82 #[must_use]
83 pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
84 self.write_provider
85 .as_ref()
86 .and_then(|wp| wp.get_postgres_pool())
87 .or_else(|| self.provider.get_postgres_pool())
88 }
89
90 #[must_use]
91 pub fn has_write_pool(&self) -> bool {
92 self.write_provider.is_some()
93 }
94
95 #[must_use]
96 pub fn write_provider(&self) -> &dyn DatabaseProvider {
97 self.write_provider
98 .as_deref()
99 .unwrap_or_else(|| self.provider.as_ref())
100 }
101
102 pub async fn query(
103 &self,
104 sql: &dyn crate::models::QuerySelector,
105 ) -> DatabaseResult<QueryResult> {
106 self.provider.query_raw(sql).await
107 }
108
109 pub async fn query_with(
110 &self,
111 sql: &dyn crate::models::QuerySelector,
112 params: Vec<serde_json::Value>,
113 ) -> DatabaseResult<QueryResult> {
114 self.provider.query_raw_with(sql, params).await
115 }
116
117 pub async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
118 self.provider.execute_batch(sql).await
119 }
120
121 pub async fn get_info(&self) -> DatabaseResult<DatabaseInfo> {
122 self.provider.get_database_info().await
123 }
124
125 pub async fn test_connection(&self) -> DatabaseResult<()> {
126 self.provider.test_connection().await?;
127 if let Some(wp) = &self.write_provider {
128 wp.test_connection().await?;
129 }
130 Ok(())
131 }
132
133 #[must_use]
134 pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
135 self.write_provider
136 .as_ref()
137 .and_then(|wp| wp.get_postgres_pool())
138 .or_else(|| self.provider.get_postgres_pool())
139 }
140
141 pub fn pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
142 self.get_postgres_pool()
143 .ok_or_else(|| RepositoryError::invalid_state("Database is not PostgreSQL"))
144 }
145
146 #[must_use]
147 pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
148 self.get_postgres_pool()
149 }
150
151 #[must_use]
152 pub fn read_pool(&self) -> Option<Arc<sqlx::PgPool>> {
153 self.provider.get_postgres_pool()
154 }
155
156 pub fn read_pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
157 self.provider
158 .get_postgres_pool()
159 .ok_or_else(|| RepositoryError::invalid_state("Database is not PostgreSQL"))
160 }
161
162 pub async fn begin(&self) -> DatabaseResult<sqlx::Transaction<'_, sqlx::Postgres>> {
163 let pool = self.write_pool_arc()?;
164 pool.begin().await.map_err(Into::into)
165 }
166}
167
168pub type DbPool = Arc<Database>;
169
170pub trait DatabaseExt {
171 fn database(&self) -> Arc<Database>;
172}
173
174impl DatabaseExt for Arc<Database> {
175 fn database(&self) -> Arc<Database> {
176 Self::clone(self)
177 }
178}
179
180#[async_trait::async_trait]
181impl DatabaseProvider for Database {
182 fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
183 self.write_provider
184 .as_ref()
185 .and_then(|wp| wp.get_postgres_pool())
186 .or_else(|| self.provider.get_postgres_pool())
187 }
188
189 async fn execute(
190 &self,
191 query: &dyn crate::models::QuerySelector,
192 params: &[&dyn crate::models::ToDbValue],
193 ) -> DatabaseResult<u64> {
194 self.write_provider().execute(query, params).await
195 }
196
197 async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
198 self.write_provider().execute_raw(sql).await
199 }
200
201 async fn fetch_all(
202 &self,
203 query: &dyn crate::models::QuerySelector,
204 params: &[&dyn crate::models::ToDbValue],
205 ) -> DatabaseResult<Vec<crate::models::JsonRow>> {
206 self.provider.fetch_all(query, params).await
207 }
208
209 async fn fetch_one(
210 &self,
211 query: &dyn crate::models::QuerySelector,
212 params: &[&dyn crate::models::ToDbValue],
213 ) -> DatabaseResult<crate::models::JsonRow> {
214 self.provider.fetch_one(query, params).await
215 }
216
217 async fn fetch_optional(
218 &self,
219 query: &dyn crate::models::QuerySelector,
220 params: &[&dyn crate::models::ToDbValue],
221 ) -> DatabaseResult<Option<crate::models::JsonRow>> {
222 self.provider.fetch_optional(query, params).await
223 }
224
225 async fn fetch_scalar_value(
226 &self,
227 query: &dyn crate::models::QuerySelector,
228 params: &[&dyn crate::models::ToDbValue],
229 ) -> DatabaseResult<crate::models::DbValue> {
230 self.provider.fetch_scalar_value(query, params).await
231 }
232
233 async fn begin_transaction(
234 &self,
235 ) -> DatabaseResult<Box<dyn crate::models::DatabaseTransaction>> {
236 self.write_provider().begin_transaction().await
237 }
238
239 async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
240 self.provider.get_database_info().await
241 }
242
243 async fn test_connection(&self) -> DatabaseResult<()> {
244 self.provider.test_connection().await
245 }
246
247 async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
248 self.write_provider().execute_batch(sql).await
249 }
250
251 async fn query_raw(
252 &self,
253 query: &dyn crate::models::QuerySelector,
254 ) -> DatabaseResult<QueryResult> {
255 self.provider.query_raw(query).await
256 }
257
258 async fn query_raw_with(
259 &self,
260 query: &dyn crate::models::QuerySelector,
261 params: Vec<serde_json::Value>,
262 ) -> DatabaseResult<QueryResult> {
263 self.provider.query_raw_with(query, params).await
264 }
265}