1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3mod mysql;
10use std::io::{Read, Write};
11
12use mysql::authentication::AsyncAuthentication;
13use mysql::protos::drop_packet;
14use mysql::protos::drop_packet_sync;
15use mysql::protos::request;
16use mysql::protos::request_async;
17use mysql::protos::write_packet;
18use mysql::protos::write_packet_sync;
19use mysql::protos::AsyncReceivePacket;
20use mysql::protos::CapabilityFlags;
21use mysql::protos::ErrPacket;
22use mysql::protos::QueryCommand;
23use mysql::protos::QueryCommandResponse;
24use mysql::protos::QuitCommand;
25use mysql::protos::StmtCloseCommand;
26use mysql::protos::StmtExecuteCommand;
27use mysql::protos::StmtExecuteFlags;
28use mysql::protos::StmtExecuteResult;
29use mysql::protos::StmtPrepareCommand;
30use mysql::protos::StmtResetCommand;
31use mysql::protos::Value;
32use parking_lot::MutexGuard;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncWriteExt;
35
36use crate::authentication::Authentication;
37use crate::protos::Handshake;
38
39pub use self::mysql::*;
40
41mod resultset_stream;
42pub use self::resultset_stream::*;
43
44#[derive(Debug)]
46pub enum CommunicationError {
47 IO(std::io::Error),
49 Server(ErrPacket),
51 UnexpectedOKPacket,
53}
54impl From<std::io::Error> for CommunicationError {
55 fn from(e: std::io::Error) -> Self {
56 Self::IO(e)
57 }
58}
59impl From<ErrPacket> for CommunicationError {
60 fn from(e: ErrPacket) -> Self {
61 Self::Server(e)
62 }
63}
64impl std::fmt::Display for CommunicationError {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Self::IO(io) => write!(f, "IO Error: {io}"),
68 Self::Server(e) => write!(f, "Server Error: {}", e.error_message),
69 Self::UnexpectedOKPacket => write!(f, "Unexpected OK Packet was returned"),
70 }
71 }
72}
73impl std::error::Error for CommunicationError {}
74
75pub struct ConnectInfo<'s> {
77 username: &'s str,
78 password: &'s str,
79 database: Option<&'s str>,
80 max_packet_size: u32,
81 character_set: u8,
82}
83impl<'s> ConnectInfo<'s> {
84 pub fn new(username: &'s str, password: &'s str) -> Self {
86 Self {
87 username,
88 password,
89 database: None,
90 max_packet_size: 16777216,
91 character_set: 0xff,
92 }
93 }
94
95 pub fn database(mut self, db_name: &'s str) -> Self {
97 self.database = Some(db_name);
98 self
99 }
100
101 pub fn max_packet_size(mut self, packet_size: u32) -> Self {
103 self.max_packet_size = packet_size;
104 self
105 }
106
107 pub fn character_set(mut self, character_set: u8) -> Self {
109 self.character_set = character_set;
110 self
111 }
112}
113
114pub struct BlockingClient<Stream: Write> {
118 stream: Stream,
119 capability: CapabilityFlags,
120}
121impl<Stream: Write> BlockingClient<Stream> {
122 pub fn handshake(mut stream: Stream, connect_info: &ConnectInfo) -> Result<Self, CommunicationError>
124 where
125 Stream: Read,
126 {
127 let (server_handshake, sequence_id) = Handshake::read_packet_sync(&mut stream)?;
128
129 let server_caps = match server_handshake {
130 Handshake::V10Long(ref p) => p.short.capability_flags,
131 Handshake::V10Short(ref p) => p.capability_flags,
132 _ => CapabilityFlags::new(),
133 };
134 let mut required_caps = CapabilityFlags::new();
135 required_caps
136 .set_support_41_protocol()
137 .set_support_secure_connection()
138 .set_use_long_password()
139 .set_support_deprecate_eof()
140 .set_client_plugin_auth()
141 .set_support_plugin_auth_lenenc_client_data();
142 if connect_info.database.is_some() {
143 required_caps.set_connect_with_db();
144 }
145 let capability = required_caps & server_caps;
146
147 let con_info = authentication::ConnectionInfo {
148 client_capabilities: capability,
149 max_packet_size: connect_info.max_packet_size,
150 character_set: connect_info.character_set,
151 username: connect_info.username,
152 password: connect_info.password,
153 database: connect_info.database,
154 };
155
156 let (auth_plugin_name, auth_data_1, auth_data_2) = match server_handshake {
157 Handshake::V10Long(ref p) => (
158 p.auth_plugin_name.as_deref(),
159 &p.short.auth_plugin_data_part_1[..],
160 p.auth_plugin_data_part_2.as_deref(),
161 ),
162 Handshake::V10Short(ref p) => (None, &p.auth_plugin_data_part_1[..], None),
163 Handshake::V9(ref p) => (None, p.scramble.as_bytes(), None),
164 };
165 match auth_plugin_name {
166 Some(x) if x == authentication::Native41::NAME => authentication::Native41 {
167 server_data_1: auth_data_1,
168 server_data_2: auth_data_2.expect("no extra data passed from server"),
169 }
170 .run_sync(&mut stream, &con_info, sequence_id + 1)?,
171 Some(x) if x == authentication::ClearText::NAME => {
172 authentication::ClearText.run_sync(&mut stream, &con_info, sequence_id + 1)?
173 }
174 Some(x) if x == authentication::SHA256::NAME => authentication::SHA256 {
175 server_spki_der: None,
176 scramble_buffer_1: auth_data_1,
177 scramble_buffer_2: auth_data_2.unwrap_or(&[]),
178 }
179 .run_sync(&mut stream, &con_info, sequence_id + 1)?,
180 Some(x) if x == authentication::CachedSHA256::NAME => {
181 authentication::CachedSHA256(authentication::SHA256 {
182 server_spki_der: None,
183 scramble_buffer_1: auth_data_1,
184 scramble_buffer_2: auth_data_2.unwrap_or(&[]),
185 })
186 .run_sync(&mut stream, &con_info, sequence_id + 1)?
187 }
188 Some(x) => unreachable!("unknown auth plugin: {x}"),
189 None => unreachable!("auth plugin is not specified"),
190 };
191
192 Ok(unsafe { Self::new(stream, capability) })
193 }
194
195 pub unsafe fn new(stream: Stream, capability: CapabilityFlags) -> Self {
199 Self { stream, capability }
200 }
201
202 pub fn quit(&mut self) -> std::io::Result<()> {
204 write_packet_sync(&mut self.stream, QuitCommand, 0)
205 }
206
207 pub fn query(&mut self, query: &str) -> std::io::Result<QueryCommandResponse>
209 where
210 Stream: Read,
211 {
212 request(QueryCommand(query), &mut self.stream, 0, self.capability)
213 }
214
215 pub fn fetch_all<'s>(&'s mut self, query: &str) -> Result<TextResultsetIterator<&'s mut Stream>, CommunicationError>
217 where
218 Stream: Read,
219 {
220 match self.query(query)? {
221 QueryCommandResponse::Resultset { column_count } => {
222 self.text_resultset_iterator(column_count as _).map_err(From::from)
223 }
224 QueryCommandResponse::Err(e) => Err(CommunicationError::from(e)),
225 QueryCommandResponse::Ok(_) => Err(CommunicationError::UnexpectedOKPacket),
226 QueryCommandResponse::LocalInfileRequest { filename } => {
227 todo!("local infile request: {filename}")
228 }
229 }
230 }
231
232 pub fn text_resultset_iterator(
234 &mut self,
235 column_count: usize,
236 ) -> std::io::Result<TextResultsetIterator<&mut Stream>>
237 where
238 Stream: Read,
239 {
240 TextResultsetIterator::new(&mut self.stream, column_count, self.capability)
241 }
242
243 pub fn binary_resultset_iterator(
245 &mut self,
246 column_count: usize,
247 ) -> std::io::Result<BinaryResultsetIterator<&mut Stream>>
248 where
249 Stream: Read,
250 {
251 BinaryResultsetIterator::new(&mut self.stream, column_count, self.capability)
252 }
253}
254impl<Stream: Write> Drop for BlockingClient<Stream> {
255 fn drop(&mut self) {
256 self.quit().expect("Failed to send quit packet at drop")
257 }
258}
259impl BlockingClient<std::net::TcpStream> {
260 pub fn into_async(self) -> Client<tokio::io::BufStream<tokio::net::TcpStream>> {
262 let stream = unsafe { std::ptr::read(&self.stream as *const std::net::TcpStream) };
263 let capability = self.capability;
264 std::mem::forget(self);
265
266 stream.set_nonblocking(true).expect("Failed to switch blocking mode");
267
268 Client {
269 stream: tokio::io::BufStream::new(
270 tokio::net::TcpStream::from_std(stream).expect("Failed to wrap std stream"),
271 ),
272 capability,
273 }
274 }
275}
276
277pub struct Client<Stream: AsyncWriteExt + Send + Sync + Unpin> {
279 stream: Stream,
280 capability: CapabilityFlags,
281}
282impl<Stream: AsyncWriteExt + Send + Sync + Unpin> Client<Stream> {
283 pub async fn handshake(mut stream: Stream, connect_info: &ConnectInfo<'_>) -> Result<Self, CommunicationError>
285 where
286 Stream: AsyncRead + 'static,
287 {
288 let (server_handshake, sequence_id) = Handshake::read_packet(&mut stream).await?;
289
290 let server_caps = match server_handshake {
291 Handshake::V10Long(ref p) => p.short.capability_flags,
292 Handshake::V10Short(ref p) => p.capability_flags,
293 _ => CapabilityFlags::new(),
294 };
295 let mut required_caps = CapabilityFlags::new();
296 required_caps
297 .set_support_41_protocol()
298 .set_support_secure_connection()
299 .set_use_long_password()
300 .set_support_deprecate_eof()
301 .set_client_plugin_auth()
302 .set_support_plugin_auth_lenenc_client_data();
303 if connect_info.database.is_some() {
304 required_caps.set_connect_with_db();
305 }
306 let capability = required_caps & server_caps;
307
308 let con_info = authentication::ConnectionInfo {
309 client_capabilities: capability,
310 max_packet_size: connect_info.max_packet_size,
311 character_set: connect_info.character_set,
312 username: connect_info.username,
313 password: connect_info.password,
314 database: connect_info.database,
315 };
316
317 let (auth_plugin_name, auth_data_1, auth_data_2) = match server_handshake {
318 Handshake::V10Long(ref p) => (
319 p.auth_plugin_name.as_deref(),
320 &p.short.auth_plugin_data_part_1[..],
321 p.auth_plugin_data_part_2.as_deref(),
322 ),
323 Handshake::V10Short(ref p) => (None, &p.auth_plugin_data_part_1[..], None),
324 Handshake::V9(ref p) => (None, p.scramble.as_bytes(), None),
325 };
326 match auth_plugin_name {
327 Some(x) if x == authentication::Native41::NAME => authentication::Native41 {
328 server_data_1: auth_data_1,
329 server_data_2: auth_data_2.expect("no extra data passed from server"),
330 }
331 .run(&mut stream, &con_info, sequence_id + 1)
332 .await
333 .expect("Failed to authenticate"),
334 Some(x) if x == authentication::ClearText::NAME => authentication::ClearText
335 .run(&mut stream, &con_info, sequence_id + 1)
336 .await
337 .expect("Failed to authenticate"),
338 Some(x) if x == authentication::SHA256::NAME => authentication::SHA256 {
339 server_spki_der: None,
340 scramble_buffer_1: auth_data_1,
341 scramble_buffer_2: auth_data_2.unwrap_or(&[]),
342 }
343 .run(&mut stream, &con_info, sequence_id + 1)
344 .await
345 .expect("Failed to authenticate"),
346 Some(x) if x == authentication::CachedSHA256::NAME => {
347 authentication::CachedSHA256(authentication::SHA256 {
348 server_spki_der: None,
349 scramble_buffer_1: auth_data_1,
350 scramble_buffer_2: auth_data_2.unwrap_or(&[]),
351 })
352 .run(&mut stream, &con_info, sequence_id + 1)
353 .await
354 .expect("Failed to authenticate")
355 }
356 Some(x) => unreachable!("unknown auth plugin: {x}"),
357 None => unreachable!("auth plugin is not specified"),
358 };
359
360 Ok(unsafe { Self::new(stream, capability) })
361 }
362
363 pub unsafe fn new(stream: Stream, capability: CapabilityFlags) -> Self {
367 Self { stream, capability }
368 }
369
370 pub async fn quit(&mut self) -> std::io::Result<()> {
372 write_packet(&mut self.stream, QuitCommand, 0).await?;
373 Ok(())
374 }
375
376 pub async fn query(&mut self, query: &str) -> std::io::Result<QueryCommandResponse>
378 where
379 Stream: AsyncRead,
380 {
381 write_packet(&mut self.stream, QueryCommand(query), 0).await?;
382 self.stream.flush().await?;
383 QueryCommandResponse::read_packet_async(&mut self.stream, self.capability).await
384 }
385
386 pub async fn fetch_all<'s>(
388 &'s mut self,
389 query: &'s str,
390 ) -> Result<TextResultsetStream<'s, Stream>, CommunicationError>
391 where
392 Stream: AsyncRead,
393 {
394 match self.query(query).await? {
395 QueryCommandResponse::Resultset { column_count } => {
396 self.text_resultset_stream(column_count as _).await.map_err(From::from)
397 }
398 QueryCommandResponse::Err(e) => Err(CommunicationError::from(e)),
399 QueryCommandResponse::Ok(_) => Err(CommunicationError::UnexpectedOKPacket),
400 QueryCommandResponse::LocalInfileRequest { filename } => {
401 todo!("local infile request: {filename}")
402 }
403 }
404 }
405
406 pub async fn text_resultset_stream<'s>(
408 &'s mut self,
409 column_count: usize,
410 ) -> std::io::Result<TextResultsetStream<'s, Stream>>
411 where
412 Stream: AsyncRead,
413 {
414 TextResultsetStream::new(&mut self.stream, column_count, self.capability).await
415 }
416
417 pub async fn binary_resultset_stream<'s>(
419 &'s mut self,
420 column_count: usize,
421 ) -> std::io::Result<BinaryResultsetStream<'s, Stream>>
422 where
423 Stream: AsyncRead,
424 {
425 BinaryResultsetStream::new(&mut self.stream, self.capability, column_count).await
426 }
427}
428
429pub trait GenericClient {
431 type Stream;
433
434 fn stream(&self) -> &Self::Stream;
436 fn stream_mut(&mut self) -> &mut Self::Stream;
438 fn capability(&self) -> CapabilityFlags;
440}
441impl<C: GenericClient> GenericClient for MutexGuard<'_, C> {
442 type Stream = C::Stream;
443
444 fn stream(&self) -> &Self::Stream {
445 C::stream(self)
446 }
447 fn stream_mut(&mut self) -> &mut Self::Stream {
448 C::stream_mut(self)
449 }
450 fn capability(&self) -> CapabilityFlags {
451 C::capability(self)
452 }
453}
454impl<C: GenericClient> GenericClient for Box<C> {
455 type Stream = C::Stream;
456
457 fn stream(&self) -> &Self::Stream {
458 C::stream(self)
459 }
460 fn stream_mut(&mut self) -> &mut Self::Stream {
461 C::stream_mut(self)
462 }
463 fn capability(&self) -> CapabilityFlags {
464 C::capability(self)
465 }
466}
467impl<S: AsyncWriteExt + Send + Sync + Unpin> GenericClient for Client<S> {
468 type Stream = S;
469
470 fn stream(&self) -> &Self::Stream {
471 &self.stream
472 }
473 fn stream_mut(&mut self) -> &mut Self::Stream {
474 &mut self.stream
475 }
476 fn capability(&self) -> CapabilityFlags {
477 self.capability
478 }
479}
480impl<S: Write> GenericClient for BlockingClient<S> {
481 type Stream = S;
482
483 fn stream(&self) -> &Self::Stream {
484 &self.stream
485 }
486 fn stream_mut(&mut self) -> &mut Self::Stream {
487 &mut self.stream
488 }
489 fn capability(&self) -> CapabilityFlags {
490 self.capability
491 }
492}
493
494#[repr(transparent)]
496pub struct Statement(u32);
497impl<Stream: AsyncWriteExt + Sync + Send + Unpin> Client<Stream> {
499 pub async fn prepare(&mut self, statement: &str) -> Result<Statement, CommunicationError>
501 where
502 Stream: AsyncRead,
503 {
504 let resp = request_async(StmtPrepareCommand(statement), &mut self.stream, 0, self.capability)
505 .await?
506 .into_result()?;
507
508 for _ in 0..resp.num_params {
510 drop_packet(&mut self.stream).await?;
511 }
512 if !self.capability.support_deprecate_eof() {
513 drop_packet(&mut self.stream).await?
515 }
516
517 for _ in 0..resp.num_columns {
518 drop_packet(&mut self.stream).await?;
519 }
520 if !self.capability.support_deprecate_eof() {
521 drop_packet(&mut self.stream).await?
523 }
524
525 Ok(Statement(resp.statement_id))
526 }
527
528 pub async fn close_statement(&mut self, statement: Statement) -> std::io::Result<()> {
530 write_packet(&mut self.stream, StmtCloseCommand(statement.0), 0).await
531 }
532
533 pub async fn reset_statement(&mut self, statement: &Statement) -> Result<(), CommunicationError>
535 where
536 Stream: AsyncRead,
537 {
538 request_async(StmtResetCommand(statement.0), &mut self.stream, 0, self.capability)
539 .await?
540 .into_result()
541 .map(drop)
542 .map_err(From::from)
543 }
544
545 pub async fn execute_statement(
549 &mut self,
550 statement: &Statement,
551 parameters: &[(Value<'_>, bool)],
552 rebound_parameters: bool,
553 ) -> std::io::Result<StmtExecuteResult>
554 where
555 Stream: AsyncRead,
556 {
557 request_async(
558 StmtExecuteCommand {
559 statement_id: statement.0,
560 flags: StmtExecuteFlags::new(),
561 parameters,
562 requires_rebound_parameters: rebound_parameters,
563 },
564 &mut self.stream,
565 0,
566 self.capability,
567 )
568 .await
569 }
570
571 pub async fn fetch_all_statement<'s>(
573 &'s mut self,
574 statement: &Statement,
575 parameters: &[(Value<'_>, bool)],
576 rebound_parameters: bool,
577 ) -> Result<BinaryResultsetStream<'s, Stream>, CommunicationError>
578 where
579 Stream: AsyncRead,
580 {
581 match self
582 .execute_statement(statement, parameters, rebound_parameters)
583 .await?
584 {
585 StmtExecuteResult::Resultset { column_count } => self
586 .binary_resultset_stream(column_count as _)
587 .await
588 .map_err(From::from),
589 StmtExecuteResult::Err(e) => Err(CommunicationError::from(e)),
590 StmtExecuteResult::Ok(_) => Err(CommunicationError::UnexpectedOKPacket),
591 }
592 }
593}
594
595impl<Stream: Write> BlockingClient<Stream> {
597 pub fn prepare(&mut self, statement: &str) -> Result<Statement, CommunicationError>
599 where
600 Stream: Read,
601 {
602 let resp = request(StmtPrepareCommand(statement), &mut self.stream, 0, self.capability)?.into_result()?;
603
604 for _ in 0..resp.num_params {
606 drop_packet_sync(&mut self.stream)?;
607 }
608 if !self.capability.support_deprecate_eof() {
609 drop_packet_sync(&mut self.stream)?
611 }
612
613 for _ in 0..resp.num_columns {
614 drop_packet_sync(&mut self.stream)?;
615 }
616 if !self.capability.support_deprecate_eof() {
617 drop_packet_sync(&mut self.stream)?
619 }
620
621 Ok(Statement(resp.statement_id))
622 }
623
624 pub fn close_statement(&mut self, statement: Statement) -> std::io::Result<()> {
626 write_packet_sync(&mut self.stream, StmtCloseCommand(statement.0), 0)
627 }
628
629 pub fn reset_statement(&mut self, statement: &Statement) -> Result<(), CommunicationError>
631 where
632 Stream: Read,
633 {
634 request(StmtResetCommand(statement.0), &mut self.stream, 0, self.capability)?
635 .into_result()
636 .map(drop)
637 .map_err(From::from)
638 }
639
640 pub fn execute_statement(
644 &mut self,
645 statement: &Statement,
646 parameters: &[(Value<'_>, bool)],
647 rebound_parameters: bool,
648 ) -> std::io::Result<StmtExecuteResult>
649 where
650 Stream: Read,
651 {
652 request(
653 StmtExecuteCommand {
654 statement_id: statement.0,
655 flags: StmtExecuteFlags::new(),
656 parameters,
657 requires_rebound_parameters: rebound_parameters,
658 },
659 &mut self.stream,
660 0,
661 self.capability,
662 )
663 }
664
665 pub fn fetch_all_statement(
667 &mut self,
668 statement: &Statement,
669 parameters: &[(Value<'_>, bool)],
670 rebound_parameters: bool,
671 ) -> Result<BinaryResultsetIterator<&mut Stream>, CommunicationError>
672 where
673 Stream: Read,
674 {
675 match self.execute_statement(statement, parameters, rebound_parameters)? {
676 StmtExecuteResult::Resultset { column_count } => {
677 self.binary_resultset_iterator(column_count as _).map_err(From::from)
678 }
679 StmtExecuteResult::Err(e) => Err(CommunicationError::from(e)),
680 StmtExecuteResult::Ok(_) => Err(CommunicationError::UnexpectedOKPacket),
681 }
682 }
683}
684
685mod async_utils;
686mod counted_read;
687
688#[cfg(feature = "r2d2-integration")]
689#[cfg_attr(docsrs, doc(cfg(feature = "r2d2-integration")))]
690pub mod r2d2;
691
692#[cfg(feature = "bb8-integration")]
693#[cfg_attr(docsrs, doc(cfg(feature = "bb8-integration")))]
694pub mod bb8;
695
696#[cfg(feature = "autossl")]
697#[cfg_attr(docsrs, doc(cfg(feature = "autossl")))]
698pub mod autossl_client;