1use std::sync::Arc;
12
13use futures_core::future::BoxFuture;
14use futures_core::stream::BoxStream;
15use sqlx_core::HashMap;
16use sqlx_core::connection::Connection;
17use sqlx_core::error::Error;
18use sqlx_core::executor::Executor;
19use sqlx_core::transaction::Transaction;
20
21use spg_embedded::QueryResult as EngineQueryResult;
22use spg_embedded_tokio::AsyncDatabase;
23
24use crate::column::SpgColumn;
25use crate::database::Spg;
26use crate::error::engine_to_sqlx;
27use crate::options::SpgConnectOptions;
28use crate::query_result::SpgQueryResult;
29use crate::row::SpgRow;
30use crate::type_info::SpgTypeInfo;
31
32#[derive(Debug, Clone)]
34pub struct SpgConnection {
35 pub(crate) inner: AsyncDatabase,
36 pub(crate) tx_depth: usize,
37 pub(crate) pending_rollback: bool,
38}
39
40impl SpgConnection {
41 pub fn new(inner: AsyncDatabase) -> Self {
45 Self {
46 inner,
47 tx_depth: 0,
48 pending_rollback: false,
49 }
50 }
51
52 #[must_use]
56 pub const fn engine(&self) -> &AsyncDatabase {
57 &self.inner
58 }
59}
60
61impl Connection for SpgConnection {
62 type Database = Spg;
63 type Options = SpgConnectOptions;
64
65 fn close(self) -> BoxFuture<'static, Result<(), Error>> {
66 Box::pin(async move { Ok(()) })
69 }
70
71 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
72 Box::pin(async move { Ok(()) })
73 }
74
75 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
76 Box::pin(async move {
79 self.inner
80 .execute("SELECT 1")
81 .await
82 .map_err(engine_to_sqlx)?;
83 Ok(())
84 })
85 }
86
87 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
88 where
89 Self: Sized,
90 {
91 Transaction::begin(self, None)
92 }
93
94 fn shrink_buffers(&mut self) {
95 }
97
98 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
99 Box::pin(async move { Ok(()) })
100 }
101
102 fn should_flush(&self) -> bool {
103 false
104 }
105}
106
107impl<'c> Executor<'c> for &'c mut SpgConnection {
113 type Database = Spg;
114
115 fn fetch_many<'e, 'q: 'e, E>(
116 self,
117 mut query: E,
118 ) -> BoxStream<
119 'e,
120 Result<
121 either::Either<
122 <Self::Database as sqlx_core::database::Database>::QueryResult,
123 crate::SpgRow,
124 >,
125 Error,
126 >,
127 >
128 where
129 'c: 'e,
130 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
131 {
132 use futures_util::stream::{self, StreamExt};
133 let sql = query.sql().to_string();
134 let arguments = match query.take_arguments() {
135 Ok(args) => args,
136 Err(e) => {
137 return Box::pin(stream::iter(std::iter::once(Err(Error::Encode(e)))));
138 }
139 };
140 let inner = self.inner.clone();
141 let outcome_fut = async move { run_one(&inner, &sql, arguments).await };
142 Box::pin(stream::once(outcome_fut).flat_map(|outcome| {
143 let items: Vec<Result<either::Either<SpgQueryResult, SpgRow>, Error>> = match outcome {
144 Ok(Outcome::Affected(qr)) => vec![Ok(either::Either::Left(qr))],
145 Ok(Outcome::Rows(rows)) => rows
146 .into_iter()
147 .map(|r| Ok(either::Either::Right(r)))
148 .collect(),
149 Err(e) => vec![Err(e)],
150 };
151 stream::iter(items)
152 }))
153 }
154
155 fn fetch_optional<'e, 'q: 'e, E>(
156 self,
157 mut query: E,
158 ) -> BoxFuture<'e, Result<Option<crate::SpgRow>, Error>>
159 where
160 'c: 'e,
161 E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
162 {
163 let sql = query.sql().to_string();
164 let arguments = query.take_arguments();
165 let inner = self.inner.clone();
166 Box::pin(async move {
167 let args = arguments.map_err(Error::Encode)?;
168 match run_one(&inner, &sql, args).await? {
169 Outcome::Rows(mut rows) => Ok(rows.drain(..).next()),
170 Outcome::Affected(_) => Ok(None),
171 }
172 })
173 }
174
175 fn prepare_with<'e, 'q: 'e>(
176 self,
177 sql: &'q str,
178 _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
179 ) -> BoxFuture<
180 'e,
181 Result<<Self::Database as sqlx_core::database::Database>::Statement<'q>, Error>,
182 >
183 where
184 'c: 'e,
185 {
186 let inner = self.inner.clone();
187 let sql_str = sql.to_string();
188 Box::pin(async move {
189 let stmt = inner.prepare(&sql_str).await.map_err(engine_to_sqlx)?;
190 let inner_stmt = spg_embedded_tokio::async_statement_inner(&stmt);
196 Ok(crate::SpgStatement {
197 sql: std::borrow::Cow::Owned(sql_str),
198 inner: Some(inner_stmt),
199 columns: std::sync::Arc::new(Vec::new()),
200 by_name: std::sync::Arc::new(sqlx_core::HashMap::new()),
201 })
202 })
203 }
204
205 fn describe<'e, 'q: 'e>(
206 self,
207 _sql: &'q str,
208 ) -> BoxFuture<'e, Result<sqlx_core::describe::Describe<Self::Database>, Error>>
209 where
210 'c: 'e,
211 {
212 Box::pin(async move {
213 Err(Error::Protocol(
214 "describe is v7.17 — compile-time sqlx::query!() macros need offline mode in the meantime".into(),
215 ))
216 })
217 }
218}
219
220enum Outcome {
224 Affected(SpgQueryResult),
226 Rows(Vec<SpgRow>),
229}
230
231async fn run_one(
232 db: &AsyncDatabase,
233 sql: &str,
234 arguments: Option<crate::SpgArguments<'_>>,
235) -> Result<Outcome, Error> {
236 let result: EngineQueryResult = if let Some(args) = arguments {
243 let stmt = db.prepare(sql).await.map_err(engine_to_sqlx)?;
244 db.execute_prepared(&stmt, args.into_engine_values())
245 .await
246 .map_err(engine_to_sqlx)?
247 } else {
248 db.execute(sql).await.map_err(engine_to_sqlx)?
249 };
250 match result {
251 EngineQueryResult::Rows { columns, rows } => {
252 let row_values: Vec<Vec<spg_embedded::Value>> =
253 rows.into_iter().map(|r| r.values).collect();
254 Ok(Outcome::Rows(build_rows(&columns, row_values)))
255 }
256 EngineQueryResult::CommandOk { affected, .. } => Ok(Outcome::Affected(
257 SpgQueryResult::new(u64::try_from(affected).unwrap_or(0)),
258 )),
259 _ => Ok(Outcome::Affected(SpgQueryResult::default())),
260 }
261}
262
263#[allow(dead_code)]
264fn affected_from(qr: &EngineQueryResult) -> u64 {
265 match qr {
266 EngineQueryResult::CommandOk { affected, .. } => u64::try_from(*affected).unwrap_or(0),
267 EngineQueryResult::Rows { rows, .. } => u64::try_from(rows.len()).unwrap_or(0),
268 _ => 0,
269 }
270}
271
272fn build_rows(
273 cols: &[spg_embedded::ColumnSchema],
274 rows: Vec<Vec<spg_embedded::Value>>,
275) -> Vec<SpgRow> {
276 let columns: Arc<Vec<SpgColumn>> = Arc::new(
277 cols.iter()
278 .enumerate()
279 .map(|(i, c)| SpgColumn::new(i, c.name.clone(), SpgTypeInfo::from_data_type(c.ty)))
280 .collect(),
281 );
282 let mut by_name: HashMap<String, usize> = HashMap::new();
283 for (i, c) in cols.iter().enumerate() {
284 by_name.insert(c.name.clone(), i);
285 }
286 let by_name = Arc::new(by_name);
287 rows.into_iter()
288 .map(|values| SpgRow::new(Arc::clone(&columns), Arc::clone(&by_name), values))
289 .collect()
290}