1use crate::PreparedStatement;
2use crate::buffer::BufferSet;
3use crate::buffer_pool::PooledBufferSet;
4use crate::constant::CapabilityFlags;
5use crate::error::{Error, Result};
6use crate::nightly::unlikely;
7use crate::protocol::TextRowPayload;
8use crate::protocol::command::Action;
9use crate::protocol::command::ColumnDefinition;
10use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
11use crate::protocol::command::prepared::Exec;
12use crate::protocol::command::prepared::write_execute;
13use crate::protocol::command::prepared::{read_prepare_ok, write_prepare};
14use crate::protocol::command::query::Query;
15use crate::protocol::command::query::write_query;
16use crate::protocol::command::utility::DropHandler;
17use crate::protocol::command::utility::FirstHandler;
18use crate::protocol::command::utility::write_ping;
19use crate::protocol::command::utility::write_reset_connection;
20use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
21use crate::protocol::packet::PacketHeader;
22use crate::protocol::primitive::read_string_lenenc;
23use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
24use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
25use std::net::TcpStream;
26#[cfg(unix)]
27use std::os::unix::net::UnixStream;
28use zerocopy::FromZeros;
29use zerocopy::{FromBytes, IntoBytes};
30
31use super::stream::Stream;
32
33pub struct Conn {
34 stream: Stream,
35 buffer_set: PooledBufferSet,
36 initial_handshake: InitialHandshake,
37 capability_flags: CapabilityFlags,
38 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
39 in_transaction: bool,
40 is_broken: bool,
41}
42
43impl Conn {
44 pub(crate) fn set_in_transaction(&mut self, value: bool) {
45 self.in_transaction = value;
46 }
47
48 pub fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
50 where
51 Error: From<O::Error>,
52 {
53 let opts: crate::opts::Opts = opts.try_into()?;
54
55 #[cfg(unix)]
56 let stream = if let Some(socket_path) = &opts.socket {
57 let stream = UnixStream::connect(socket_path)?;
58 Stream::unix(stream)
59 } else {
60 if opts.host.is_empty() {
61 return Err(Error::BadUsageError(
62 "Missing host in connection options".to_string(),
63 ));
64 }
65 let addr = format!("{}:{}", opts.host, opts.port);
66 let stream = TcpStream::connect(&addr)?;
67 stream.set_nodelay(opts.tcp_nodelay)?;
68 Stream::tcp(stream)
69 };
70
71 #[cfg(not(unix))]
72 let stream = {
73 if opts.socket.is_some() {
74 return Err(Error::BadUsageError(
75 "Unix sockets are not supported on this platform".to_string(),
76 ));
77 }
78 if opts.host.is_empty() {
79 return Err(Error::BadUsageError(
80 "Missing host in connection options".to_string(),
81 ));
82 }
83 let addr = format!("{}:{}", opts.host, opts.port);
84 let stream = TcpStream::connect(&addr)?;
85 stream.set_nodelay(opts.tcp_nodelay)?;
86 Stream::tcp(stream)
87 };
88
89 Self::new_with_stream(stream, &opts)
90 }
91
92 pub fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
94 let mut conn_stream = stream;
95 let mut buffer_set = opts.buffer_pool.get_buffer_set();
96
97 #[cfg(feature = "sync-tls")]
98 let host = opts.host.clone();
99
100 let mut handshake = Handshake::new(opts);
101
102 loop {
103 match handshake.step(&mut buffer_set)? {
104 HandshakeAction::ReadPacket(buffer) => {
105 buffer.clear();
106 read_payload(&mut conn_stream, buffer)?;
107 }
108 HandshakeAction::WritePacket { sequence_id } => {
109 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
110 buffer_set.read_buffer.clear();
111 read_payload(&mut conn_stream, &mut buffer_set.read_buffer)?;
112 }
113 #[cfg(feature = "sync-tls")]
114 HandshakeAction::UpgradeTls { sequence_id } => {
115 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
116 conn_stream = conn_stream.upgrade_to_tls(&host)?;
117 }
118 #[cfg(not(feature = "sync-tls"))]
119 HandshakeAction::UpgradeTls { .. } => {
120 return Err(Error::BadUsageError(
121 "TLS requested but sync-tls feature is not enabled".to_string(),
122 ));
123 }
124 HandshakeAction::Finished => break,
125 }
126 }
127
128 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
129
130 let conn = Self {
131 stream: conn_stream,
132 buffer_set,
133 initial_handshake,
134 capability_flags,
135 mariadb_capabilities,
136 in_transaction: false,
137 is_broken: false,
138 };
139
140 #[cfg(unix)]
142 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
143 conn.try_upgrade_to_unix_socket(opts)
144 } else {
145 conn
146 };
147 #[cfg(not(unix))]
148 let mut conn = conn;
149
150 if let Some(init_command) = &opts.init_command {
152 conn.query_drop(init_command)?;
153 }
154
155 Ok(conn)
156 }
157
158 pub fn server_version(&self) -> &[u8] {
160 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
161 }
162
163 pub fn capability_flags(&self) -> CapabilityFlags {
165 self.capability_flags
166 }
167
168 pub fn is_mysql(&self) -> bool {
170 self.capability_flags.is_mysql()
171 }
172
173 pub fn is_mariadb(&self) -> bool {
175 self.capability_flags.is_mariadb()
176 }
177
178 pub fn connection_id(&self) -> u64 {
180 self.initial_handshake.connection_id as u64
181 }
182
183 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
185 self.initial_handshake.status_flags
186 }
187
188 pub fn is_broken(&self) -> bool {
192 self.is_broken
193 }
194
195 #[inline]
196 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
197 if let Err(e) = &result
198 && e.is_conn_broken()
199 {
200 self.is_broken = true;
201 }
202 result
203 }
204
205 #[cfg(unix)]
208 fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
209 let mut handler = SocketPathHandler { path: None };
211 if self.query("SELECT @@socket", &mut handler).is_err() {
212 return self;
213 }
214
215 let socket_path = match handler.path {
216 Some(p) if !p.is_empty() => p,
217 _ => return self,
218 };
219
220 let unix_stream = match UnixStream::connect(&socket_path) {
222 Ok(s) => s,
223 Err(_) => return self,
224 };
225 let stream = Stream::unix(unix_stream);
226
227 let mut opts_unix = opts.clone();
230 opts_unix.upgrade_to_unix_socket = false;
231
232 match Self::new_with_stream(stream, &opts_unix) {
233 Ok(new_conn) => new_conn,
234 Err(_) => self,
235 }
236 }
237
238 fn write_payload(&mut self) -> Result<()> {
239 let mut sequence_id = 0_u8;
240 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
241
242 loop {
243 let chunk_size = buffer[4..].len().min(0xFFFFFF);
244 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
245 .encode_in_place(chunk_size, sequence_id);
246 self.stream.write_all(&buffer[..4 + chunk_size])?;
247
248 if chunk_size < 0xFFFFFF {
249 break;
250 }
251
252 sequence_id = sequence_id.wrapping_add(1);
253 buffer = &mut buffer[0xFFFFFF..];
254 }
255 self.stream.flush()?;
256 Ok(())
257 }
258
259 pub fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
261 let result = self.prepare_inner(sql);
262 self.check_error(result)
263 }
264
265 fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
266 use crate::protocol::command::ColumnDefinitions;
267
268 self.buffer_set.read_buffer.clear();
269
270 write_prepare(self.buffer_set.new_write_buffer(), sql);
271
272 self.write_payload()?;
273 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
274
275 if unlikely(
276 !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF,
277 ) {
278 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
279 }
280
281 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
282 let statement_id = prepare_ok.statement_id();
283 let num_params = prepare_ok.num_params();
284 let num_columns = prepare_ok.num_columns();
285
286 if num_params > 0 {
288 for _ in 0..num_params {
289 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
290 }
291 }
292
293 let column_definitions = if num_columns > 0 {
295 read_column_definition_packets(
296 &mut self.stream,
297 &mut self.buffer_set.column_definition_buffer,
298 num_columns as usize,
299 )?;
300 Some(ColumnDefinitions::new(
301 num_columns as usize,
302 std::mem::take(&mut self.buffer_set.column_definition_buffer),
303 )?)
304 } else {
305 None
306 };
307
308 let mut stmt = PreparedStatement::new(statement_id);
309 if let Some(col_defs) = column_definitions {
310 stmt.set_column_definitions(col_defs);
311 }
312 Ok(stmt)
313 }
314
315 fn drive_exec<H: BinaryResultSetHandler>(
316 &mut self,
317 stmt: &mut PreparedStatement,
318 handler: &mut H,
319 ) -> Result<()> {
320 let cache_metadata = self
321 .mariadb_capabilities
322 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
323 let mut exec = Exec::new(handler, stmt, cache_metadata);
324
325 loop {
326 match exec.step(&mut self.buffer_set)? {
327 Action::NeedPacket(buffer) => {
328 buffer.clear();
329 let _ = read_payload(&mut self.stream, buffer)?;
330 }
331 Action::ReadColumnMetadata { num_columns } => {
332 read_column_definition_packets(
333 &mut self.stream,
334 &mut self.buffer_set.column_definition_buffer,
335 num_columns,
336 )?;
337 }
338 Action::Finished => return Ok(()),
339 }
340 }
341 }
342
343 pub fn exec<'conn, P, H>(
347 &'conn mut self,
348 stmt: &'conn mut PreparedStatement,
349 params: P,
350 handler: &mut H,
351 ) -> Result<()>
352 where
353 P: Params,
354 H: BinaryResultSetHandler,
355 {
356 let result = self.exec_inner(stmt, params, handler);
357 self.check_error(result)
358 }
359
360 fn exec_inner<'conn, P, H>(
361 &'conn mut self,
362 stmt: &'conn mut PreparedStatement,
363 params: P,
364 handler: &mut H,
365 ) -> Result<()>
366 where
367 P: Params,
368 H: BinaryResultSetHandler,
369 {
370 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
371 self.write_payload()?;
372 self.drive_exec(stmt, handler)
373 }
374
375 fn drive_bulk_exec<H: BinaryResultSetHandler>(
376 &mut self,
377 stmt: &mut PreparedStatement,
378 handler: &mut H,
379 ) -> Result<()> {
380 let cache_metadata = self
381 .mariadb_capabilities
382 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
383 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
384
385 loop {
386 match bulk_exec.step(&mut self.buffer_set)? {
387 Action::NeedPacket(buffer) => {
388 buffer.clear();
389 let _ = read_payload(&mut self.stream, buffer)?;
390 }
391 Action::ReadColumnMetadata { num_columns } => {
392 read_column_definition_packets(
393 &mut self.stream,
394 &mut self.buffer_set.column_definition_buffer,
395 num_columns,
396 )?;
397 }
398 Action::Finished => return Ok(()),
399 }
400 }
401 }
402
403 pub fn exec_bulk_insert_or_update<P, I, H>(
408 &mut self,
409 stmt: &mut PreparedStatement,
410 params: P,
411 flags: BulkFlags,
412 handler: &mut H,
413 ) -> Result<()>
414 where
415 P: BulkParamsSet + IntoIterator<Item = I>,
416 I: Params,
417 H: BinaryResultSetHandler,
418 {
419 let result = self.exec_bulk_insert_or_update_inner(stmt, params, flags, handler);
420 self.check_error(result)
421 }
422
423 fn exec_bulk_insert_or_update_inner<P, I, H>(
424 &mut self,
425 stmt: &mut PreparedStatement,
426 params: P,
427 flags: BulkFlags,
428 handler: &mut H,
429 ) -> Result<()>
430 where
431 P: BulkParamsSet + IntoIterator<Item = I>,
432 I: Params,
433 H: BinaryResultSetHandler,
434 {
435 if !self.is_mariadb() {
436 for param in params {
438 self.exec_inner(stmt, param, &mut DropHandler::default())?;
439 }
440 Ok(())
441 } else {
442 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
444 self.write_payload()?;
445 self.drive_bulk_exec(stmt, handler)
446 }
447 }
448
449 pub fn exec_first<Row, P>(
451 &mut self,
452 stmt: &mut PreparedStatement,
453 params: P,
454 ) -> Result<Option<Row>>
455 where
456 Row: for<'buf> crate::raw::FromRawRow<'buf>,
457 P: Params,
458 {
459 let result = self.exec_first_inner(stmt, params);
460 self.check_error(result)
461 }
462
463 fn exec_first_inner<Row, P>(
464 &mut self,
465 stmt: &mut PreparedStatement,
466 params: P,
467 ) -> Result<Option<Row>>
468 where
469 Row: for<'buf> crate::raw::FromRawRow<'buf>,
470 P: Params,
471 {
472 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
473 self.write_payload()?;
474 let mut handler = FirstHandler::<Row>::default();
475 self.drive_exec(stmt, &mut handler)?;
476 Ok(handler.take())
477 }
478
479 pub fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
481 where
482 P: Params,
483 {
484 self.exec(stmt, params, &mut DropHandler::default())
485 }
486
487 pub fn exec_collect<Row, P>(
489 &mut self,
490 stmt: &mut PreparedStatement,
491 params: P,
492 ) -> Result<Vec<Row>>
493 where
494 Row: for<'buf> crate::raw::FromRawRow<'buf>,
495 P: Params,
496 {
497 let mut handler = crate::handler::CollectHandler::<Row>::default();
498 self.exec(stmt, params, &mut handler)?;
499 Ok(handler.into_rows())
500 }
501
502 pub fn exec_foreach<Row, P, F>(
506 &mut self,
507 stmt: &mut PreparedStatement,
508 params: P,
509 f: F,
510 ) -> Result<()>
511 where
512 Row: for<'buf> crate::raw::FromRawRow<'buf>,
513 P: Params,
514 F: FnMut(Row) -> Result<()>,
515 {
516 let mut handler = crate::handler::ForEachHandler::<Row, F>::new(f);
517 self.exec(stmt, params, &mut handler)
518 }
519
520 fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
521 let mut query = Query::new(handler);
522
523 loop {
524 match query.step(&mut self.buffer_set)? {
525 Action::NeedPacket(buffer) => {
526 buffer.clear();
527 let _ = read_payload(&mut self.stream, buffer)?;
528 }
529 Action::ReadColumnMetadata { num_columns } => {
530 read_column_definition_packets(
531 &mut self.stream,
532 &mut self.buffer_set.column_definition_buffer,
533 num_columns,
534 )?;
535 }
536 Action::Finished => return Ok(()),
537 }
538 }
539 }
540
541 pub fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
543 where
544 H: TextResultSetHandler,
545 {
546 let result = self.query_inner(sql, handler);
547 self.check_error(result)
548 }
549
550 fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
551 where
552 H: TextResultSetHandler,
553 {
554 write_query(self.buffer_set.new_write_buffer(), sql);
555 self.write_payload()?;
556 self.drive_query(handler)
557 }
558
559 pub fn query_drop(&mut self, sql: &str) -> Result<()> {
561 let result = self.query_drop_inner(sql);
562 self.check_error(result)
563 }
564
565 fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
566 write_query(self.buffer_set.new_write_buffer(), sql);
567 self.write_payload()?;
568 self.drive_query(&mut DropHandler::default())
569 }
570
571 pub fn ping(&mut self) -> Result<()> {
575 let result = self.ping_inner();
576 self.check_error(result)
577 }
578
579 fn ping_inner(&mut self) -> Result<()> {
580 write_ping(self.buffer_set.new_write_buffer());
581 self.write_payload()?;
582 self.buffer_set.read_buffer.clear();
583 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
584 Ok(())
585 }
586
587 pub fn reset(&mut self) -> Result<()> {
589 let result = self.reset_inner();
590 self.check_error(result)
591 }
592
593 fn reset_inner(&mut self) -> Result<()> {
594 write_reset_connection(self.buffer_set.new_write_buffer());
595 self.write_payload()?;
596 self.buffer_set.read_buffer.clear();
597 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
598 self.in_transaction = false;
599 Ok(())
600 }
601
602 pub fn run_transaction<F, R>(&mut self, f: F) -> Result<R>
607 where
608 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
609 {
610 if self.in_transaction {
611 return Err(Error::NestedTransaction);
612 }
613
614 self.in_transaction = true;
615
616 if let Err(e) = self.query_drop("BEGIN") {
617 self.in_transaction = false;
618 return Err(e);
619 }
620
621 let tx = super::transaction::Transaction::new(self.connection_id());
622 let result = f(self, tx);
623
624 if self.in_transaction {
626 let rollback_result = self.query_drop("ROLLBACK");
627 self.in_transaction = false;
628
629 if let Err(e) = result {
631 return Err(e);
632 }
633 rollback_result?;
634 }
635
636 result
637 }
638}
639
640fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
643 buffer.clear();
644
645 let mut header = PacketHeader::new_zeroed();
646 reader.read_exact(header.as_mut_bytes())?;
647
648 let length = header.length();
649 let mut sequence_id = header.sequence_id;
650
651 buffer.reserve(length);
652
653 {
654 let spare = buffer.spare_capacity_mut();
655 reader.read_buf_exact(&mut spare[..length])?;
656 unsafe {
658 buffer.set_len(length);
659 }
660 }
661
662 let mut current_length = length;
663 while current_length == 0xFFFFFF {
664 reader.read_exact(header.as_mut_bytes())?;
665
666 current_length = header.length();
667 sequence_id = header.sequence_id;
668
669 buffer.reserve(current_length);
670 let spare = buffer.spare_capacity_mut();
671 reader.read_buf_exact(&mut spare[..current_length])?;
672 unsafe {
674 buffer.set_len(buffer.len() + current_length);
675 }
676 }
677
678 Ok(sequence_id)
679}
680
681fn read_column_definition_packets(
682 reader: &mut Stream,
683 out: &mut Vec<u8>,
684 num_columns: usize,
685) -> Result<u8> {
686 out.clear();
687 let mut header = PacketHeader::new_zeroed();
688
689 for _ in 0..num_columns {
691 reader.read_exact(header.as_mut_bytes())?;
692 let length = header.length();
693 out.extend((length as u32).to_ne_bytes());
694
695 out.reserve(length);
696 let spare = out.spare_capacity_mut();
697 reader.read_buf_exact(&mut spare[..length])?;
698 unsafe {
700 out.set_len(out.len() + length);
701 }
702 }
703
704 Ok(header.sequence_id)
705}
706
707fn write_handshake_payload(
708 stream: &mut Stream,
709 buffer_set: &mut BufferSet,
710 sequence_id: u8,
711) -> Result<()> {
712 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
713 let mut seq_id = sequence_id;
714
715 loop {
716 let chunk_size = buffer[4..].len().min(0xFFFFFF);
717 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
718 stream.write_all(&buffer[..4 + chunk_size])?;
719
720 if chunk_size < 0xFFFFFF {
721 break;
722 }
723
724 seq_id = seq_id.wrapping_add(1);
725 buffer = &mut buffer[0xFFFFFF..];
726 }
727 stream.flush()?;
728 Ok(())
729}
730
731#[cfg(unix)]
733struct SocketPathHandler {
734 path: Option<String>,
735}
736
737#[cfg(unix)]
738impl TextResultSetHandler for SocketPathHandler {
739 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
740 Ok(())
741 }
742 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
743 Ok(())
744 }
745 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
746 Ok(())
747 }
748 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
749 if row.0.first() == Some(&0xFB) {
751 return Ok(());
752 }
753 let (value, _) = read_string_lenenc(row.0)?;
755 if !value.is_empty() {
756 self.path = Some(String::from_utf8_lossy(value).into_owned());
757 }
758 Ok(())
759 }
760}