systemprompt_database/services/
database.rs1use super::postgres::PostgresProvider;
2use super::provider::DatabaseProvider;
3use crate::models::{DatabaseInfo, QueryResult};
4use anyhow::Result;
5use std::sync::Arc;
6
7pub struct Database {
8 provider: Arc<dyn DatabaseProvider>,
9 write_provider: Option<Arc<dyn DatabaseProvider>>,
10}
11
12impl std::fmt::Debug for Database {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 f.debug_struct("Database")
15 .field("backend", &"PostgreSQL")
16 .finish()
17 }
18}
19
20impl Database {
21 pub async fn new_postgres(url: &str) -> Result<Self> {
22 let provider = PostgresProvider::new(url).await?;
23 Ok(Self {
24 provider: Arc::new(provider),
25 write_provider: None,
26 })
27 }
28
29 pub async fn from_config(db_type: &str, url: &str) -> Result<Self> {
30 match db_type.to_lowercase().as_str() {
31 "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
32 other => Err(anyhow::anyhow!(
33 "Unsupported database type: {other}. Only PostgreSQL is supported."
34 )),
35 }
36 }
37
38 pub async fn from_config_with_write(
39 db_type: &str,
40 read_url: &str,
41 write_url: Option<&str>,
42 ) -> Result<Self> {
43 let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
44 "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
45 other => {
46 return Err(anyhow::anyhow!(
47 "Unsupported database type: {other}. Only PostgreSQL is supported."
48 ));
49 },
50 };
51
52 let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
53 Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
54 None => None,
55 };
56
57 Ok(Self {
58 provider,
59 write_provider,
60 })
61 }
62
63 pub fn get_postgres_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
64 self.pool_arc()
65 }
66
67 pub fn write_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
68 self.write_provider.as_ref().map_or_else(
69 || self.get_postgres_pool_arc(),
70 |wp| {
71 wp.get_postgres_pool()
72 .ok_or_else(|| anyhow::anyhow!("Write database is not PostgreSQL"))
73 },
74 )
75 }
76
77 #[must_use]
78 pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
79 self.write_provider
80 .as_ref()
81 .and_then(|wp| wp.get_postgres_pool())
82 .or_else(|| self.provider.get_postgres_pool())
83 }
84
85 #[must_use]
86 pub fn has_write_pool(&self) -> bool {
87 self.write_provider.is_some()
88 }
89
90 #[must_use]
91 pub fn write_provider(&self) -> &dyn DatabaseProvider {
92 self.write_provider
93 .as_deref()
94 .unwrap_or_else(|| self.provider.as_ref())
95 }
96
97 pub async fn query(&self, sql: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
98 self.provider.query_raw(sql).await
99 }
100
101 pub async fn query_with(
103 &self,
104 sql: &dyn crate::models::QuerySelector,
105 params: Vec<serde_json::Value>,
106 ) -> Result<QueryResult> {
107 self.provider.query_raw_with(sql, params).await
108 }
109
110 pub async fn execute_batch(&self, sql: &str) -> Result<()> {
111 self.provider.execute_batch(sql).await
112 }
113
114 pub async fn get_info(&self) -> Result<DatabaseInfo> {
115 self.provider.get_database_info().await
116 }
117
118 pub async fn test_connection(&self) -> Result<()> {
119 self.provider.test_connection().await?;
120 if let Some(wp) = &self.write_provider {
121 wp.test_connection().await?;
122 }
123 Ok(())
124 }
125
126 #[must_use]
127 pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
128 self.write_provider
129 .as_ref()
130 .and_then(|wp| wp.get_postgres_pool())
131 .or_else(|| self.provider.get_postgres_pool())
132 }
133
134 pub fn pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
135 self.get_postgres_pool()
136 .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
137 }
138
139 #[must_use]
140 pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
141 self.get_postgres_pool()
142 }
143
144 #[must_use]
145 pub fn read_pool(&self) -> Option<Arc<sqlx::PgPool>> {
146 self.provider.get_postgres_pool()
147 }
148
149 pub fn read_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
150 self.provider
151 .get_postgres_pool()
152 .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
153 }
154
155 pub async fn begin(&self) -> Result<sqlx::Transaction<'_, sqlx::Postgres>> {
156 let pool = self.write_pool_arc()?;
157 pool.begin().await.map_err(Into::into)
158 }
159}
160
161pub type DbPool = Arc<Database>;
162
163pub trait DatabaseExt {
164 fn database(&self) -> Arc<Database>;
165}
166
167impl DatabaseExt for Arc<Database> {
168 fn database(&self) -> Arc<Database> {
169 Self::clone(self)
170 }
171}
172
173#[async_trait::async_trait]
174impl DatabaseProvider for Database {
175 fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
176 self.write_provider
177 .as_ref()
178 .and_then(|wp| wp.get_postgres_pool())
179 .or_else(|| self.provider.get_postgres_pool())
180 }
181
182 async fn execute(
183 &self,
184 query: &dyn crate::models::QuerySelector,
185 params: &[&dyn crate::models::ToDbValue],
186 ) -> Result<u64> {
187 self.write_provider().execute(query, params).await
188 }
189
190 async fn execute_raw(&self, sql: &str) -> Result<()> {
191 self.write_provider().execute_raw(sql).await
192 }
193
194 async fn fetch_all(
195 &self,
196 query: &dyn crate::models::QuerySelector,
197 params: &[&dyn crate::models::ToDbValue],
198 ) -> Result<Vec<crate::models::JsonRow>> {
199 self.provider.fetch_all(query, params).await
200 }
201
202 async fn fetch_one(
203 &self,
204 query: &dyn crate::models::QuerySelector,
205 params: &[&dyn crate::models::ToDbValue],
206 ) -> Result<crate::models::JsonRow> {
207 self.provider.fetch_one(query, params).await
208 }
209
210 async fn fetch_optional(
211 &self,
212 query: &dyn crate::models::QuerySelector,
213 params: &[&dyn crate::models::ToDbValue],
214 ) -> Result<Option<crate::models::JsonRow>> {
215 self.provider.fetch_optional(query, params).await
216 }
217
218 async fn fetch_scalar_value(
219 &self,
220 query: &dyn crate::models::QuerySelector,
221 params: &[&dyn crate::models::ToDbValue],
222 ) -> Result<crate::models::DbValue> {
223 self.provider.fetch_scalar_value(query, params).await
224 }
225
226 async fn begin_transaction(&self) -> Result<Box<dyn crate::models::DatabaseTransaction>> {
227 self.write_provider().begin_transaction().await
228 }
229
230 async fn get_database_info(&self) -> Result<DatabaseInfo> {
231 self.provider.get_database_info().await
232 }
233
234 async fn test_connection(&self) -> Result<()> {
235 self.provider.test_connection().await
236 }
237
238 async fn execute_batch(&self, sql: &str) -> Result<()> {
239 self.write_provider().execute_batch(sql).await
240 }
241
242 async fn query_raw(&self, query: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
243 self.provider.query_raw(query).await
244 }
245
246 async fn query_raw_with(
247 &self,
248 query: &dyn crate::models::QuerySelector,
249 params: Vec<serde_json::Value>,
250 ) -> Result<QueryResult> {
251 self.provider.query_raw_with(query, params).await
252 }
253}