mysql_binlog_connector_rust/command/
authenticator.rs

1use async_recursion::async_recursion;
2use percent_encoding::percent_decode_str;
3use url::Url;
4
5use crate::{
6    binlog_error::BinlogError,
7    constants::MysqlRespCode,
8    network::{
9        auth_plugin_switch_packet::AuthPluginSwitchPacket, greeting_packet::GreetingPacket,
10        packet_channel::PacketChannel,
11    },
12};
13
14use super::{
15    auth_native_password_command::AuthNativePasswordCommand, auth_plugin::AuthPlugin,
16    auth_sha2_password_command::AuthSha2PasswordCommand,
17    auth_sha2_rsa_password_command::AuthSha2RsaPasswordCommand, command_util::CommandUtil,
18};
19
20pub struct Authenticator {
21    host: String,
22    port: String,
23    username: String,
24    password: String,
25    schema: String,
26    scramble: String,
27    collation: u8,
28    timeout_secs: u64,
29}
30
31impl Authenticator {
32    pub fn new(url: &str, timeout_secs: u64) -> Result<Self, BinlogError> {
33        // url example: mysql://root:123456@127.0.0.1:3307/test_db?ssl-mode=disabled
34        let url_info = Url::parse(url)?;
35        let host = url_info.host_str().unwrap_or("");
36        let port = format!("{}", url_info.port().unwrap_or(3306));
37        let username = url_info.username();
38        let password = url_info.password().unwrap_or("");
39        let mut schema = "";
40        let pathes = url_info.path_segments().map(|c| c.collect::<Vec<_>>());
41        if let Some(vec) = pathes {
42            if !vec.is_empty() {
43                schema = vec[0];
44            }
45        }
46
47        Ok(Self {
48            host: percent_decode_str(host).decode_utf8_lossy().to_string(),
49            port,
50            username: percent_decode_str(username).decode_utf8_lossy().to_string(),
51            password: percent_decode_str(password).decode_utf8_lossy().to_string(),
52            schema: percent_decode_str(schema).decode_utf8_lossy().to_string(),
53            scramble: String::new(),
54            collation: 0,
55            timeout_secs,
56        })
57    }
58
59    pub async fn connect(&mut self) -> Result<PacketChannel, BinlogError> {
60        // connect to hostname:port
61        let mut channel = PacketChannel::new(&self.host, &self.port, self.timeout_secs).await?;
62
63        // read and parse greeting packet
64        let (greeting_buf, sequence) = channel.read_with_sequece().await?;
65        let greeting_packet = GreetingPacket::new(greeting_buf)?;
66
67        self.collation = greeting_packet.server_collation;
68        self.scramble = greeting_packet.scramble;
69
70        // authenticate
71        self.authenticate(
72            &mut channel,
73            &greeting_packet.plugin_provided_data,
74            sequence,
75        )
76        .await?;
77
78        Ok(channel)
79    }
80
81    async fn authenticate(
82        &mut self,
83        channel: &mut PacketChannel,
84        auth_plugin_name: &str,
85        sequence: u8,
86    ) -> Result<(), BinlogError> {
87        let command_buf = match AuthPlugin::from_name(auth_plugin_name) {
88            AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
89                schema: self.schema.clone(),
90                username: self.username.clone(),
91                password: self.password.clone(),
92                scramble: self.scramble.clone(),
93                collation: self.collation,
94            }
95            .to_bytes()?,
96
97            AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
98                schema: self.schema.clone(),
99                username: self.username.clone(),
100                password: self.password.clone(),
101                scramble: self.scramble.clone(),
102                collation: self.collation,
103            }
104            .to_bytes()?,
105
106            AuthPlugin::Unsupported => {
107                return Err(BinlogError::ConnectError("unsupported auth plugin".into()));
108            }
109        };
110
111        channel.write(&command_buf, sequence + 1).await?;
112        let (auth_res, sequence) = channel.read_with_sequece().await?;
113        self.handle_auth_result(channel, auth_plugin_name, sequence, &auth_res)
114            .await
115    }
116
117    async fn handle_auth_result(
118        &mut self,
119        channel: &mut PacketChannel,
120        auth_plugin_name: &str,
121        sequence: u8,
122        auth_res: &Vec<u8>,
123    ) -> Result<(), BinlogError> {
124        // parse result
125        match auth_res[0] {
126            MysqlRespCode::OK => return Ok(()),
127
128            MysqlRespCode::ERROR => return CommandUtil::check_error_packet(auth_res),
129
130            MysqlRespCode::AUTH_PLUGIN_SWITCH => {
131                return self
132                    .handle_auth_plugin_switch(channel, sequence, auth_res)
133                    .await;
134            }
135
136            _ => match AuthPlugin::from_name(auth_plugin_name) {
137                AuthPlugin::MySqlNativePassword => {
138                    return Err(BinlogError::ConnectError(format!(
139                        "unexpected auth result for mysql_native_password: {}",
140                        auth_res[0]
141                    )));
142                }
143
144                AuthPlugin::CachingSha2Password => {
145                    return self
146                        .handle_sha2_auth_result(channel, sequence, auth_res)
147                        .await;
148                }
149
150                // won't happen
151                _ => {}
152            },
153        };
154
155        Ok(())
156    }
157
158    #[async_recursion]
159    async fn handle_auth_plugin_switch(
160        &mut self,
161        channel: &mut PacketChannel,
162        sequence: u8,
163        auth_res: &Vec<u8>,
164    ) -> Result<(), BinlogError> {
165        let switch_packet = AuthPluginSwitchPacket::new(auth_res)?;
166        let auth_plugin_name = &switch_packet.auth_plugin_name;
167        self.scramble = switch_packet.scramble;
168
169        let encrypted_password = match AuthPlugin::from_name(auth_plugin_name) {
170            AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
171                schema: self.schema.clone(),
172                username: self.username.clone(),
173                password: self.password.clone(),
174                scramble: self.scramble.clone(),
175                collation: self.collation,
176            }
177            .encrypted_password()?,
178
179            AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
180                schema: self.schema.clone(),
181                username: self.username.clone(),
182                password: self.password.clone(),
183                scramble: self.scramble.clone(),
184                collation: self.collation,
185            }
186            .encrypted_password()?,
187
188            _ => {
189                return Err(BinlogError::ConnectError(format!(
190                    "unexpected auth plugin for auth plugin switch: {}",
191                    auth_plugin_name
192                )));
193            }
194        };
195
196        channel.write(&encrypted_password, sequence + 1).await?;
197        let (encrypted_auth_res, sequence) = channel.read_with_sequece().await?;
198        self.handle_auth_result(channel, auth_plugin_name, sequence, &encrypted_auth_res)
199            .await
200    }
201
202    async fn handle_sha2_auth_result(
203        &self,
204        channel: &mut PacketChannel,
205        sequence: u8,
206        auth_res: &[u8],
207    ) -> Result<(), BinlogError> {
208        // buf[0] is the length of buf, always 1
209        match auth_res[1] {
210            0x03 => Ok(()),
211
212            0x04 => self.sha2_rsa_authenticate(channel, sequence).await,
213
214            _ => Err(BinlogError::ConnectError(format!(
215                "unexpected auth result for caching_sha2_password: {}",
216                auth_res[1]
217            ))),
218        }
219    }
220
221    async fn sha2_rsa_authenticate(
222        &self,
223        channel: &mut PacketChannel,
224        sequence: u8,
225    ) -> Result<(), BinlogError> {
226        // refer: https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
227        // try to get RSA key from server
228        channel.write(&[0x02], sequence + 1).await?;
229        let (rsa_res, sequence) = channel.read_with_sequece().await?;
230        match rsa_res[0] {
231            0x01 => {
232                // try sha2 authentication with rsa
233                let mut command = AuthSha2RsaPasswordCommand {
234                    rsa_res: rsa_res[1..].to_vec(),
235                    password: self.password.clone(),
236                    scramble: self.scramble.clone(),
237                };
238                channel.write(&command.to_bytes()?, sequence + 1).await?;
239
240                let (auth_res, _) = channel.read_with_sequece().await?;
241                CommandUtil::parse_result(&auth_res)
242            }
243
244            _ => Err(BinlogError::ConnectError(format!(
245                "failed to get RSA key from server for caching_sha2_password: {}",
246                rsa_res[0]
247            ))),
248        }
249    }
250}