1use std::sync::Arc;
18
19use futures_core::future::BoxFuture;
20use futures_core::stream::BoxStream;
21use sqlx_core::HashMap;
22use sqlx_core::connection::Connection;
23use sqlx_core::error::Error;
24use sqlx_core::executor::Executor;
25use sqlx_core::transaction::Transaction;
26
27use spg_embedded::QueryResult as EngineQueryResult;
28use spg_embedded_tokio::AsyncDatabase;
29
30use crate::column::SpgColumn;
31use crate::database::Spg;
32use crate::error::engine_to_sqlx;
33use crate::options::SpgConnectOptions;
34use crate::query_result::SpgQueryResult;
35use crate::row::SpgRow;
36use crate::type_info::SpgTypeInfo;
37
38#[derive(Debug, Clone)]
44pub(crate) struct CachedStmt {
45 pub(crate) readonly: bool,
46 pub(crate) stmt: std::sync::Arc<spg_embedded::Statement>,
47}
48
49const STMT_CACHE_CAP: usize = 256;
54
55#[derive(Debug)]
67pub struct SpgConnection {
68 pub(crate) inner: AsyncDatabase,
69 pub(crate) stmt_cache: HashMap<String, CachedStmt>,
71 pub(crate) tx_depth: usize,
72 pub(crate) pending_rollback: bool,
73}
74
75impl Clone for SpgConnection {
76 fn clone(&self) -> Self {
77 Self {
80 inner: self.inner.clone(),
81 stmt_cache: HashMap::new(),
82 tx_depth: self.tx_depth,
83 pending_rollback: self.pending_rollback,
84 }
85 }
86}
87
88impl SpgConnection {
89 pub fn new(inner: AsyncDatabase) -> Self {
93 Self {
94 inner,
95 stmt_cache: HashMap::new(),
96 tx_depth: 0,
97 pending_rollback: false,
98 }
99 }
100
101 #[must_use]
104 pub const fn engine(&self) -> &AsyncDatabase {
105 &self.inner
106 }
107
108 pub(crate) async fn cached_stmt(
112 &mut self,
113 sql: &str,
114 ) -> Result<CachedStmt, spg_embedded::EngineError> {
115 if let Some(c) = self.stmt_cache.get(sql) {
116 return Ok(c.clone());
117 }
118 let readonly = spg_embedded::Engine::is_readonly_sql(sql);
119 let snap = self.inner.clone_snapshot_inline().await;
125 let stmt = spg_embedded::Database::prepare_on_snapshot(&snap, sql)?;
126 let cached = CachedStmt {
127 readonly,
128 stmt: std::sync::Arc::new(stmt),
129 };
130 if self.stmt_cache.len() >= STMT_CACHE_CAP {
131 self.stmt_cache.clear();
132 }
133 self.stmt_cache.insert(sql.to_string(), cached.clone());
134 Ok(cached)
135 }
136}
137
138impl Connection for SpgConnection {
139 type Database = Spg;
140 type Options = SpgConnectOptions;
141
142 fn close(self) -> BoxFuture<'static, Result<(), Error>> {
143 Box::pin(async move { Ok(()) })
146 }
147
148 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
149 Box::pin(async move { Ok(()) })
150 }
151
152 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
153 Box::pin(async move {
156 self.inner
157 .execute("SELECT 1")
158 .await
159 .map_err(engine_to_sqlx)?;
160 Ok(())
161 })
162 }
163
164 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
165 where
166 Self: Sized,
167 {
168 Transaction::begin(self, None)
169 }
170
171 fn shrink_buffers(&mut self) {
172 }
174
175 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
176 Box::pin(async move { Ok(()) })
177 }
178
179 fn should_flush(&self) -> bool {
180 false
181 }
182}
183
184impl<'c> Executor<'c> for &'c mut SpgConnection {
195 type Database = Spg;
196
197 fn fetch_many<'e, 'q: 'e, E>(
198 self,
199 mut query: E,
200 ) -> BoxStream<
201 'e,
202 Result<
203 either::Either<
204 <Self::Database as sqlx_core::database::Database>::QueryResult,
205 crate::SpgRow,
206 >,
207 Error,
208 >,
209 >
210 where
211 'c: 'e,
212 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
213 {
214 use futures_util::stream::{self, StreamExt};
215 let sql = query.sql().to_string();
216 let arguments = match query.take_arguments() {
217 Ok(args) => args,
218 Err(e) => {
219 return Box::pin(stream::iter(std::iter::once(Err(Error::Encode(e)))));
220 }
221 };
222 let outcome_fut = async move {
223 match arguments {
224 Some(args) => run_one(self, &sql, Some(args)).await.map(|o| vec![o]),
227 None => Ok(self
234 .inner
235 .execute_script(&sql)
236 .await
237 .map_err(engine_to_sqlx)?
238 .into_iter()
239 .map(outcome_from)
240 .collect()),
241 }
242 };
243 Box::pin(stream::once(outcome_fut).flat_map(|outcome| {
244 let items: Vec<Result<either::Either<SpgQueryResult, SpgRow>, Error>> = match outcome {
245 Ok(outcomes) => outcomes
246 .into_iter()
247 .flat_map(|o| match o {
248 Outcome::Affected(qr) => vec![Ok(either::Either::Left(qr))],
249 Outcome::Rows(rows) => rows
250 .into_iter()
251 .map(|r| Ok(either::Either::Right(r)))
252 .collect::<Vec<_>>(),
253 })
254 .collect(),
255 Err(e) => vec![Err(e)],
256 };
257 stream::iter(items)
258 }))
259 }
260
261 fn fetch_optional<'e, 'q: 'e, E>(
262 self,
263 mut query: E,
264 ) -> BoxFuture<'e, Result<Option<crate::SpgRow>, Error>>
265 where
266 'c: 'e,
267 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
268 {
269 let sql = query.sql().to_string();
270 let arguments = query.take_arguments();
271 Box::pin(async move {
272 let args = arguments.map_err(Error::Encode)?;
273 match run_one(self, &sql, args).await? {
274 Outcome::Rows(mut rows) => Ok(rows.drain(..).next()),
275 Outcome::Affected(_) => Ok(None),
276 }
277 })
278 }
279
280 fn prepare_with<'e, 'q: 'e>(
281 self,
282 sql: &'q str,
283 _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
284 ) -> BoxFuture<
285 'e,
286 Result<<Self::Database as sqlx_core::database::Database>::Statement<'q>, Error>,
287 >
288 where
289 'c: 'e,
290 {
291 let inner = self.inner.clone();
292 let sql_str = sql.to_string();
293 Box::pin(async move {
294 let stmt = inner.prepare(&sql_str).await.map_err(engine_to_sqlx)?;
295 let inner_stmt = spg_embedded_tokio::async_statement_inner(&stmt);
301 Ok(crate::SpgStatement {
302 sql: std::borrow::Cow::Owned(sql_str),
303 inner: Some(inner_stmt),
304 columns: std::sync::Arc::new(Vec::new()),
305 by_name: std::sync::Arc::new(sqlx_core::HashMap::new()),
306 })
307 })
308 }
309
310 fn describe<'e, 'q: 'e>(
311 self,
312 sql: &'q str,
313 ) -> BoxFuture<'e, Result<sqlx_core::describe::Describe<Self::Database>, Error>>
314 where
315 'c: 'e,
316 {
317 let inner = self.inner.clone();
327 let sql_str = sql.to_string();
328 Box::pin(async move {
329 let (params, cols) = inner.describe(&sql_str).await.map_err(engine_to_sqlx)?;
330 let nullable: Vec<Option<bool>> = cols.iter().map(|c| Some(c.nullable)).collect();
331 let columns: Vec<SpgColumn> = cols
332 .iter()
333 .enumerate()
334 .map(|(i, c)| {
335 let ti = SpgTypeInfo::from_data_type(c.ty);
336 SpgColumn::new(i, c.name.clone(), ti)
337 })
338 .collect();
339 let parameters = if params.is_empty() {
340 None
341 } else {
342 Some(either::Either::Right(params.len()))
343 };
344 Ok(sqlx_core::describe::Describe {
345 columns,
346 parameters,
347 nullable,
348 })
349 })
350 }
351}
352
353enum Outcome {
357 Affected(SpgQueryResult),
359 Rows(Vec<SpgRow>),
362}
363
364async fn run_one(
365 conn: &mut SpgConnection,
366 sql: &str,
367 arguments: Option<crate::SpgArguments<'_>>,
368) -> Result<Outcome, Error> {
369 let in_tx = conn.tx_depth > 0;
377 let cached = if in_tx {
378 None
379 } else {
380 conn.cached_stmt(sql).await.ok()
382 };
383
384 let result: EngineQueryResult = if let Some(c) = cached.filter(|c| c.readonly) {
385 let snap = conn.inner.clone_snapshot_inline().await;
390 let params = arguments.map(crate::SpgArguments::into_engine_values);
391 spg_embedded::Database::execute_prepared_on_snapshot(
392 &snap,
393 &c.stmt,
394 params.as_deref().unwrap_or(&[]),
395 )
396 .map_err(engine_to_sqlx)?
397 } else {
398 let db = &conn.inner;
399 if let Some(args) = arguments {
400 let stmt = db.prepare(sql).await.map_err(engine_to_sqlx)?;
401 db.execute_prepared(&stmt, args.into_engine_values())
402 .await
403 .map_err(engine_to_sqlx)?
404 } else {
405 db.execute(sql).await.map_err(engine_to_sqlx)?
406 }
407 };
408 Ok(outcome_from(result))
409}
410
411fn outcome_from(result: EngineQueryResult) -> Outcome {
412 match result {
413 EngineQueryResult::Rows { columns, rows } => {
414 let row_values: Vec<Vec<spg_embedded::Value>> =
415 rows.into_iter().map(|r| r.values).collect();
416 Outcome::Rows(build_rows(&columns, row_values))
417 }
418 EngineQueryResult::CommandOk { affected, .. } => {
419 Outcome::Affected(SpgQueryResult::new(u64::try_from(affected).unwrap_or(0)))
420 }
421 _ => Outcome::Affected(SpgQueryResult::default()),
422 }
423}
424
425#[allow(dead_code)]
426fn affected_from(qr: &EngineQueryResult) -> u64 {
427 match qr {
428 EngineQueryResult::CommandOk { affected, .. } => u64::try_from(*affected).unwrap_or(0),
429 EngineQueryResult::Rows { rows, .. } => u64::try_from(rows.len()).unwrap_or(0),
430 _ => 0,
431 }
432}
433
434fn build_rows(
435 cols: &[spg_embedded::ColumnSchema],
436 rows: Vec<Vec<spg_embedded::Value>>,
437) -> Vec<SpgRow> {
438 let columns: Arc<Vec<SpgColumn>> = Arc::new(
439 cols.iter()
440 .enumerate()
441 .map(|(i, c)| SpgColumn::new(i, c.name.clone(), SpgTypeInfo::from_data_type(c.ty)))
442 .collect(),
443 );
444 let mut by_name: HashMap<String, usize> = HashMap::new();
445 for (i, c) in cols.iter().enumerate() {
446 by_name.insert(c.name.clone(), i);
447 }
448 let by_name = Arc::new(by_name);
449 rows.into_iter()
450 .map(|values| SpgRow::new(Arc::clone(&columns), Arc::clone(&by_name), values))
451 .collect()
452}