1use std::sync::Arc;
18
19use futures_core::future::BoxFuture;
20use futures_core::stream::BoxStream;
21use sqlx_core::HashMap;
22use sqlx_core::connection::Connection;
23use sqlx_core::error::Error;
24use sqlx_core::executor::Executor;
25use sqlx_core::transaction::Transaction;
26
27use spg_embedded::QueryResult as EngineQueryResult;
28use spg_embedded_tokio::AsyncDatabase;
29
30use crate::column::SpgColumn;
31use crate::database::Spg;
32use crate::error::engine_to_sqlx;
33use crate::options::SpgConnectOptions;
34use crate::query_result::SpgQueryResult;
35use crate::row::SpgRow;
36use crate::type_info::SpgTypeInfo;
37
38#[derive(Debug, Clone)]
44pub(crate) struct CachedStmt {
45 pub(crate) readonly: bool,
46 pub(crate) stmt: std::sync::Arc<spg_embedded::Statement>,
47}
48
49const STMT_CACHE_CAP: usize = 256;
54
55#[derive(Debug)]
67pub struct SpgConnection {
68 pub(crate) inner: AsyncDatabase,
69 pub(crate) stmt_cache: HashMap<String, CachedStmt>,
71 pub(crate) tx_depth: usize,
72 pub(crate) pending_rollback: bool,
73}
74
75impl Clone for SpgConnection {
76 fn clone(&self) -> Self {
77 Self {
80 inner: self.inner.clone(),
81 stmt_cache: HashMap::new(),
82 tx_depth: self.tx_depth,
83 pending_rollback: self.pending_rollback,
84 }
85 }
86}
87
88impl SpgConnection {
89 pub fn new(inner: AsyncDatabase) -> Self {
93 Self {
94 inner,
95 stmt_cache: HashMap::new(),
96 tx_depth: 0,
97 pending_rollback: false,
98 }
99 }
100
101 #[must_use]
104 pub const fn engine(&self) -> &AsyncDatabase {
105 &self.inner
106 }
107
108 pub(crate) async fn cached_stmt(
112 &mut self,
113 sql: &str,
114 ) -> Result<CachedStmt, spg_embedded::EngineError> {
115 if let Some(c) = self.stmt_cache.get(sql) {
116 return Ok(c.clone());
117 }
118 let readonly = spg_embedded::Engine::is_readonly_sql(sql);
119 let snap = self.inner.clone_snapshot_inline().await;
125 let stmt = spg_embedded::Database::prepare_on_snapshot(&snap, sql)?;
126 let cached = CachedStmt {
127 readonly,
128 stmt: std::sync::Arc::new(stmt),
129 };
130 if self.stmt_cache.len() >= STMT_CACHE_CAP {
131 self.stmt_cache.clear();
132 }
133 self.stmt_cache.insert(sql.to_string(), cached.clone());
134 Ok(cached)
135 }
136}
137
138impl Connection for SpgConnection {
139 type Database = Spg;
140 type Options = SpgConnectOptions;
141
142 fn close(self) -> BoxFuture<'static, Result<(), Error>> {
143 Box::pin(async move { Ok(()) })
146 }
147
148 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
149 Box::pin(async move { Ok(()) })
150 }
151
152 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
153 Box::pin(async move {
156 self.inner
157 .execute("SELECT 1")
158 .await
159 .map_err(engine_to_sqlx)?;
160 Ok(())
161 })
162 }
163
164 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
165 where
166 Self: Sized,
167 {
168 Transaction::begin(self, None)
169 }
170
171 fn shrink_buffers(&mut self) {
172 }
174
175 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
176 Box::pin(async move { Ok(()) })
177 }
178
179 fn should_flush(&self) -> bool {
180 false
181 }
182}
183
184impl<'c> Executor<'c> for &'c mut SpgConnection {
195 type Database = Spg;
196
197 fn fetch_many<'e, 'q: 'e, E>(
198 self,
199 mut query: E,
200 ) -> BoxStream<
201 'e,
202 Result<
203 either::Either<
204 <Self::Database as sqlx_core::database::Database>::QueryResult,
205 crate::SpgRow,
206 >,
207 Error,
208 >,
209 >
210 where
211 'c: 'e,
212 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
213 {
214 use futures_util::stream::{self, StreamExt};
215 let sql = query.sql().to_string();
216 let arguments = match query.take_arguments() {
217 Ok(args) => args,
218 Err(e) => {
219 return Box::pin(stream::iter(std::iter::once(Err(Error::Encode(e)))));
220 }
221 };
222 let outcome_fut = async move { run_one(self, &sql, arguments).await };
223 Box::pin(stream::once(outcome_fut).flat_map(|outcome| {
224 let items: Vec<Result<either::Either<SpgQueryResult, SpgRow>, Error>> = match outcome {
225 Ok(Outcome::Affected(qr)) => vec![Ok(either::Either::Left(qr))],
226 Ok(Outcome::Rows(rows)) => rows
227 .into_iter()
228 .map(|r| Ok(either::Either::Right(r)))
229 .collect(),
230 Err(e) => vec![Err(e)],
231 };
232 stream::iter(items)
233 }))
234 }
235
236 fn fetch_optional<'e, 'q: 'e, E>(
237 self,
238 mut query: E,
239 ) -> BoxFuture<'e, Result<Option<crate::SpgRow>, Error>>
240 where
241 'c: 'e,
242 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
243 {
244 let sql = query.sql().to_string();
245 let arguments = query.take_arguments();
246 Box::pin(async move {
247 let args = arguments.map_err(Error::Encode)?;
248 match run_one(self, &sql, args).await? {
249 Outcome::Rows(mut rows) => Ok(rows.drain(..).next()),
250 Outcome::Affected(_) => Ok(None),
251 }
252 })
253 }
254
255 fn prepare_with<'e, 'q: 'e>(
256 self,
257 sql: &'q str,
258 _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
259 ) -> BoxFuture<
260 'e,
261 Result<<Self::Database as sqlx_core::database::Database>::Statement<'q>, Error>,
262 >
263 where
264 'c: 'e,
265 {
266 let inner = self.inner.clone();
267 let sql_str = sql.to_string();
268 Box::pin(async move {
269 let stmt = inner.prepare(&sql_str).await.map_err(engine_to_sqlx)?;
270 let inner_stmt = spg_embedded_tokio::async_statement_inner(&stmt);
276 Ok(crate::SpgStatement {
277 sql: std::borrow::Cow::Owned(sql_str),
278 inner: Some(inner_stmt),
279 columns: std::sync::Arc::new(Vec::new()),
280 by_name: std::sync::Arc::new(sqlx_core::HashMap::new()),
281 })
282 })
283 }
284
285 fn describe<'e, 'q: 'e>(
286 self,
287 sql: &'q str,
288 ) -> BoxFuture<'e, Result<sqlx_core::describe::Describe<Self::Database>, Error>>
289 where
290 'c: 'e,
291 {
292 let inner = self.inner.clone();
302 let sql_str = sql.to_string();
303 Box::pin(async move {
304 let (params, cols) = inner.describe(&sql_str).await.map_err(engine_to_sqlx)?;
305 let nullable: Vec<Option<bool>> = cols.iter().map(|c| Some(c.nullable)).collect();
306 let columns: Vec<SpgColumn> = cols
307 .iter()
308 .enumerate()
309 .map(|(i, c)| {
310 let ti = SpgTypeInfo::from_data_type(c.ty);
311 SpgColumn::new(i, c.name.clone(), ti)
312 })
313 .collect();
314 let parameters = if params.is_empty() {
315 None
316 } else {
317 Some(either::Either::Right(params.len()))
318 };
319 Ok(sqlx_core::describe::Describe {
320 columns,
321 parameters,
322 nullable,
323 })
324 })
325 }
326}
327
328enum Outcome {
332 Affected(SpgQueryResult),
334 Rows(Vec<SpgRow>),
337}
338
339async fn run_one(
340 conn: &mut SpgConnection,
341 sql: &str,
342 arguments: Option<crate::SpgArguments<'_>>,
343) -> Result<Outcome, Error> {
344 let in_tx = conn.tx_depth > 0;
352 let cached = if in_tx {
353 None
354 } else {
355 conn.cached_stmt(sql).await.ok()
357 };
358
359 let result: EngineQueryResult = if let Some(c) = cached.filter(|c| c.readonly) {
360 let snap = conn.inner.clone_snapshot_inline().await;
365 let params = arguments.map(crate::SpgArguments::into_engine_values);
366 spg_embedded::Database::execute_prepared_on_snapshot(
367 &snap,
368 &c.stmt,
369 params.as_deref().unwrap_or(&[]),
370 )
371 .map_err(engine_to_sqlx)?
372 } else {
373 let db = &conn.inner;
374 if let Some(args) = arguments {
375 let stmt = db.prepare(sql).await.map_err(engine_to_sqlx)?;
376 db.execute_prepared(&stmt, args.into_engine_values())
377 .await
378 .map_err(engine_to_sqlx)?
379 } else {
380 db.execute(sql).await.map_err(engine_to_sqlx)?
381 }
382 };
383 match result {
384 EngineQueryResult::Rows { columns, rows } => {
385 let row_values: Vec<Vec<spg_embedded::Value>> =
386 rows.into_iter().map(|r| r.values).collect();
387 Ok(Outcome::Rows(build_rows(&columns, row_values)))
388 }
389 EngineQueryResult::CommandOk { affected, .. } => Ok(Outcome::Affected(
390 SpgQueryResult::new(u64::try_from(affected).unwrap_or(0)),
391 )),
392 _ => Ok(Outcome::Affected(SpgQueryResult::default())),
393 }
394}
395
396#[allow(dead_code)]
397fn affected_from(qr: &EngineQueryResult) -> u64 {
398 match qr {
399 EngineQueryResult::CommandOk { affected, .. } => u64::try_from(*affected).unwrap_or(0),
400 EngineQueryResult::Rows { rows, .. } => u64::try_from(rows.len()).unwrap_or(0),
401 _ => 0,
402 }
403}
404
405fn build_rows(
406 cols: &[spg_embedded::ColumnSchema],
407 rows: Vec<Vec<spg_embedded::Value>>,
408) -> Vec<SpgRow> {
409 let columns: Arc<Vec<SpgColumn>> = Arc::new(
410 cols.iter()
411 .enumerate()
412 .map(|(i, c)| SpgColumn::new(i, c.name.clone(), SpgTypeInfo::from_data_type(c.ty)))
413 .collect(),
414 );
415 let mut by_name: HashMap<String, usize> = HashMap::new();
416 for (i, c) in cols.iter().enumerate() {
417 by_name.insert(c.name.clone(), i);
418 }
419 let by_name = Arc::new(by_name);
420 rows.into_iter()
421 .map(|values| SpgRow::new(Arc::clone(&columns), Arc::clone(&by_name), values))
422 .collect()
423}