1use core::{ops::Deref, sync::atomic::Ordering};
4
5use std::sync::Arc;
6
7use super::{
8 column::Column,
9 driver::codec::{AsParams, encode::StatementCancel},
10 query::Query,
11 types::{ToSql, Type},
12};
13
14pub struct StatementGuarded<'a, C>
20where
21 C: Query,
22{
23 stmt: Option<Statement>,
24 cli: &'a C,
25}
26
27impl<C> AsRef<Statement> for StatementGuarded<'_, C>
28where
29 C: Query,
30{
31 #[inline]
32 fn as_ref(&self) -> &Statement {
33 self
34 }
35}
36
37impl<C> Deref for StatementGuarded<'_, C>
38where
39 C: Query,
40{
41 type Target = Statement;
42
43 fn deref(&self) -> &Self::Target {
44 self.stmt.as_ref().unwrap()
45 }
46}
47
48impl<C> Drop for StatementGuarded<'_, C>
49where
50 C: Query,
51{
52 fn drop(&mut self) {
53 if let Some(stmt) = self.stmt.take() {
54 let _ = self.cli._send_encode_query(StatementCancel { name: stmt.name() });
55 }
56 }
57}
58
59impl<C> StatementGuarded<'_, C>
60where
61 C: Query,
62{
63 pub fn leak(mut self) -> Statement {
66 self.stmt.take().unwrap()
67 }
68}
69
70#[derive(Default)]
83pub struct Statement {
84 name: Arc<str>,
85 params: Arc<[Type]>,
86 columns: Arc<[Column]>,
87}
88
89impl Statement {
90 pub(crate) fn new(name: String, params: Vec<Type>, columns: Vec<Column>) -> Self {
91 Self {
92 name: name.into(),
93 params: params.into(),
94 columns: columns.into(),
95 }
96 }
97
98 pub(crate) fn duplicate(&self) -> Self {
100 Self {
101 name: self.name.clone(),
102 params: self.params.clone(),
103 columns: self.columns.clone(),
104 }
105 }
106
107 pub(crate) fn name(&self) -> &str {
108 &self.name
109 }
110
111 pub(crate) fn columns_owned(&self) -> Arc<[Column]> {
112 self.columns.clone()
113 }
114
115 #[inline]
120 pub const fn named<'a>(stmt: &'a str, types: &'a [Type]) -> StatementNamed<'a> {
121 StatementNamed { stmt, types }
122 }
123
124 #[inline]
139 pub fn bind<P>(&self, params: P) -> StatementPreparedQuery<'_, P>
140 where
141 P: AsParams,
142 {
143 StatementPreparedQuery { stmt: self, params }
144 }
145
146 #[inline]
156 pub fn bind_dyn<'p, 't>(
157 &self,
158 params: &'p [&'t (dyn ToSql + Sync)],
159 ) -> StatementPreparedQuery<'_, impl ExactSizeIterator<Item = &'t (dyn ToSql + Sync)> + Clone + 'p> {
160 self.bind(params.iter().cloned())
161 }
162
163 #[inline]
166 pub fn bind_none(&self) -> StatementPreparedQuery<'_, [bool; 0]> {
167 self.bind([])
168 }
169
170 #[inline]
172 pub fn params(&self) -> &[Type] {
173 &self.params
174 }
175
176 #[inline]
178 pub fn columns(&self) -> &[Column] {
179 &self.columns
180 }
181
182 #[inline]
184 pub fn into_guarded<C>(self, cli: &C) -> StatementGuarded<'_, C>
185 where
186 C: Query,
187 {
188 StatementGuarded { stmt: Some(self), cli }
189 }
190}
191
192pub struct StatementNamed<'a> {
194 pub(crate) stmt: &'a str,
195 pub(crate) types: &'a [Type],
196}
197
198impl<'a> StatementNamed<'a> {
199 fn name() -> String {
200 let id = crate::NEXT_ID.fetch_add(1, Ordering::Relaxed);
201 format!("s{id}")
202 }
203
204 #[inline]
206 pub fn bind<P>(self, params: P) -> StatementQuery<'a, P> {
207 StatementQuery {
208 stmt: self.stmt,
209 types: self.types,
210 params,
211 }
212 }
213
214 #[inline]
216 pub fn bind_dyn<'p, 't>(
217 self,
218 params: &'p [&'t (dyn ToSql + Sync)],
219 ) -> StatementQuery<'a, impl ExactSizeIterator<Item = &'t (dyn ToSql + Sync)> + Clone + 'p> {
220 self.bind(params.iter().cloned())
221 }
222
223 #[inline]
225 pub fn bind_none(self) -> StatementQuery<'a, [bool; 0]> {
226 StatementQuery {
227 stmt: self.stmt,
228 types: self.types,
229 params: [],
230 }
231 }
232}
233
234pub(crate) struct StatementCreate<'a, 'c, C> {
235 pub(crate) name: String,
236 pub(crate) stmt: &'a str,
237 pub(crate) types: &'a [Type],
238 pub(crate) cli: &'c C,
239}
240
241impl<'a, 'c, C> From<(StatementNamed<'a>, &'c C)> for StatementCreate<'a, 'c, C> {
242 fn from((stmt, cli): (StatementNamed<'a>, &'c C)) -> Self {
243 Self {
244 name: StatementNamed::name(),
245 stmt: stmt.stmt,
246 types: stmt.types,
247 cli,
248 }
249 }
250}
251
252pub(crate) struct StatementCreateBlocking<'a, 'c, C> {
253 pub(crate) name: String,
254 pub(crate) stmt: &'a str,
255 pub(crate) types: &'a [Type],
256 pub(crate) cli: &'c C,
257}
258
259impl<'a, 'c, C> From<(StatementNamed<'a>, &'c C)> for StatementCreateBlocking<'a, 'c, C> {
260 fn from((stmt, cli): (StatementNamed<'a>, &'c C)) -> Self {
261 Self {
262 name: StatementNamed::name(),
263 stmt: stmt.stmt,
264 types: stmt.types,
265 cli,
266 }
267 }
268}
269
270pub struct StatementPreparedQuery<'a, P> {
277 pub(crate) stmt: &'a Statement,
278 pub(crate) params: P,
279}
280
281impl<'a, P> StatementPreparedQuery<'a, P> {
282 #[inline]
283 pub fn into_owned(self) -> StatementPreparedQueryOwned<'a, P> {
284 StatementPreparedQueryOwned {
285 stmt: self.stmt,
286 params: self.params,
287 }
288 }
289}
290
291pub struct StatementPreparedQueryOwned<'a, P> {
298 pub(crate) stmt: &'a Statement,
299 pub(crate) params: P,
300}
301
302pub struct StatementQuery<'a, P> {
318 pub(crate) stmt: &'a str,
319 pub(crate) types: &'a [Type],
320 pub(crate) params: P,
321}
322
323impl<'a, P> StatementQuery<'a, P> {
324 pub fn into_single_rtt(self) -> StatementSingleRTTQuery<'a, P> {
328 StatementSingleRTTQuery { query: self }
329 }
330}
331
332pub struct StatementSingleRTTQuery<'a, P> {
335 query: StatementQuery<'a, P>,
336}
337
338impl<'a, P> StatementSingleRTTQuery<'a, P> {
339 pub(crate) fn into_with_cli<'c, C>(self, cli: &'c C) -> StatementSingleRTTQueryWithCli<'a, 'c, P, C> {
340 StatementSingleRTTQueryWithCli { query: self.query, cli }
341 }
342}
343
344pub(crate) struct StatementSingleRTTQueryWithCli<'a, 'c, P, C> {
345 pub(crate) query: StatementQuery<'a, P>,
346 pub(crate) cli: &'c C,
347}
348
349pub struct StatementGuardedOwned<C>
353where
354 C: Query,
355{
356 stmt: Statement,
357 cli: C,
358}
359
360impl<C> Clone for StatementGuardedOwned<C>
361where
362 C: Query + Clone,
363{
364 fn clone(&self) -> Self {
365 Self {
366 stmt: self.stmt.duplicate(),
367 cli: self.cli.clone(),
368 }
369 }
370}
371
372impl<C> Drop for StatementGuardedOwned<C>
373where
374 C: Query,
375{
376 fn drop(&mut self) {
377 if Arc::strong_count(&self.stmt.name) == 1 {
379 debug_assert_eq!(Arc::strong_count(&self.stmt.params), 1);
380 debug_assert_eq!(Arc::strong_count(&self.stmt.columns), 1);
381 let _ = self.cli._send_encode_query(StatementCancel { name: self.stmt.name() });
382 }
383 }
384}
385
386impl<C> Deref for StatementGuardedOwned<C>
387where
388 C: Query,
389{
390 type Target = Statement;
391
392 fn deref(&self) -> &Self::Target {
393 &self.stmt
394 }
395}
396
397impl<C> AsRef<Statement> for StatementGuardedOwned<C>
398where
399 C: Query,
400{
401 fn as_ref(&self) -> &Statement {
402 &self.stmt
403 }
404}
405
406impl<C> StatementGuardedOwned<C>
407where
408 C: Query,
409{
410 pub fn new(stmt: Statement, cli: C) -> Self {
412 Self { stmt, cli }
413 }
414
415 pub fn client(&self) -> &C {
418 &self.cli
419 }
420}
421
422#[cfg(test)]
423mod test {
424 use core::future::IntoFuture;
425
426 use crate::{
427 Postgres,
428 error::{DbError, SqlState},
429 execute::Execute,
430 iter::AsyncLendingIterator,
431 statement::Statement,
432 };
433
434 #[tokio::test]
435 async fn cancel_statement() {
436 let (cli, drv) = Postgres::new("postgres://postgres:postgres@localhost:5432")
437 .connect()
438 .await
439 .unwrap();
440
441 tokio::task::spawn(drv.into_future());
442
443 std::path::Path::new("./samples/test.sql").execute(&cli).await.unwrap();
444
445 let stmt = Statement::named("SELECT id, name FROM foo ORDER BY id", &[])
446 .execute(&cli)
447 .await
448 .unwrap();
449
450 let stmt_raw = stmt.duplicate();
451
452 drop(stmt);
453
454 let mut stream = stmt_raw.query(&cli).await.unwrap();
455
456 let e = stream.try_next().await.err().unwrap();
457
458 let e = e.downcast_ref::<DbError>().unwrap();
459
460 assert_eq!(e.code(), &SqlState::INVALID_SQL_STATEMENT_NAME);
461 }
462}