1use std::sync::Arc;
18
19use futures_core::future::BoxFuture;
20use futures_core::stream::BoxStream;
21use sqlx_core::HashMap;
22use sqlx_core::connection::Connection;
23use sqlx_core::error::Error;
24use sqlx_core::executor::Executor;
25use sqlx_core::transaction::Transaction;
26
27use spg_embedded::QueryResult as EngineQueryResult;
28use spg_embedded_tokio::{AsyncDatabase, AsyncReadHandle};
29
30use crate::column::SpgColumn;
31use crate::database::Spg;
32use crate::error::engine_to_sqlx;
33use crate::options::SpgConnectOptions;
34use crate::query_result::SpgQueryResult;
35use crate::row::SpgRow;
36use crate::type_info::SpgTypeInfo;
37
38#[derive(Debug)]
56pub struct SpgConnection {
57 pub(crate) inner: AsyncDatabase,
58 pub(crate) read_handle: Option<AsyncReadHandle>,
63 pub(crate) tx_depth: usize,
64 pub(crate) pending_rollback: bool,
65}
66
67impl Clone for SpgConnection {
68 fn clone(&self) -> Self {
69 Self {
75 inner: self.inner.clone(),
76 read_handle: None,
77 tx_depth: self.tx_depth,
78 pending_rollback: self.pending_rollback,
79 }
80 }
81}
82
83impl SpgConnection {
84 pub fn new(inner: AsyncDatabase) -> Self {
88 Self {
89 inner,
90 read_handle: None,
91 tx_depth: 0,
92 pending_rollback: false,
93 }
94 }
95
96 #[must_use]
102 pub const fn engine(&self) -> &AsyncDatabase {
103 &self.inner
104 }
105
106 pub(crate) async fn fresh_read_handle(&mut self) -> &mut AsyncReadHandle {
115 if let Some(rh) = self.read_handle.as_mut() {
116 rh.refresh().await;
117 } else {
118 self.read_handle = Some(self.inner.read_handle().await);
119 }
120 self.read_handle.as_mut().expect("set above")
121 }
122}
123
124impl Connection for SpgConnection {
125 type Database = Spg;
126 type Options = SpgConnectOptions;
127
128 fn close(self) -> BoxFuture<'static, Result<(), Error>> {
129 Box::pin(async move { Ok(()) })
132 }
133
134 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
135 Box::pin(async move { Ok(()) })
136 }
137
138 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
139 Box::pin(async move {
142 self.inner
143 .execute("SELECT 1")
144 .await
145 .map_err(engine_to_sqlx)?;
146 Ok(())
147 })
148 }
149
150 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
151 where
152 Self: Sized,
153 {
154 Transaction::begin(self, None)
155 }
156
157 fn shrink_buffers(&mut self) {
158 }
160
161 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
162 Box::pin(async move { Ok(()) })
163 }
164
165 fn should_flush(&self) -> bool {
166 false
167 }
168}
169
170impl<'c> Executor<'c> for &'c mut SpgConnection {
181 type Database = Spg;
182
183 fn fetch_many<'e, 'q: 'e, E>(
184 self,
185 mut query: E,
186 ) -> BoxStream<
187 'e,
188 Result<
189 either::Either<
190 <Self::Database as sqlx_core::database::Database>::QueryResult,
191 crate::SpgRow,
192 >,
193 Error,
194 >,
195 >
196 where
197 'c: 'e,
198 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
199 {
200 use futures_util::stream::{self, StreamExt};
201 let sql = query.sql().to_string();
202 let arguments = match query.take_arguments() {
203 Ok(args) => args,
204 Err(e) => {
205 return Box::pin(stream::iter(std::iter::once(Err(Error::Encode(e)))));
206 }
207 };
208 let outcome_fut = async move { run_one(self, &sql, arguments).await };
209 Box::pin(stream::once(outcome_fut).flat_map(|outcome| {
210 let items: Vec<Result<either::Either<SpgQueryResult, SpgRow>, Error>> = match outcome {
211 Ok(Outcome::Affected(qr)) => vec![Ok(either::Either::Left(qr))],
212 Ok(Outcome::Rows(rows)) => rows
213 .into_iter()
214 .map(|r| Ok(either::Either::Right(r)))
215 .collect(),
216 Err(e) => vec![Err(e)],
217 };
218 stream::iter(items)
219 }))
220 }
221
222 fn fetch_optional<'e, 'q: 'e, E>(
223 self,
224 mut query: E,
225 ) -> BoxFuture<'e, Result<Option<crate::SpgRow>, Error>>
226 where
227 'c: 'e,
228 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
229 {
230 let sql = query.sql().to_string();
231 let arguments = query.take_arguments();
232 Box::pin(async move {
233 let args = arguments.map_err(Error::Encode)?;
234 match run_one(self, &sql, args).await? {
235 Outcome::Rows(mut rows) => Ok(rows.drain(..).next()),
236 Outcome::Affected(_) => Ok(None),
237 }
238 })
239 }
240
241 fn prepare_with<'e, 'q: 'e>(
242 self,
243 sql: &'q str,
244 _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
245 ) -> BoxFuture<
246 'e,
247 Result<<Self::Database as sqlx_core::database::Database>::Statement<'q>, Error>,
248 >
249 where
250 'c: 'e,
251 {
252 let inner = self.inner.clone();
253 let sql_str = sql.to_string();
254 Box::pin(async move {
255 let stmt = inner.prepare(&sql_str).await.map_err(engine_to_sqlx)?;
256 let inner_stmt = spg_embedded_tokio::async_statement_inner(&stmt);
262 Ok(crate::SpgStatement {
263 sql: std::borrow::Cow::Owned(sql_str),
264 inner: Some(inner_stmt),
265 columns: std::sync::Arc::new(Vec::new()),
266 by_name: std::sync::Arc::new(sqlx_core::HashMap::new()),
267 })
268 })
269 }
270
271 fn describe<'e, 'q: 'e>(
272 self,
273 sql: &'q str,
274 ) -> BoxFuture<'e, Result<sqlx_core::describe::Describe<Self::Database>, Error>>
275 where
276 'c: 'e,
277 {
278 let inner = self.inner.clone();
288 let sql_str = sql.to_string();
289 Box::pin(async move {
290 let (params, cols) = inner.describe(&sql_str).await.map_err(engine_to_sqlx)?;
291 let nullable: Vec<Option<bool>> = cols.iter().map(|c| Some(c.nullable)).collect();
292 let columns: Vec<SpgColumn> = cols
293 .iter()
294 .enumerate()
295 .map(|(i, c)| {
296 let ti = SpgTypeInfo::from_data_type(c.ty);
297 SpgColumn::new(i, c.name.clone(), ti)
298 })
299 .collect();
300 let parameters = if params.is_empty() {
301 None
302 } else {
303 Some(either::Either::Right(params.len()))
304 };
305 Ok(sqlx_core::describe::Describe {
306 columns,
307 parameters,
308 nullable,
309 })
310 })
311 }
312}
313
314enum Outcome {
318 Affected(SpgQueryResult),
320 Rows(Vec<SpgRow>),
323}
324
325async fn run_one(
326 conn: &mut SpgConnection,
327 sql: &str,
328 arguments: Option<crate::SpgArguments<'_>>,
329) -> Result<Outcome, Error> {
330 let in_tx = conn.tx_depth > 0;
339 let use_readonly = !in_tx && spg_embedded::Engine::is_readonly_sql(sql);
340
341 let result: EngineQueryResult = if use_readonly {
342 let rh = conn.fresh_read_handle().await;
343 if let Some(args) = arguments {
344 let stmt = rh.prepare(sql).await.map_err(engine_to_sqlx)?;
345 rh.execute_prepared(&stmt, args.into_engine_values())
346 .await
347 .map_err(engine_to_sqlx)?
348 } else {
349 rh.query(sql).await.map_err(engine_to_sqlx)?
350 }
351 } else {
352 let db = &conn.inner;
353 if let Some(args) = arguments {
354 let stmt = db.prepare(sql).await.map_err(engine_to_sqlx)?;
355 db.execute_prepared(&stmt, args.into_engine_values())
356 .await
357 .map_err(engine_to_sqlx)?
358 } else {
359 db.execute(sql).await.map_err(engine_to_sqlx)?
360 }
361 };
362 match result {
363 EngineQueryResult::Rows { columns, rows } => {
364 let row_values: Vec<Vec<spg_embedded::Value>> =
365 rows.into_iter().map(|r| r.values).collect();
366 Ok(Outcome::Rows(build_rows(&columns, row_values)))
367 }
368 EngineQueryResult::CommandOk { affected, .. } => Ok(Outcome::Affected(
369 SpgQueryResult::new(u64::try_from(affected).unwrap_or(0)),
370 )),
371 _ => Ok(Outcome::Affected(SpgQueryResult::default())),
372 }
373}
374
375#[allow(dead_code)]
376fn affected_from(qr: &EngineQueryResult) -> u64 {
377 match qr {
378 EngineQueryResult::CommandOk { affected, .. } => u64::try_from(*affected).unwrap_or(0),
379 EngineQueryResult::Rows { rows, .. } => u64::try_from(rows.len()).unwrap_or(0),
380 _ => 0,
381 }
382}
383
384fn build_rows(
385 cols: &[spg_embedded::ColumnSchema],
386 rows: Vec<Vec<spg_embedded::Value>>,
387) -> Vec<SpgRow> {
388 let columns: Arc<Vec<SpgColumn>> = Arc::new(
389 cols.iter()
390 .enumerate()
391 .map(|(i, c)| SpgColumn::new(i, c.name.clone(), SpgTypeInfo::from_data_type(c.ty)))
392 .collect(),
393 );
394 let mut by_name: HashMap<String, usize> = HashMap::new();
395 for (i, c) in cols.iter().enumerate() {
396 by_name.insert(c.name.clone(), i);
397 }
398 let by_name = Arc::new(by_name);
399 rows.into_iter()
400 .map(|values| SpgRow::new(Arc::clone(&columns), Arc::clone(&by_name), values))
401 .collect()
402}