mysql_binlog_connector_rust/command/
authenticator.rs1use async_recursion::async_recursion;
2use percent_encoding::percent_decode_str;
3use url::Url;
4
5use crate::{
6 binlog_error::BinlogError,
7 constants::MysqlRespCode,
8 network::{
9 auth_plugin_switch_packet::AuthPluginSwitchPacket, greeting_packet::GreetingPacket,
10 packet_channel::PacketChannel,
11 },
12};
13
14use super::{
15 auth_native_password_command::AuthNativePasswordCommand, auth_plugin::AuthPlugin,
16 auth_sha2_password_command::AuthSha2PasswordCommand,
17 auth_sha2_rsa_password_command::AuthSha2RsaPasswordCommand, command_util::CommandUtil,
18};
19
20pub struct Authenticator {
21 host: String,
22 port: String,
23 username: String,
24 password: String,
25 schema: String,
26 scramble: String,
27 collation: u8,
28}
29
30impl Authenticator {
31 pub fn new(url: &str) -> Result<Self, BinlogError> {
32 let url_info = Url::parse(url)?;
34 let host = url_info.host_str().unwrap_or("");
35 let port = format!("{}", url_info.port().unwrap_or(3306));
36 let username = url_info.username();
37 let password = url_info.password().unwrap_or("");
38 let mut schema = "";
39 let pathes = url_info.path_segments().map(|c| c.collect::<Vec<_>>());
40 if let Some(vec) = pathes {
41 if !vec.is_empty() {
42 schema = vec[0];
43 }
44 }
45
46 Ok(Self {
47 host: percent_decode_str(host).decode_utf8_lossy().to_string(),
48 port,
49 username: percent_decode_str(username).decode_utf8_lossy().to_string(),
50 password: percent_decode_str(password).decode_utf8_lossy().to_string(),
51 schema: percent_decode_str(schema).decode_utf8_lossy().to_string(),
52 scramble: String::new(),
53 collation: 0,
54 })
55 }
56
57 pub async fn connect(&mut self) -> Result<PacketChannel, BinlogError> {
58 let mut channel = PacketChannel::new(&self.host, &self.port).await?;
60
61 let (greeting_buf, sequence) = channel.read_with_sequece().await?;
63 let greeting_packet = GreetingPacket::new(greeting_buf)?;
64
65 self.collation = greeting_packet.server_collation;
66 self.scramble = greeting_packet.scramble;
67
68 self.authenticate(
70 &mut channel,
71 &greeting_packet.plugin_provided_data,
72 sequence,
73 )
74 .await?;
75
76 Ok(channel)
77 }
78
79 async fn authenticate(
80 &mut self,
81 channel: &mut PacketChannel,
82 auth_plugin_name: &str,
83 sequence: u8,
84 ) -> Result<(), BinlogError> {
85 let command_buf = match AuthPlugin::from_name(auth_plugin_name) {
86 AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
87 schema: self.schema.clone(),
88 username: self.username.clone(),
89 password: self.password.clone(),
90 scramble: self.scramble.clone(),
91 collation: self.collation,
92 }
93 .to_bytes()?,
94
95 AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
96 schema: self.schema.clone(),
97 username: self.username.clone(),
98 password: self.password.clone(),
99 scramble: self.scramble.clone(),
100 collation: self.collation,
101 }
102 .to_bytes()?,
103
104 AuthPlugin::Unsupported => {
105 return Err(BinlogError::ConnectError("unsupported auth plugin".into()));
106 }
107 };
108
109 channel.write(&command_buf, sequence + 1).await?;
110 let (auth_res, sequence) = channel.read_with_sequece().await?;
111 self.handle_auth_result(channel, auth_plugin_name, sequence, &auth_res)
112 .await
113 }
114
115 async fn handle_auth_result(
116 &mut self,
117 channel: &mut PacketChannel,
118 auth_plugin_name: &str,
119 sequence: u8,
120 auth_res: &Vec<u8>,
121 ) -> Result<(), BinlogError> {
122 match auth_res[0] {
124 MysqlRespCode::OK => return Ok(()),
125
126 MysqlRespCode::ERROR => return CommandUtil::check_error_packet(auth_res),
127
128 MysqlRespCode::AUTH_PLUGIN_SWITCH => {
129 return self
130 .handle_auth_plugin_switch(channel, sequence, auth_res)
131 .await;
132 }
133
134 _ => match AuthPlugin::from_name(auth_plugin_name) {
135 AuthPlugin::MySqlNativePassword => {
136 return Err(BinlogError::ConnectError(format!(
137 "unexpected auth result for mysql_native_password: {}",
138 auth_res[0]
139 )));
140 }
141
142 AuthPlugin::CachingSha2Password => {
143 return self
144 .handle_sha2_auth_result(channel, sequence, auth_res)
145 .await;
146 }
147
148 _ => {}
150 },
151 };
152
153 Ok(())
154 }
155
156 #[async_recursion]
157 async fn handle_auth_plugin_switch(
158 &mut self,
159 channel: &mut PacketChannel,
160 sequence: u8,
161 auth_res: &Vec<u8>,
162 ) -> Result<(), BinlogError> {
163 let switch_packet = AuthPluginSwitchPacket::new(auth_res)?;
164 let auth_plugin_name = &switch_packet.auth_plugin_name;
165 self.scramble = switch_packet.scramble;
166
167 let encrypted_password = match AuthPlugin::from_name(auth_plugin_name) {
168 AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
169 schema: self.schema.clone(),
170 username: self.username.clone(),
171 password: self.password.clone(),
172 scramble: self.scramble.clone(),
173 collation: self.collation,
174 }
175 .encrypted_password()?,
176
177 AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
178 schema: self.schema.clone(),
179 username: self.username.clone(),
180 password: self.password.clone(),
181 scramble: self.scramble.clone(),
182 collation: self.collation,
183 }
184 .encrypted_password()?,
185
186 _ => {
187 return Err(BinlogError::ConnectError(format!(
188 "unexpected auth plugin for auth plugin switch: {}",
189 auth_plugin_name
190 )));
191 }
192 };
193
194 channel.write(&encrypted_password, sequence + 1).await?;
195 let (encrypted_auth_res, sequence) = channel.read_with_sequece().await?;
196 self.handle_auth_result(channel, auth_plugin_name, sequence, &encrypted_auth_res)
197 .await
198 }
199
200 async fn handle_sha2_auth_result(
201 &self,
202 channel: &mut PacketChannel,
203 sequence: u8,
204 auth_res: &[u8],
205 ) -> Result<(), BinlogError> {
206 match auth_res[1] {
208 0x03 => Ok(()),
209
210 0x04 => self.sha2_rsa_authenticate(channel, sequence).await,
211
212 _ => Err(BinlogError::ConnectError(format!(
213 "unexpected auth result for caching_sha2_password: {}",
214 auth_res[1]
215 ))),
216 }
217 }
218
219 async fn sha2_rsa_authenticate(
220 &self,
221 channel: &mut PacketChannel,
222 sequence: u8,
223 ) -> Result<(), BinlogError> {
224 channel.write(&[0x02], sequence + 1).await?;
227 let (rsa_res, sequence) = channel.read_with_sequece().await?;
228 match rsa_res[0] {
229 0x01 => {
230 let mut command = AuthSha2RsaPasswordCommand {
232 rsa_res: rsa_res[1..].to_vec(),
233 password: self.password.clone(),
234 scramble: self.scramble.clone(),
235 };
236 channel.write(&command.to_bytes()?, sequence + 1).await?;
237
238 let (auth_res, _) = channel.read_with_sequece().await?;
239 CommandUtil::parse_result(&auth_res)
240 }
241
242 _ => Err(BinlogError::ConnectError(format!(
243 "failed to get RSA key from server for caching_sha2_password: {}",
244 rsa_res[0]
245 ))),
246 }
247 }
248}