mysql_binlog_connector_rust/command/
authenticator.rs1use async_recursion::async_recursion;
2use percent_encoding::percent_decode_str;
3use url::Url;
4
5use crate::{
6 binlog_error::BinlogError,
7 constants::MysqlRespCode,
8 network::{
9 auth_plugin_switch_packet::AuthPluginSwitchPacket, greeting_packet::GreetingPacket,
10 packet_channel::PacketChannel,
11 },
12};
13
14use super::{
15 auth_native_password_command::AuthNativePasswordCommand, auth_plugin::AuthPlugin,
16 auth_sha2_password_command::AuthSha2PasswordCommand,
17 auth_sha2_rsa_password_command::AuthSha2RsaPasswordCommand, command_util::CommandUtil,
18};
19
20pub struct Authenticator {
21 host: String,
22 port: String,
23 username: String,
24 password: String,
25 schema: String,
26 scramble: String,
27 collation: u8,
28 timeout_secs: u64,
29}
30
31impl Authenticator {
32 pub fn new(url: &str, timeout_secs: u64) -> Result<Self, BinlogError> {
33 let url_info = Url::parse(url)?;
35 let host = url_info.host_str().unwrap_or("");
36 let port = format!("{}", url_info.port().unwrap_or(3306));
37 let username = url_info.username();
38 let password = url_info.password().unwrap_or("");
39 let mut schema = "";
40 let pathes = url_info.path_segments().map(|c| c.collect::<Vec<_>>());
41 if let Some(vec) = pathes {
42 if !vec.is_empty() {
43 schema = vec[0];
44 }
45 }
46
47 Ok(Self {
48 host: percent_decode_str(host).decode_utf8_lossy().to_string(),
49 port,
50 username: percent_decode_str(username).decode_utf8_lossy().to_string(),
51 password: percent_decode_str(password).decode_utf8_lossy().to_string(),
52 schema: percent_decode_str(schema).decode_utf8_lossy().to_string(),
53 scramble: String::new(),
54 collation: 0,
55 timeout_secs,
56 })
57 }
58
59 pub async fn connect(&mut self) -> Result<PacketChannel, BinlogError> {
60 let mut channel = PacketChannel::new(&self.host, &self.port, self.timeout_secs).await?;
62
63 let (greeting_buf, sequence) = channel.read_with_sequece().await?;
65 let greeting_packet = GreetingPacket::new(greeting_buf)?;
66
67 self.collation = greeting_packet.server_collation;
68 self.scramble = greeting_packet.scramble;
69
70 self.authenticate(
72 &mut channel,
73 &greeting_packet.plugin_provided_data,
74 sequence,
75 )
76 .await?;
77
78 Ok(channel)
79 }
80
81 async fn authenticate(
82 &mut self,
83 channel: &mut PacketChannel,
84 auth_plugin_name: &str,
85 sequence: u8,
86 ) -> Result<(), BinlogError> {
87 let command_buf = match AuthPlugin::from_name(auth_plugin_name) {
88 AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
89 schema: self.schema.clone(),
90 username: self.username.clone(),
91 password: self.password.clone(),
92 scramble: self.scramble.clone(),
93 collation: self.collation,
94 }
95 .to_bytes()?,
96
97 AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
98 schema: self.schema.clone(),
99 username: self.username.clone(),
100 password: self.password.clone(),
101 scramble: self.scramble.clone(),
102 collation: self.collation,
103 }
104 .to_bytes()?,
105
106 AuthPlugin::Unsupported => {
107 return Err(BinlogError::ConnectError("unsupported auth plugin".into()));
108 }
109 };
110
111 channel.write(&command_buf, sequence + 1).await?;
112 let (auth_res, sequence) = channel.read_with_sequece().await?;
113 self.handle_auth_result(channel, auth_plugin_name, sequence, &auth_res)
114 .await
115 }
116
117 async fn handle_auth_result(
118 &mut self,
119 channel: &mut PacketChannel,
120 auth_plugin_name: &str,
121 sequence: u8,
122 auth_res: &Vec<u8>,
123 ) -> Result<(), BinlogError> {
124 match auth_res[0] {
126 MysqlRespCode::OK => return Ok(()),
127
128 MysqlRespCode::ERROR => return CommandUtil::check_error_packet(auth_res),
129
130 MysqlRespCode::AUTH_PLUGIN_SWITCH => {
131 return self
132 .handle_auth_plugin_switch(channel, sequence, auth_res)
133 .await;
134 }
135
136 _ => match AuthPlugin::from_name(auth_plugin_name) {
137 AuthPlugin::MySqlNativePassword => {
138 return Err(BinlogError::ConnectError(format!(
139 "unexpected auth result for mysql_native_password: {}",
140 auth_res[0]
141 )));
142 }
143
144 AuthPlugin::CachingSha2Password => {
145 return self
146 .handle_sha2_auth_result(channel, sequence, auth_res)
147 .await;
148 }
149
150 _ => {}
152 },
153 };
154
155 Ok(())
156 }
157
158 #[async_recursion]
159 async fn handle_auth_plugin_switch(
160 &mut self,
161 channel: &mut PacketChannel,
162 sequence: u8,
163 auth_res: &Vec<u8>,
164 ) -> Result<(), BinlogError> {
165 let switch_packet = AuthPluginSwitchPacket::new(auth_res)?;
166 let auth_plugin_name = &switch_packet.auth_plugin_name;
167 self.scramble = switch_packet.scramble;
168
169 let encrypted_password = match AuthPlugin::from_name(auth_plugin_name) {
170 AuthPlugin::CachingSha2Password => AuthSha2PasswordCommand {
171 schema: self.schema.clone(),
172 username: self.username.clone(),
173 password: self.password.clone(),
174 scramble: self.scramble.clone(),
175 collation: self.collation,
176 }
177 .encrypted_password()?,
178
179 AuthPlugin::MySqlNativePassword => AuthNativePasswordCommand {
180 schema: self.schema.clone(),
181 username: self.username.clone(),
182 password: self.password.clone(),
183 scramble: self.scramble.clone(),
184 collation: self.collation,
185 }
186 .encrypted_password()?,
187
188 _ => {
189 return Err(BinlogError::ConnectError(format!(
190 "unexpected auth plugin for auth plugin switch: {}",
191 auth_plugin_name
192 )));
193 }
194 };
195
196 channel.write(&encrypted_password, sequence + 1).await?;
197 let (encrypted_auth_res, sequence) = channel.read_with_sequece().await?;
198 self.handle_auth_result(channel, auth_plugin_name, sequence, &encrypted_auth_res)
199 .await
200 }
201
202 async fn handle_sha2_auth_result(
203 &self,
204 channel: &mut PacketChannel,
205 sequence: u8,
206 auth_res: &[u8],
207 ) -> Result<(), BinlogError> {
208 match auth_res[1] {
210 0x03 => Ok(()),
211
212 0x04 => self.sha2_rsa_authenticate(channel, sequence).await,
213
214 _ => Err(BinlogError::ConnectError(format!(
215 "unexpected auth result for caching_sha2_password: {}",
216 auth_res[1]
217 ))),
218 }
219 }
220
221 async fn sha2_rsa_authenticate(
222 &self,
223 channel: &mut PacketChannel,
224 sequence: u8,
225 ) -> Result<(), BinlogError> {
226 channel.write(&[0x02], sequence + 1).await?;
229 let (rsa_res, sequence) = channel.read_with_sequece().await?;
230 match rsa_res[0] {
231 0x01 => {
232 let mut command = AuthSha2RsaPasswordCommand {
234 rsa_res: rsa_res[1..].to_vec(),
235 password: self.password.clone(),
236 scramble: self.scramble.clone(),
237 };
238 channel.write(&command.to_bytes()?, sequence + 1).await?;
239
240 let (auth_res, _) = channel.read_with_sequece().await?;
241 CommandUtil::parse_result(&auth_res)
242 }
243
244 _ => Err(BinlogError::ConnectError(format!(
245 "failed to get RSA key from server for caching_sha2_password: {}",
246 rsa_res[0]
247 ))),
248 }
249 }
250}