mysql_async/conn/routines/
exec.rs1use std::mem;
2
3use futures_core::future::BoxFuture;
4use futures_util::FutureExt;
5use mysql_common::{packets::ComStmtExecuteRequestBuilder, params::Params};
6#[cfg(feature = "tracing")]
7use tracing::{field, info_span, Level, Span};
8
9use crate::{BinaryProtocol, Conn, DriverError, Statement};
10
11use super::Routine;
12
13#[derive(Debug, Clone)]
15pub struct ExecRoutine<'a> {
16 stmt: &'a Statement,
17 params: Params,
18}
19
20impl<'a> ExecRoutine<'a> {
21 pub fn new(stmt: &'a Statement, params: Params) -> Self {
22 Self { stmt, params }
23 }
24}
25
26impl Routine<()> for ExecRoutine<'_> {
27 fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<()>> {
28 #[cfg(feature = "tracing")]
29 let span = info_span!(
30 "mysql_async::exec",
31 mysql_async.connection.id = conn.id(),
32 mysql_async.statement.id = self.stmt.id(),
33 mysql_async.query.params = field::Empty,
34 );
35
36 let fut = async move {
37 loop {
38 match self.params {
39 Params::Positional(ref params) => {
40 #[cfg(feature = "tracing")]
41 if tracing::span_enabled!(Level::DEBUG) {
42 let sep = std::iter::repeat(", ");
46 let ps = params
47 .iter()
48 .map(|p| p.as_sql(true))
49 .zip(sep)
50 .map(|(val, sep)| val + sep)
51 .collect::<String>();
52 Span::current().record("mysql_async.query.params", ps);
53 }
54
55 if self.stmt.num_params() as usize != params.len() {
56 Err(DriverError::StmtParamsMismatch {
57 required: self.stmt.num_params(),
58 supplied: params.len() as u16,
59 })?
60 }
61
62 let (body, as_long_data) =
63 ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&*params);
64
65 if as_long_data {
66 conn.send_long_data(self.stmt.id(), params.iter()).await?;
67 }
68
69 conn.write_command(&body).await?;
70 conn.read_result_set::<BinaryProtocol>(true).await?;
71 break;
72 }
73 Params::Named(_) => {
74 if self.stmt.named_params.is_none() {
75 let error = DriverError::NamedParamsForPositionalQuery.into();
76 return Err(error);
77 }
78
79 let named = mem::replace(&mut self.params, Params::Empty);
80 self.params =
81 named.into_positional(self.stmt.named_params.as_ref().unwrap())?;
82
83 continue;
84 }
85 Params::Empty => {
86 if self.stmt.num_params() > 0 {
87 let error = DriverError::StmtParamsMismatch {
88 required: self.stmt.num_params(),
89 supplied: 0,
90 }
91 .into();
92 return Err(error);
93 }
94
95 let (body, _) =
96 ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&[]);
97 conn.write_command(&body).await?;
98 conn.read_result_set::<BinaryProtocol>(true).await?;
99 break;
100 }
101 }
102 }
103 Ok(())
104 };
105
106 #[cfg(feature = "tracing")]
107 let fut = instrument_result!(fut, span);
108
109 fut.boxed()
110 }
111}