mysql_binlog_connector_rust/command/
authenticator.rs

1use async_recursion::async_recursion;
2use percent_encoding::percent_decode_str;
3use url::Url;
4
5use crate::{
6    binlog_error::BinlogError,
7    constants::MysqlRespCode,
8    network::{
9        auth_plugin_switch_packet::AuthPluginSwitchPacket, greeting_packet::GreetingPacket,
10        packet_channel::PacketChannel,
11    },
12};
13
14use super::{
15    auth_native_password_command::AuthNativePasswordCommand, auth_plugin::AuthPlugin,
16    auth_sha2_password_command::AuthSha2PasswordCommand,
17    auth_sha2_rsa_password_command::AuthSha2RsaPasswordCommand, command_util::CommandUtil,
18};
19
20pub struct Authenticator {
21    host: String,
22    port: String,
23    username: String,
24    password: String,
25    schema: String,
26    scramble: String,
27    collation: u8,
28}
29
30impl Authenticator {
31    pub fn new(url: &str) -> Result<Self, BinlogError> {
32        // url example: mysql://root:123456@127.0.0.1:3307/test_db?ssl-mode=disabled
33        let url_info = Url::parse(url)?;
34        let host = url_info.host_str().unwrap_or("");
35        let port = format!("{}", url_info.port().unwrap_or(3306));
36        let username = url_info.username();
37        let password = url_info.password().unwrap_or("");
38        let mut schema = "";
39        let pathes = url_info.path_segments().map(|c| c.collect::<Vec<_>>());
40        if let Some(vec) = pathes {
41            if !vec.is_empty() {
42                schema = vec[0];
43            }
44        }
45
46        Ok(Self {
47            host: percent_decode_str(host).decode_utf8_lossy().to_string(),
48            port,
49            username: percent_decode_str(username).decode_utf8_lossy().to_string(),
50            password: percent_decode_str(password).decode_utf8_lossy().to_string(),
51            schema: percent_decode_str(schema).decode_utf8_lossy().to_string(),
52            scramble: String::new(),
53            collation: 0,
54        })
55    }
56
57    pub async fn connect(&mut self) -> Result<PacketChannel, BinlogError> {
58        // connect to hostname:port
59        let mut channel = PacketChannel::new(&self.host, &self.port).await?;
60
61        // read and parse greeting packet
62        let (greeting_buf, sequence) = channel.read_with_sequece().await?;
63        let greeting_packet = GreetingPacket::new(greeting_buf)?;
64
65        self.collation = greeting_packet.server_collation;
66        self.scramble = greeting_packet.scramble;
67
68        // authenticate
69        self.authenticate(
70            &mut channel,
71            &greeting_packet.plugin_provided_data,
72            sequence,
73        )
74        .await?;
75
76        Ok(channel)
77    }
78
79    async fn authenticate(
80        &mut self,
81        channel: &mut PacketChannel,
82        auth_plugin_name: &str,
83        sequence: u8,
84    ) -> Result<(), BinlogError> {
85        let command_buf = match AuthPlugin::from_name(auth_plugin_name) {
86            AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
87                schema: self.schema.clone(),
88                username: self.username.clone(),
89                password: self.password.clone(),
90                scramble: self.scramble.clone(),
91                collation: self.collation,
92            }
93            .to_bytes()?,
94
95            AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
96                schema: self.schema.clone(),
97                username: self.username.clone(),
98                password: self.password.clone(),
99                scramble: self.scramble.clone(),
100                collation: self.collation,
101            }
102            .to_bytes()?,
103
104            AuthPlugin::Unsupported => {
105                return Err(BinlogError::ConnectError("unsupported auth plugin".into()));
106            }
107        };
108
109        channel.write(&command_buf, sequence + 1).await?;
110        let (auth_res, sequence) = channel.read_with_sequece().await?;
111        self.handle_auth_result(channel, auth_plugin_name, sequence, &auth_res)
112            .await
113    }
114
115    async fn handle_auth_result(
116        &mut self,
117        channel: &mut PacketChannel,
118        auth_plugin_name: &str,
119        sequence: u8,
120        auth_res: &Vec<u8>,
121    ) -> Result<(), BinlogError> {
122        // parse result
123        match auth_res[0] {
124            MysqlRespCode::OK => return Ok(()),
125
126            MysqlRespCode::ERROR => return CommandUtil::check_error_packet(auth_res),
127
128            MysqlRespCode::AUTH_PLUGIN_SWITCH => {
129                return self
130                    .handle_auth_plugin_switch(channel, sequence, auth_res)
131                    .await;
132            }
133
134            _ => match AuthPlugin::from_name(auth_plugin_name) {
135                AuthPlugin::MySqlNativePassword => {
136                    return Err(BinlogError::ConnectError(format!(
137                        "unexpected auth result for mysql_native_password: {}",
138                        auth_res[0]
139                    )));
140                }
141
142                AuthPlugin::CachingSha2Password => {
143                    return self
144                        .handle_sha2_auth_result(channel, sequence, auth_res)
145                        .await;
146                }
147
148                // won't happen
149                _ => {}
150            },
151        };
152
153        Ok(())
154    }
155
156    #[async_recursion]
157    async fn handle_auth_plugin_switch(
158        &mut self,
159        channel: &mut PacketChannel,
160        sequence: u8,
161        auth_res: &Vec<u8>,
162    ) -> Result<(), BinlogError> {
163        let switch_packet = AuthPluginSwitchPacket::new(auth_res)?;
164        let auth_plugin_name = &switch_packet.auth_plugin_name;
165        self.scramble = switch_packet.scramble;
166
167        let encrypted_password = match AuthPlugin::from_name(auth_plugin_name) {
168            AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
169                schema: self.schema.clone(),
170                username: self.username.clone(),
171                password: self.password.clone(),
172                scramble: self.scramble.clone(),
173                collation: self.collation,
174            }
175            .encrypted_password()?,
176
177            AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
178                schema: self.schema.clone(),
179                username: self.username.clone(),
180                password: self.password.clone(),
181                scramble: self.scramble.clone(),
182                collation: self.collation,
183            }
184            .encrypted_password()?,
185
186            _ => {
187                return Err(BinlogError::ConnectError(format!(
188                    "unexpected auth plugin for auth plugin switch: {}",
189                    auth_plugin_name
190                )));
191            }
192        };
193
194        channel.write(&encrypted_password, sequence + 1).await?;
195        let (encrypted_auth_res, sequence) = channel.read_with_sequece().await?;
196        self.handle_auth_result(channel, auth_plugin_name, sequence, &encrypted_auth_res)
197            .await
198    }
199
200    async fn handle_sha2_auth_result(
201        &self,
202        channel: &mut PacketChannel,
203        sequence: u8,
204        auth_res: &[u8],
205    ) -> Result<(), BinlogError> {
206        // buf[0] is the length of buf, always 1
207        match auth_res[1] {
208            0x03 => Ok(()),
209
210            0x04 => self.sha2_rsa_authenticate(channel, sequence).await,
211
212            _ => Err(BinlogError::ConnectError(format!(
213                "unexpected auth result for caching_sha2_password: {}",
214                auth_res[1]
215            ))),
216        }
217    }
218
219    async fn sha2_rsa_authenticate(
220        &self,
221        channel: &mut PacketChannel,
222        sequence: u8,
223    ) -> Result<(), BinlogError> {
224        // refer: https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
225        // try to get RSA key from server
226        channel.write(&[0x02], sequence + 1).await?;
227        let (rsa_res, sequence) = channel.read_with_sequece().await?;
228        match rsa_res[0] {
229            0x01 => {
230                // try sha2 authentication with rsa
231                let mut command = AuthSha2RsaPasswordCommand {
232                    rsa_res: rsa_res[1..].to_vec(),
233                    password: self.password.clone(),
234                    scramble: self.scramble.clone(),
235                };
236                channel.write(&command.to_bytes()?, sequence + 1).await?;
237
238                let (auth_res, _) = channel.read_with_sequece().await?;
239                CommandUtil::parse_result(&auth_res)
240            }
241
242            _ => Err(BinlogError::ConnectError(format!(
243                "failed to get RSA key from server for caching_sha2_password: {}",
244                rsa_res[0]
245            ))),
246        }
247    }
248}