mysql_binlog_connector_rust/command/
authenticator.rs

1use async_recursion::async_recursion;
2use percent_encoding::percent_decode_str;
3use url::Url;
4
5use crate::{
6    binlog_error::BinlogError,
7    constants::MysqlRespCode,
8    network::{
9        auth_plugin_switch_packet::AuthPluginSwitchPacket,
10        greeting_packet::GreetingPacket,
11        packet_channel::{KeepAliveConfig, PacketChannel},
12    },
13};
14
15use super::{
16    auth_native_password_command::AuthNativePasswordCommand, auth_plugin::AuthPlugin,
17    auth_sha2_password_command::AuthSha2PasswordCommand,
18    auth_sha2_rsa_password_command::AuthSha2RsaPasswordCommand, command_util::CommandUtil,
19};
20
21pub struct Authenticator {
22    host: String,
23    port: String,
24    username: String,
25    password: String,
26    schema: String,
27    scramble: String,
28    collation: u8,
29    timeout_secs: u64,
30    keepalive_config: Option<KeepAliveConfig>,
31}
32
33impl Authenticator {
34    pub fn new(
35        url: &str,
36        timeout_secs: u64,
37        keepalive_config: Option<KeepAliveConfig>,
38    ) -> Result<Self, BinlogError> {
39        // url example: mysql://root:123456@127.0.0.1:3307/test_db?ssl-mode=disabled
40        let url_info = Url::parse(url)?;
41        let host = url_info.host_str().unwrap_or("");
42        let port = format!("{}", url_info.port().unwrap_or(3306));
43        let username = url_info.username();
44        let password = url_info.password().unwrap_or("");
45        let mut schema = "";
46        let pathes = url_info.path_segments().map(|c| c.collect::<Vec<_>>());
47        if let Some(vec) = pathes {
48            if !vec.is_empty() {
49                schema = vec[0];
50            }
51        }
52
53        Ok(Self {
54            host: percent_decode_str(host).decode_utf8_lossy().to_string(),
55            port,
56            username: percent_decode_str(username).decode_utf8_lossy().to_string(),
57            password: percent_decode_str(password).decode_utf8_lossy().to_string(),
58            schema: percent_decode_str(schema).decode_utf8_lossy().to_string(),
59            scramble: String::new(),
60            collation: 0,
61            timeout_secs,
62            keepalive_config,
63        })
64    }
65
66    pub async fn connect(&mut self) -> Result<PacketChannel, BinlogError> {
67        // connect to hostname:port
68        let mut channel = PacketChannel::new(
69            &self.host,
70            &self.port,
71            self.timeout_secs,
72            &self.keepalive_config,
73        )
74        .await?;
75
76        // read and parse greeting packet
77        let (greeting_buf, sequence) = channel.read_with_sequece().await?;
78        let greeting_packet = GreetingPacket::new(greeting_buf)?;
79
80        self.collation = greeting_packet.server_collation;
81        self.scramble = greeting_packet.scramble;
82
83        // authenticate
84        self.authenticate(
85            &mut channel,
86            &greeting_packet.plugin_provided_data,
87            sequence,
88        )
89        .await?;
90
91        Ok(channel)
92    }
93
94    async fn authenticate(
95        &mut self,
96        channel: &mut PacketChannel,
97        auth_plugin_name: &str,
98        sequence: u8,
99    ) -> Result<(), BinlogError> {
100        let command_buf = match AuthPlugin::from_name(auth_plugin_name) {
101            AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
102                schema: self.schema.clone(),
103                username: self.username.clone(),
104                password: self.password.clone(),
105                scramble: self.scramble.clone(),
106                collation: self.collation,
107            }
108            .to_bytes()?,
109
110            AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
111                schema: self.schema.clone(),
112                username: self.username.clone(),
113                password: self.password.clone(),
114                scramble: self.scramble.clone(),
115                collation: self.collation,
116            }
117            .to_bytes()?,
118
119            AuthPlugin::Unsupported => {
120                return Err(BinlogError::ConnectError("unsupported auth plugin".into()));
121            }
122        };
123
124        channel.write(&command_buf, sequence + 1).await?;
125        let (auth_res, sequence) = channel.read_with_sequece().await?;
126        self.handle_auth_result(channel, auth_plugin_name, sequence, &auth_res)
127            .await
128    }
129
130    async fn handle_auth_result(
131        &mut self,
132        channel: &mut PacketChannel,
133        auth_plugin_name: &str,
134        sequence: u8,
135        auth_res: &Vec<u8>,
136    ) -> Result<(), BinlogError> {
137        // parse result
138        match auth_res[0] {
139            MysqlRespCode::OK => return Ok(()),
140
141            MysqlRespCode::ERROR => return CommandUtil::check_error_packet(auth_res),
142
143            MysqlRespCode::AUTH_PLUGIN_SWITCH => {
144                return self
145                    .handle_auth_plugin_switch(channel, sequence, auth_res)
146                    .await;
147            }
148
149            _ => match AuthPlugin::from_name(auth_plugin_name) {
150                AuthPlugin::MySqlNativePassword => {
151                    return Err(BinlogError::ConnectError(format!(
152                        "unexpected auth result for mysql_native_password: {}",
153                        auth_res[0]
154                    )));
155                }
156
157                AuthPlugin::CachingSha2Password => {
158                    return self
159                        .handle_sha2_auth_result(channel, sequence, auth_res)
160                        .await;
161                }
162
163                // won't happen
164                _ => {}
165            },
166        };
167
168        Ok(())
169    }
170
171    #[async_recursion]
172    async fn handle_auth_plugin_switch(
173        &mut self,
174        channel: &mut PacketChannel,
175        sequence: u8,
176        auth_res: &Vec<u8>,
177    ) -> Result<(), BinlogError> {
178        let switch_packet = AuthPluginSwitchPacket::new(auth_res)?;
179        let auth_plugin_name = &switch_packet.auth_plugin_name;
180        self.scramble = switch_packet.scramble;
181
182        let encrypted_password = match AuthPlugin::from_name(auth_plugin_name) {
183            AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
184                schema: self.schema.clone(),
185                username: self.username.clone(),
186                password: self.password.clone(),
187                scramble: self.scramble.clone(),
188                collation: self.collation,
189            }
190            .encrypted_password()?,
191
192            AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
193                schema: self.schema.clone(),
194                username: self.username.clone(),
195                password: self.password.clone(),
196                scramble: self.scramble.clone(),
197                collation: self.collation,
198            }
199            .encrypted_password()?,
200
201            _ => {
202                return Err(BinlogError::ConnectError(format!(
203                    "unexpected auth plugin for auth plugin switch: {}",
204                    auth_plugin_name
205                )));
206            }
207        };
208
209        channel.write(&encrypted_password, sequence + 1).await?;
210        let (encrypted_auth_res, sequence) = channel.read_with_sequece().await?;
211        self.handle_auth_result(channel, auth_plugin_name, sequence, &encrypted_auth_res)
212            .await
213    }
214
215    async fn handle_sha2_auth_result(
216        &self,
217        channel: &mut PacketChannel,
218        sequence: u8,
219        auth_res: &[u8],
220    ) -> Result<(), BinlogError> {
221        // buf[0] is the length of buf, always 1
222        match auth_res[1] {
223            0x03 => Ok(()),
224
225            0x04 => self.sha2_rsa_authenticate(channel, sequence).await,
226
227            _ => Err(BinlogError::ConnectError(format!(
228                "unexpected auth result for caching_sha2_password: {}",
229                auth_res[1]
230            ))),
231        }
232    }
233
234    async fn sha2_rsa_authenticate(
235        &self,
236        channel: &mut PacketChannel,
237        sequence: u8,
238    ) -> Result<(), BinlogError> {
239        // refer: https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
240        // try to get RSA key from server
241        channel.write(&[0x02], sequence + 1).await?;
242        let (rsa_res, sequence) = channel.read_with_sequece().await?;
243        match rsa_res[0] {
244            0x01 => {
245                // try sha2 authentication with rsa
246                let mut command = AuthSha2RsaPasswordCommand {
247                    rsa_res: rsa_res[1..].to_vec(),
248                    password: self.password.clone(),
249                    scramble: self.scramble.clone(),
250                };
251                channel.write(&command.to_bytes()?, sequence + 1).await?;
252
253                let (auth_res, _) = channel.read_with_sequece().await?;
254                CommandUtil::parse_result(&auth_res)
255            }
256
257            _ => Err(BinlogError::ConnectError(format!(
258                "failed to get RSA key from server for caching_sha2_password: {}",
259                rsa_res[0]
260            ))),
261        }
262    }
263}