mysql_connector/connection/
auth.rs

1use {
2    super::{types::AuthPlugin, Connection, ParseBuf, Stream},
3    crate::{
4        error::ProtocolError,
5        packets::{AuthSwitchRequest, ErrPacket},
6        Deserialize, Error,
7    },
8    std::{future::Future, pin::Pin},
9};
10
11impl Connection {
12    pub(super) fn continue_auth<T: Stream>(
13        &mut self,
14    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + '_>> {
15        match self.options.auth_plugin().unwrap_or(self.data.auth_plugin) {
16            #[cfg(feature = "caching-sha2-password")]
17            AuthPlugin::Sha2 => Box::pin(self.continue_caching_sha2_password_auth::<T>()),
18            AuthPlugin::Native | AuthPlugin::Clear => {
19                Box::pin(self.continue_mysql_native_password_auth::<T>())
20            }
21        }
22    }
23
24    async fn continue_mysql_native_password_auth<T: Stream>(&mut self) -> Result<(), Error> {
25        let packet = self.read_packet().await?;
26        match packet.first() {
27            Some(0x00) => Ok(()),
28            Some(0xFE) if !self.data.auth_switched => {
29                let auth_switch = AuthSwitchRequest::deserialize(&mut ParseBuf(&packet), ())?;
30                self.perform_auth_switch::<T>(auth_switch).await
31            }
32            _ => Err(
33                match ErrPacket::deserialize(&mut ParseBuf(&packet), self.data.capabilities) {
34                    Ok(err) => err.into(),
35                    Err(_) => {
36                        ProtocolError::unexpected_packet(packet.to_vec(), Some("Ok or auth switch"))
37                            .into()
38                    }
39                },
40            ),
41        }
42    }
43
44    #[cfg(feature = "caching-sha2-password")]
45    #[cfg_attr(doc, doc(cfg(feature = "caching-sha2-password")))]
46    async fn continue_caching_sha2_password_auth<T: Stream>(&mut self) -> Result<(), Error> {
47        use {
48            crate::{
49                error::SerializeError,
50                utils::{OaepPadding, PublicKey},
51            },
52            rand::SeedableRng as _,
53        };
54        let packet = self.read_packet().await?;
55        match packet.first() {
56            Some(0x00) => {
57                // ok packet for empty password
58                Ok(())
59            }
60            Some(0x01) => match packet.get(1) {
61                Some(0x03) => {
62                    // auth ok
63                    self.read_packet().await?;
64                    Ok(())
65                }
66                Some(0x04) => {
67                    let mut pass = super::BUFFER_POOL.get();
68                    pass.extend_from_slice(self.options.password().as_bytes());
69                    pass.push(0);
70
71                    if T::SECURE {
72                        self.write_packet(&pass).await?;
73                    } else {
74                        let server_key = match &self.data.server_key {
75                            Some(key) => key.clone(),
76                            None => {
77                                self.write_packet(&[0x02]).await?;
78                                let packet = self.read_packet().await?;
79                                match packet.first() {
80                                    Some(0x01) => {
81                                        let server_key = std::sync::Arc::new(
82                                            PublicKey::try_from_pem(&packet[1..])
83                                                .map_err(SerializeError::from)?,
84                                        );
85                                        self.data.server_key = Some(server_key.clone());
86                                        server_key
87                                    }
88                                    Some(0xFF) => {
89                                        return Err(Error::Server(ErrPacket::deserialize(
90                                            &mut ParseBuf(&packet),
91                                            self.data.capabilities,
92                                        )?))
93                                    }
94                                    _ => {
95                                        return Err(Error::Protocol(
96                                            ProtocolError::unexpected_packet(
97                                                packet.to_vec(),
98                                                Some("Server key"),
99                                            ),
100                                        ))
101                                    }
102                                }
103                            }
104                        };
105                        for (i, byte) in pass.iter_mut().enumerate() {
106                            *byte ^= self.data.nonce[i % self.data.nonce.len()];
107                        }
108                        let padding = OaepPadding::new(rand::rngs::StdRng::from_entropy());
109                        let encrypted_pass = server_key
110                            .encrypt_padded(&pass, padding)
111                            .map_err(SerializeError::from)?;
112                        self.write_packet(&encrypted_pass).await?;
113                    }
114                    let res = self.read_packet().await?;
115                    match res.first() {
116                        Some(0x00) => Ok(()),
117                        Some(0xFF) => Err(Error::Server(ErrPacket::deserialize(
118                            &mut ParseBuf(&res),
119                            self.data.capabilities,
120                        )?)),
121                        _ => Err(Error::Protocol(ProtocolError::unexpected_packet(
122                            res.to_vec(),
123                            None,
124                        ))),
125                    }
126                }
127                _ => Err(ProtocolError::unexpected_packet(packet.to_vec(), None).into()),
128            },
129            Some(0xFE) if !self.data.auth_switched => {
130                let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
131                self.perform_auth_switch::<T>(auth_switch_request).await
132            }
133            _ => Err(
134                match ErrPacket::deserialize(&mut ParseBuf(&packet), self.data.capabilities) {
135                    Ok(err) => err.into(),
136                    Err(_) => {
137                        ProtocolError::unexpected_packet(packet.to_vec(), Some("Ok or auth switch"))
138                            .into()
139                    }
140                },
141            ),
142        }
143    }
144
145    async fn perform_auth_switch<T: Stream>(
146        &mut self,
147        auth_switch_request: AuthSwitchRequest,
148    ) -> Result<(), Error> {
149        assert!(
150            !self.data.auth_switched,
151            "auth_switched flag should be checked by caller"
152        );
153
154        self.data.auth_switched = true;
155        self.data.auth_plugin = auth_switch_request.plugin();
156        self.data.nonce = auth_switch_request.into_data();
157
158        let plugin_data = self.data.auth_plugin.gen_data(
159            self.options.password(),
160            &self.data.nonce,
161            &*self.options,
162        )?;
163
164        if let Some(plugin_data) = plugin_data {
165            self.write_struct(&plugin_data).await?;
166        } else {
167            self.write_packet(&[]).await?;
168        }
169
170        self.continue_auth::<T>().await
171    }
172}