mysql_connector/connection/
init.rs

1use {
2    super::{
3        types::AuthPlugin, Command, Connection, ConnectionData, ConnectionOptions, ParseBuf,
4        Stream, BUFFER_POOL, DEFAULT_MAX_ALLOWED_PACKET,
5    },
6    crate::{
7        packets::{HandshakePacket, HandshakeResponse},
8        ConnectionOptionsTrait, Error, Serialize, StreamRequirements,
9    },
10    std::sync::Arc,
11};
12
13impl Connection {
14    pub async fn connect<T: Stream>(options: Arc<ConnectionOptions<T>>) -> Result<Self, Error> {
15        let mut stream = T::connect(&options.connection).await?;
16        let mut seq_id = 0;
17
18        let data = Self::handle_handshake(&mut stream, &mut seq_id, options.clone()).await?;
19        let mut this = Self {
20            stream: Box::new(stream),
21            seq_id,
22            data,
23            options,
24            pending_result: false,
25        };
26        this.do_handshake_response().await?;
27        this.continue_auth::<T>().await?;
28        this.read_settings().await?;
29        Ok(this)
30    }
31
32    pub async fn disconnect(mut self) -> Result<(), Error> {
33        self.execute_command(Command::Quit, &[]).await
34    }
35
36    async fn handle_handshake(
37        stream: &mut dyn StreamRequirements,
38        seq_id: &mut u8,
39        options: Arc<dyn ConnectionOptionsTrait>,
40    ) -> Result<ConnectionData, Error> {
41        #[cfg(feature = "time")]
42        fn sleep(duration: std::time::Duration) -> crate::TimeoutFuture {
43            Box::pin(tokio::time::sleep(duration))
44        }
45        #[cfg(not(feature = "time"))]
46        let sleep = match options.sleep() {
47            Some(x) => x,
48            None => panic!(concat!(
49                "No `sleep` function provided.\n",
50                "You have to either provide a custom `sleep` function by setting `ConnectionData::sleep` or enable the feature `time`.",
51            )),
52        };
53        let mut packet = BUFFER_POOL.get();
54        Self::read_packet_to_buf(stream, seq_id, packet.as_mut(), &sleep, options.timeout())
55            .await?;
56        let handshake = ParseBuf(&packet).parse::<HandshakePacket>(()).unwrap();
57
58        let (version, is_mariadb) = handshake
59            .parse_server_version()
60            .unwrap_or(((0, 0, 0), false));
61        let auth_plugin = handshake.auth_plugin().unwrap_or(AuthPlugin::Native);
62
63        Ok(ConnectionData {
64            id: handshake.connection_id(),
65            is_mariadb,
66            version,
67            capabilities: handshake.capabilities() & options.get_capabilities(),
68            nonce: handshake.into_nonce(),
69            #[cfg(feature = "caching-sha2-password")]
70            server_key: options.server_key(),
71            auth_plugin,
72            auth_switched: false,
73            max_allowed_packet: options
74                .max_allowed_packet()
75                .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
76            #[cfg(feature = "time")]
77            sleep: &sleep,
78            #[cfg(not(feature = "time"))]
79            sleep,
80        })
81    }
82
83    async fn do_handshake_response(&mut self) -> Result<(), Error> {
84        let auth_plugin = self.options.auth_plugin().unwrap_or(self.data.auth_plugin);
85        let auth_data =
86            auth_plugin.gen_data(self.options.password(), &self.data.nonce, &*self.options)?;
87
88        let handshake_response = HandshakeResponse::new(
89            auth_data.as_deref().unwrap_or_default(),
90            self.data.version,
91            self.options.user().as_bytes(),
92            self.options.db_name().map(|x| x.as_bytes()),
93            Some(auth_plugin),
94            self.data.capabilities,
95            Default::default(),
96            self.data.max_allowed_packet as u32,
97        );
98
99        let mut buf = BUFFER_POOL.get();
100        handshake_response.serialize(buf.as_mut());
101        self.write_packet(&buf).await
102    }
103}