mysql_connector/connection/
prepared_statement.rs

1use {
2    super::{types::BinaryProtocol, Command, Connection, ParseBuf, ResultSet},
3    crate::{
4        error::{ProtocolError, RuntimeError},
5        model::FromQueryResult,
6        packets::{ErrPacket, OkPacket, Stmt, StmtExecuteRequest},
7        types::SimpleValue,
8        Deserialize, Error,
9    },
10};
11
12#[derive(Debug)]
13pub struct PreparedStatement<'a> {
14    id: u32,
15    conn: &'a mut Connection,
16    params: usize,
17}
18
19impl<'a> PreparedStatement<'a> {
20    pub async fn query<V: SimpleValue, R: FromQueryResult>(
21        &mut self,
22        values: &[V],
23    ) -> Result<ResultSet<'_, BinaryProtocol, R>, Error> {
24        if values.len() != self.params {
25            return Err(RuntimeError::ParameterCountMismatch.into());
26        }
27
28        self.conn
29            .query_prepared_statement_unchecked(self.id, values)
30            .await
31    }
32
33    pub async fn execute<V: SimpleValue>(&mut self, values: &[V]) -> Result<OkPacket, Error> {
34        if values.len() != self.params {
35            return Err(RuntimeError::ParameterCountMismatch.into());
36        }
37
38        self.conn
39            .execute_prepared_statement_unchecked(self.id, values)
40            .await
41    }
42}
43
44impl Connection {
45    pub async fn prepare_statement(&mut self, stmt: &str) -> Result<PreparedStatement, Error> {
46        self.execute_command(Command::StmtPrepare, stmt).await?;
47        let packet = self.read_packet().await?;
48        let stmt = match packet.first() {
49            Some(0x00) => Stmt::deserialize(&mut ParseBuf(&packet), ())?,
50            Some(0xFF) => {
51                return Err(
52                    ErrPacket::deserialize(&mut ParseBuf(&packet), self.data.capabilities)?.into(),
53                )
54            }
55            _ => return Err(ProtocolError::unexpected_packet(Vec::clone(&packet), None).into()),
56        };
57
58        for _ in 0..(stmt.params_len as usize + stmt.columns_len as usize) {
59            self.read_packet().await?;
60        }
61
62        Ok(PreparedStatement {
63            id: stmt.id,
64            conn: self,
65            params: stmt.params_len as usize,
66        })
67    }
68
69    async fn query_prepared_statement_unchecked<V: SimpleValue, R: FromQueryResult>(
70        &mut self,
71        id: u32,
72        params: &[V],
73    ) -> Result<ResultSet<'_, BinaryProtocol, R>, Error> {
74        let request = StmtExecuteRequest::new(id, params);
75
76        if request.as_long_data() {
77            self.send_long_data(id, params.iter()).await?;
78        }
79
80        self.write_command(&request).await?;
81        ResultSet::read(self).await
82    }
83
84    async fn execute_prepared_statement_unchecked<V: SimpleValue>(
85        &mut self,
86        id: u32,
87        params: &[V],
88    ) -> Result<OkPacket, Error> {
89        let request = StmtExecuteRequest::new(id, params);
90
91        if request.as_long_data() {
92            self.send_long_data(id, params.iter()).await?;
93        }
94
95        self.write_command(&request).await?;
96        self.read_response().await?.map_err(Into::into)
97    }
98}