1use futures_util::FutureExt;
10use mysql_common::{
11 io::ParseBuf,
12 named_params::parse_named_params,
13 packets::{ComStmtClose, StmtPacket},
14};
15
16use std::{borrow::Cow, sync::Arc};
17
18use crate::{
19 conn::routines::{ExecRoutine, PrepareRoutine},
20 consts::CapabilityFlags,
21 error::*,
22 Column, Params,
23};
24
25use super::AsQuery;
26
27pub enum ToStatementResult<'a> {
29 Immediate(Statement),
31 Mediate(crate::BoxFuture<'a, Statement>),
33}
34
35pub trait StatementLike: Send + Sync {
36 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
38 where
39 Self: 'a;
40}
41
42fn to_statement_move<'a, T: AsQuery + 'a>(
43 stmt: T,
44 conn: &'a mut crate::Conn,
45) -> ToStatementResult<'a> {
46 let fut = async move {
47 let query = stmt.as_query();
48 let (named_params, raw_query) = parse_named_params(query.as_ref())?;
49 let inner_stmt = match conn.get_cached_stmt(&*raw_query) {
50 Some(inner_stmt) => inner_stmt,
51 None => conn.prepare_statement(raw_query).await?,
52 };
53 Ok(Statement::new(inner_stmt, named_params))
54 }
55 .boxed();
56 ToStatementResult::Mediate(fut)
57}
58
59impl StatementLike for Cow<'_, str> {
60 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
61 where
62 Self: 'a,
63 {
64 to_statement_move(self, conn)
65 }
66}
67
68impl StatementLike for &'_ str {
69 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
70 where
71 Self: 'a,
72 {
73 to_statement_move(self, conn)
74 }
75}
76
77impl StatementLike for String {
78 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
79 where
80 Self: 'a,
81 {
82 to_statement_move(self, conn)
83 }
84}
85
86impl StatementLike for Box<str> {
87 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
88 where
89 Self: 'a,
90 {
91 to_statement_move(self, conn)
92 }
93}
94
95impl StatementLike for Arc<str> {
96 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
97 where
98 Self: 'a,
99 {
100 to_statement_move(self, conn)
101 }
102}
103
104impl StatementLike for Cow<'_, [u8]> {
105 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
106 where
107 Self: 'a,
108 {
109 to_statement_move(self, conn)
110 }
111}
112
113impl StatementLike for &'_ [u8] {
114 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
115 where
116 Self: 'a,
117 {
118 to_statement_move(self, conn)
119 }
120}
121
122impl StatementLike for Vec<u8> {
123 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
124 where
125 Self: 'a,
126 {
127 to_statement_move(self, conn)
128 }
129}
130
131impl StatementLike for Box<[u8]> {
132 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
133 where
134 Self: 'a,
135 {
136 to_statement_move(self, conn)
137 }
138}
139
140impl StatementLike for Arc<[u8]> {
141 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
142 where
143 Self: 'a,
144 {
145 to_statement_move(self, conn)
146 }
147}
148
149impl StatementLike for Statement {
150 fn to_statement<'a>(self, _conn: &'a mut crate::Conn) -> ToStatementResult<'static>
151 where
152 Self: 'a,
153 {
154 ToStatementResult::Immediate(self.clone())
155 }
156}
157
158impl<T: StatementLike + Clone> StatementLike for &'_ T {
159 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
160 where
161 Self: 'a,
162 {
163 self.clone().to_statement(conn)
164 }
165}
166
167#[derive(Debug, Clone, Eq, PartialEq)]
169pub struct StmtInner {
170 pub(crate) raw_query: Arc<[u8]>,
171 columns: Option<Box<[Column]>>,
172 params: Option<Box<[Column]>>,
173 stmt_packet: StmtPacket,
174 connection_id: u32,
175}
176
177impl StmtInner {
178 pub(crate) fn from_payload(
179 pld: &[u8],
180 connection_id: u32,
181 raw_query: Arc<[u8]>,
182 ) -> std::io::Result<Self> {
183 let stmt_packet = ParseBuf(pld).parse(())?;
184
185 Ok(Self {
186 raw_query,
187 columns: None,
188 params: None,
189 stmt_packet,
190 connection_id,
191 })
192 }
193
194 pub(crate) fn with_params(mut self, params: Vec<Column>) -> Self {
195 self.params = if params.is_empty() {
196 None
197 } else {
198 Some(params.into_boxed_slice())
199 };
200 self
201 }
202
203 pub(crate) fn with_columns(mut self, columns: Vec<Column>) -> Self {
204 self.columns = if columns.is_empty() {
205 None
206 } else {
207 Some(columns.into_boxed_slice())
208 };
209 self
210 }
211
212 pub(crate) fn columns(&self) -> &[Column] {
213 self.columns.as_ref().map(AsRef::as_ref).unwrap_or(&[])
214 }
215
216 pub(crate) fn params(&self) -> &[Column] {
217 self.params.as_ref().map(AsRef::as_ref).unwrap_or(&[])
218 }
219
220 pub(crate) fn id(&self) -> u32 {
221 self.stmt_packet.statement_id()
222 }
223
224 pub(crate) const fn connection_id(&self) -> u32 {
225 self.connection_id
226 }
227
228 pub(crate) fn num_params(&self) -> u16 {
229 self.stmt_packet.num_params()
230 }
231
232 pub(crate) fn num_columns(&self) -> u16 {
233 self.stmt_packet.num_columns()
234 }
235}
236
237#[derive(Debug, Clone, Eq, PartialEq)]
241pub struct Statement {
242 pub(crate) inner: Arc<StmtInner>,
243 pub(crate) named_params: Option<Vec<Vec<u8>>>,
244}
245
246impl Statement {
247 pub(crate) fn new(inner: Arc<StmtInner>, named_params: Option<Vec<Vec<u8>>>) -> Self {
248 Self {
249 inner,
250 named_params,
251 }
252 }
253
254 pub fn columns(&self) -> &[Column] {
256 self.inner.columns()
257 }
258
259 pub fn params(&self) -> &[Column] {
261 self.inner.params()
262 }
263
264 pub fn id(&self) -> u32 {
266 self.inner.id()
267 }
268
269 pub fn connection_id(&self) -> u32 {
271 self.inner.connection_id()
272 }
273
274 pub fn num_params(&self) -> u16 {
276 self.inner.num_params()
277 }
278
279 pub fn num_columns(&self) -> u16 {
281 self.inner.num_columns()
282 }
283}
284
285impl crate::Conn {
286 pub(crate) async fn read_column_defs<U>(&mut self, num: U) -> Result<Vec<Column>>
290 where
291 U: Into<usize>,
292 {
293 let num = num.into();
294 debug_assert!(num > 0);
295 let packets = self.read_packets(num).await?;
296 let defs = packets
297 .into_iter()
298 .map(|x| ParseBuf(&*x).parse(()))
299 .collect::<std::result::Result<Vec<Column>, _>>()
300 .map_err(Error::from)?;
301
302 if !self
303 .capabilities()
304 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
305 {
306 self.read_packet().await?;
307 }
308
309 Ok(defs)
310 }
311
312 pub(crate) async fn get_statement<U>(&mut self, stmt_like: U) -> Result<Statement>
314 where
315 U: StatementLike,
316 {
317 match stmt_like.to_statement(self) {
318 ToStatementResult::Immediate(statement) => Ok(statement),
319 ToStatementResult::Mediate(statement) => statement.await,
320 }
321 }
322
323 async fn prepare_statement(&mut self, raw_query: Cow<'_, [u8]>) -> Result<Arc<StmtInner>> {
327 let inner_stmt = self.routine(PrepareRoutine::new(raw_query)).await?;
328
329 if let Some(old_stmt) = self.cache_stmt(&inner_stmt) {
330 self.close_statement(old_stmt.id()).await?;
331 }
332
333 Ok(inner_stmt)
334 }
335
336 pub(crate) async fn execute_statement<P>(
338 &mut self,
339 statement: &Statement,
340 params: P,
341 ) -> Result<()>
342 where
343 P: Into<Params>,
344 {
345 self.routine(ExecRoutine::new(statement, params.into()))
346 .await?;
347 Ok(())
348 }
349
350 pub(crate) async fn close_statement(&mut self, id: u32) -> Result<()> {
352 self.stmt_cache_mut().remove(id);
353 self.write_command(&ComStmtClose::new(id)).await
354 }
355}