zero_mysql/tokio/
conn.rs

1use tokio::net::{TcpStream, UnixStream};
2use tracing::instrument;
3use zerocopy::{FromBytes, FromZeros, IntoBytes};
4
5use crate::PreparedStatement;
6use crate::buffer::BufferSet;
7use crate::buffer_pool::PooledBufferSet;
8use crate::constant::CapabilityFlags;
9use crate::error::{Error, Result};
10use crate::protocol::TextRowPayload;
11use crate::protocol::command::Action;
12use crate::protocol::command::ColumnDefinition;
13use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
14use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
15use crate::protocol::command::query::{Query, write_query};
16use crate::protocol::command::utility::{
17    DropHandler, FirstRowHandler, write_ping, write_reset_connection,
18};
19use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
20use crate::protocol::packet::PacketHeader;
21use crate::protocol::primitive::read_string_lenenc;
22use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
23use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
24
25use super::stream::Stream;
26
27pub struct Conn {
28    stream: Stream,
29    buffer_set: PooledBufferSet,
30    initial_handshake: InitialHandshake,
31    capability_flags: CapabilityFlags,
32    mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
33    in_transaction: bool,
34    is_broken: bool,
35}
36
37impl Conn {
38    /// Create a new MySQL connection from connection options (async)
39    pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
40    where
41        Error: From<O::Error>,
42    {
43        let opts: crate::opts::Opts = opts.try_into()?;
44
45        let stream = if let Some(socket_path) = &opts.socket {
46            let stream = UnixStream::connect(socket_path).await?;
47            Stream::unix(stream)
48        } else {
49            if opts.host.is_empty() {
50                return Err(Error::BadUsageError(
51                    "Missing host in connection options".to_string(),
52                ));
53            }
54            let addr = format!("{}:{}", opts.host, opts.port);
55            let stream = TcpStream::connect(&addr).await?;
56            stream.set_nodelay(opts.tcp_nodelay)?;
57            Stream::tcp(stream)
58        };
59
60        Self::new_with_stream(stream, &opts).await
61    }
62
63    /// Create a new MySQL connection with an existing stream (async)
64    pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
65        let mut conn_stream = stream;
66        let mut buffer_set = opts.buffer_pool.get_buffer_set();
67
68        #[cfg(feature = "tokio-tls")]
69        let host = opts.host.clone();
70
71        let mut handshake = Handshake::new(opts);
72
73        loop {
74            match handshake.step(&mut buffer_set)? {
75                HandshakeAction::ReadPacket(buffer) => {
76                    buffer.clear();
77                    read_payload(&mut conn_stream, buffer).await?;
78                }
79                HandshakeAction::WritePacket { sequence_id } => {
80                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
81                    buffer_set.read_buffer.clear();
82                    read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
83                }
84                #[cfg(feature = "tokio-tls")]
85                HandshakeAction::UpgradeTls { sequence_id } => {
86                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
87                    conn_stream = conn_stream.upgrade_to_tls(&host).await?;
88                }
89                #[cfg(not(feature = "tokio-tls"))]
90                HandshakeAction::UpgradeTls { .. } => {
91                    return Err(Error::BadUsageError(
92                        "TLS requested but tokio-tls feature is not enabled".to_string(),
93                    ));
94                }
95                HandshakeAction::Finished => break,
96            }
97        }
98
99        let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
100
101        let conn = Self {
102            stream: conn_stream,
103            buffer_set,
104            initial_handshake,
105            capability_flags,
106            mariadb_capabilities,
107            in_transaction: false,
108            is_broken: false,
109        };
110
111        // Upgrade to Unix socket if connected via TCP to loopback
112        let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
113            conn.try_upgrade_to_unix_socket(opts).await
114        } else {
115            conn
116        };
117
118        // Execute init command if specified
119        if let Some(init_command) = &opts.init_command {
120            conn.query_drop(init_command).await?;
121        }
122
123        Ok(conn)
124    }
125
126    pub fn server_version(&self) -> &[u8] {
127        &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
128    }
129
130    /// Get the negotiated capability flags
131    pub fn capability_flags(&self) -> CapabilityFlags {
132        self.capability_flags
133    }
134
135    /// Check if the server is MySQL (as opposed to MariaDB)
136    pub fn is_mysql(&self) -> bool {
137        self.capability_flags.is_mysql()
138    }
139
140    /// Check if the server is MariaDB (as opposed to MySQL)
141    pub fn is_mariadb(&self) -> bool {
142        self.capability_flags.is_mariadb()
143    }
144
145    /// Get the connection ID assigned by the server
146    pub fn connection_id(&self) -> u64 {
147        self.initial_handshake.connection_id as u64
148    }
149
150    /// Get the server status flags from the initial handshake
151    pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
152        self.initial_handshake.status_flags
153    }
154
155    /// Check if the connection is broken due to a previous I/O error
156    pub fn is_broken(&self) -> bool {
157        self.is_broken
158    }
159
160    #[inline]
161    fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
162        if let Err(e) = &result
163            && e.is_conn_broken()
164        {
165            self.is_broken = true;
166        }
167        result
168    }
169
170    pub(crate) fn set_in_transaction(&mut self, value: bool) {
171        self.in_transaction = value;
172    }
173
174    /// Try to upgrade to Unix socket connection.
175    /// Returns upgraded conn on success, original conn on failure.
176    async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
177        // Query the server for its Unix socket path
178        let mut handler = SocketPathHandler { path: None };
179        if self.query("SELECT @@socket", &mut handler).await.is_err() {
180            return self;
181        }
182
183        let socket_path = match handler.path {
184            Some(p) if !p.is_empty() => p,
185            _ => return self,
186        };
187
188        // Connect via Unix socket
189        let unix_stream = match UnixStream::connect(&socket_path).await {
190            Ok(s) => s,
191            Err(_) => return self,
192        };
193        let stream = Stream::unix(unix_stream);
194
195        // Create new connection over Unix socket (re-handshakes)
196        // Disable upgrade_to_unix_socket to prevent infinite recursion
197        let mut opts_unix = opts.clone();
198        opts_unix.upgrade_to_unix_socket = false;
199
200        match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
201            Ok(new_conn) => new_conn,
202            Err(_) => self,
203        }
204    }
205
206    /// Write a MySQL packet from write_buffer asynchronously, splitting it into 16MB chunks if necessary
207    #[instrument(skip_all)]
208    async fn write_payload(&mut self) -> Result<()> {
209        let mut sequence_id = 0_u8;
210        let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
211
212        loop {
213            let chunk_size = buffer[4..].len().min(0xFFFFFF);
214            PacketHeader::mut_from_bytes(&mut buffer[0..4])?
215                .encode_in_place(chunk_size, sequence_id);
216            self.stream.write_all(&buffer[..4 + chunk_size]).await?;
217
218            if chunk_size < 0xFFFFFF {
219                break;
220            }
221
222            sequence_id = sequence_id.wrapping_add(1);
223            buffer = &mut buffer[0xFFFFFF..];
224        }
225        self.stream.flush().await?;
226        Ok(())
227    }
228
229    /// Prepare a statement and return the PreparedStatement (async)
230    ///
231    /// Returns `Ok(PreparedStatement)` on success.
232    pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
233        let result = self.prepare_inner(sql).await;
234        self.check_error(result)
235    }
236
237    async fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
238        use crate::protocol::command::ColumnDefinitions;
239
240        self.buffer_set.read_buffer.clear();
241
242        write_prepare(self.buffer_set.new_write_buffer(), sql);
243
244        self.write_payload().await?;
245
246        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
247
248        if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
249            Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
250        }
251
252        let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
253        let statement_id = prepare_ok.statement_id();
254        let num_params = prepare_ok.num_params();
255        let num_columns = prepare_ok.num_columns();
256
257        // Skip param definitions (we don't cache them)
258        for _ in 0..num_params {
259            let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
260        }
261
262        // Read and cache column definitions for MARIADB_CLIENT_CACHE_METADATA support
263        let column_definitions = if num_columns > 0 {
264            self.read_column_definition_packets(num_columns as usize)
265                .await?;
266            Some(ColumnDefinitions::new(
267                num_columns as usize,
268                std::mem::take(&mut self.buffer_set.column_definition_buffer),
269            )?)
270        } else {
271            None
272        };
273
274        let mut stmt = PreparedStatement::new(statement_id);
275        if let Some(col_defs) = column_definitions {
276            stmt.set_column_definitions(col_defs);
277        }
278        Ok(stmt)
279    }
280
281    #[tracing::instrument(skip_all)]
282    async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
283        let mut header = PacketHeader::new_zeroed();
284        let out = &mut self.buffer_set.column_definition_buffer;
285        out.clear();
286
287        // For each column, write [4 bytes len][payload]
288        for _ in 0..num_columns {
289            self.stream.read_exact(header.as_mut_bytes()).await?;
290            let length = header.length();
291            out.extend((length as u32).to_ne_bytes());
292
293            out.reserve(length);
294            let spare = out.spare_capacity_mut();
295            self.stream.read_buf_exact(&mut spare[..length]).await?;
296            // SAFETY: read_buf_exact filled exactly `length` bytes
297            unsafe {
298                out.set_len(out.len() + length);
299            }
300        }
301
302        Ok(header.sequence_id)
303    }
304
305    async fn drive_exec<H: BinaryResultSetHandler>(
306        &mut self,
307        stmt: &mut crate::PreparedStatement,
308        handler: &mut H,
309    ) -> Result<()> {
310        let cache_metadata = self
311            .mariadb_capabilities
312            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
313        let mut exec = Exec::new(handler, stmt, cache_metadata);
314
315        loop {
316            match exec.step(&mut self.buffer_set)? {
317                Action::NeedPacket(buffer) => {
318                    buffer.clear();
319                    let _ = read_payload(&mut self.stream, buffer).await?;
320                }
321                Action::ReadColumnMetadata { num_columns } => {
322                    self.read_column_definition_packets(num_columns).await?;
323                }
324                Action::Finished => return Ok(()),
325            }
326        }
327    }
328
329    async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
330        let mut query = Query::new(handler);
331
332        loop {
333            match query.step(&mut self.buffer_set)? {
334                Action::NeedPacket(buffer) => {
335                    buffer.clear();
336                    let _ = read_payload(&mut self.stream, buffer).await?;
337                }
338                Action::ReadColumnMetadata { num_columns } => {
339                    self.read_column_definition_packets(num_columns).await?;
340                }
341                Action::Finished => return Ok(()),
342            }
343        }
344    }
345
346    /// Execute a prepared statement with a result set handler (async)
347    pub async fn exec<P, H>(
348        &mut self,
349        stmt: &mut PreparedStatement,
350        params: P,
351        handler: &mut H,
352    ) -> Result<()>
353    where
354        P: Params,
355        H: BinaryResultSetHandler,
356    {
357        let result = self.exec_inner(stmt, params, handler).await;
358        self.check_error(result)
359    }
360
361    async fn exec_inner<P, H>(
362        &mut self,
363        stmt: &mut PreparedStatement,
364        params: P,
365        handler: &mut H,
366    ) -> Result<()>
367    where
368        P: Params,
369        H: BinaryResultSetHandler,
370    {
371        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
372        self.write_payload().await?;
373        self.drive_exec(stmt, handler).await
374    }
375
376    async fn drive_bulk_exec<H: BinaryResultSetHandler>(
377        &mut self,
378        stmt: &mut crate::PreparedStatement,
379        handler: &mut H,
380    ) -> Result<()> {
381        let cache_metadata = self
382            .mariadb_capabilities
383            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
384        let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
385
386        loop {
387            match bulk_exec.step(&mut self.buffer_set)? {
388                Action::NeedPacket(buffer) => {
389                    buffer.clear();
390                    let _ = read_payload(&mut self.stream, buffer).await?;
391                }
392                Action::ReadColumnMetadata { num_columns } => {
393                    self.read_column_definition_packets(num_columns).await?;
394                }
395                Action::Finished => return Ok(()),
396            }
397        }
398    }
399
400    /// Execute a bulk prepared statement with a result set handler (async)
401    pub async fn exec_bulk_insert_or_update<P, I, H>(
402        &mut self,
403        stmt: &mut PreparedStatement,
404        params: P,
405        flags: BulkFlags,
406        handler: &mut H,
407    ) -> Result<()>
408    where
409        P: BulkParamsSet + IntoIterator<Item = I>,
410        I: Params,
411        H: BinaryResultSetHandler,
412    {
413        let result = self
414            .exec_bulk_insert_or_update_inner(stmt, params, flags, handler)
415            .await;
416        self.check_error(result)
417    }
418
419    async fn exec_bulk_insert_or_update_inner<P, I, H>(
420        &mut self,
421        stmt: &mut PreparedStatement,
422        params: P,
423        flags: BulkFlags,
424        handler: &mut H,
425    ) -> Result<()>
426    where
427        P: BulkParamsSet + IntoIterator<Item = I>,
428        I: Params,
429        H: BinaryResultSetHandler,
430    {
431        if !self.is_mariadb() {
432            // Fallback to multiple exec_drop for non-MariaDB servers
433            for param in params {
434                self.exec_inner(stmt, param, &mut DropHandler::default())
435                    .await?;
436            }
437            Ok(())
438        } else {
439            // Use MariaDB bulk execute protocol
440            write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
441            self.write_payload().await?;
442            self.drive_bulk_exec(stmt, handler).await
443        }
444    }
445
446    /// Execute a prepared statement and return only the first row, dropping the rest (async)
447    ///
448    /// # Returns
449    /// * `Ok(true)` - First row was found and processed
450    /// * `Ok(false)` - No rows in result set
451    /// * `Err(Error)` - Query execution or handler callback failed
452    pub async fn exec_first<P, H>(
453        &mut self,
454        stmt: &mut PreparedStatement,
455        params: P,
456        handler: &mut H,
457    ) -> Result<bool>
458    where
459        P: Params,
460        H: BinaryResultSetHandler,
461    {
462        let result = self.exec_first_inner(stmt, params, handler).await;
463        self.check_error(result)
464    }
465
466    async fn exec_first_inner<P, H>(
467        &mut self,
468        stmt: &mut PreparedStatement,
469        params: P,
470        handler: &mut H,
471    ) -> Result<bool>
472    where
473        P: Params,
474        H: BinaryResultSetHandler,
475    {
476        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
477        self.write_payload().await?;
478        let mut first_row_handler = FirstRowHandler::new(handler);
479        self.drive_exec(stmt, &mut first_row_handler).await?;
480        Ok(first_row_handler.found_row)
481    }
482
483    /// Execute a prepared statement and discard all results (async)
484    #[instrument(skip_all)]
485    pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
486    where
487        P: Params,
488    {
489        self.exec(stmt, params, &mut DropHandler::default()).await
490    }
491
492    /// Execute a text protocol SQL query (async)
493    pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
494    where
495        H: TextResultSetHandler,
496    {
497        let result = self.query_inner(sql, handler).await;
498        self.check_error(result)
499    }
500
501    async fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
502    where
503        H: TextResultSetHandler,
504    {
505        write_query(self.buffer_set.new_write_buffer(), sql);
506        self.write_payload().await?;
507        self.drive_query(handler).await
508    }
509
510    /// Execute a text protocol SQL query and discard all results (async)
511    #[instrument(skip_all)]
512    pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
513        let result = self.query_drop_inner(sql).await;
514        self.check_error(result)
515    }
516
517    async fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
518        write_query(self.buffer_set.new_write_buffer(), sql);
519        self.write_payload().await?;
520        self.drive_query(&mut DropHandler::default()).await
521    }
522
523    /// Send a ping to the server to check if the connection is alive (async)
524    ///
525    /// This sends a COM_PING command to the MySQL server and waits for an OK response.
526    pub async fn ping(&mut self) -> Result<()> {
527        let result = self.ping_inner().await;
528        self.check_error(result)
529    }
530
531    async fn ping_inner(&mut self) -> Result<()> {
532        write_ping(self.buffer_set.new_write_buffer());
533        self.write_payload().await?;
534        self.buffer_set.read_buffer.clear();
535        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
536        Ok(())
537    }
538
539    /// Reset the connection to its initial state (async)
540    pub async fn reset(&mut self) -> Result<()> {
541        let result = self.reset_inner().await;
542        self.check_error(result)
543    }
544
545    async fn reset_inner(&mut self) -> Result<()> {
546        write_reset_connection(self.buffer_set.new_write_buffer());
547        self.write_payload().await?;
548        self.buffer_set.read_buffer.clear();
549        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
550        self.in_transaction = false;
551        Ok(())
552    }
553
554    /// Execute a closure within a transaction (async)
555    ///
556    /// # Errors
557    /// Returns `Error::NestedTransaction` if called while already in a transaction
558    pub async fn run_transaction<F, Fut, R>(&mut self, f: F) -> Result<R>
559    where
560        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
561        Fut: core::future::Future<Output = Result<R>>,
562    {
563        if self.in_transaction {
564            return Err(Error::NestedTransaction);
565        }
566
567        self.in_transaction = true;
568
569        if let Err(err) = self.query_drop("BEGIN").await {
570            self.in_transaction = false;
571            return Err(err);
572        }
573
574        let tx = super::transaction::Transaction::new(self.connection_id());
575        let result = f(self, tx).await;
576
577        // If the transaction was not explicitly committed or rolled back, roll it back
578        if self.in_transaction {
579            let rollback_result = self.query_drop("ROLLBACK").await;
580            self.in_transaction = false;
581
582            // Return the first error (either from closure or rollback)
583            if let Err(e) = result {
584                return Err(e);
585            }
586            rollback_result?;
587        }
588
589        result
590    }
591}
592
593/// Read a complete MySQL payload asynchronously, concatenating packets if they span multiple 16MB chunks
594/// Returns the sequence_id of the last packet read.
595#[instrument(skip_all)]
596async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
597    let mut packet_header = PacketHeader::new_zeroed();
598
599    buffer.clear();
600    reader.read_exact(packet_header.as_mut_bytes()).await?;
601
602    let length = packet_header.length();
603    let mut sequence_id = packet_header.sequence_id;
604
605    buffer.reserve(length);
606
607    // read the first payload
608    {
609        let spare = buffer.spare_capacity_mut();
610        reader.read_buf_exact(&mut spare[..length]).await?;
611        // SAFETY: read_buf_exact filled exactly `length` bytes
612        unsafe {
613            buffer.set_len(length);
614        }
615    }
616
617    let mut current_length = length;
618    while current_length == 0xFFFFFF {
619        reader.read_exact(packet_header.as_mut_bytes()).await?;
620
621        current_length = packet_header.length();
622        sequence_id = packet_header.sequence_id;
623
624        buffer.reserve(current_length);
625        let spare = buffer.spare_capacity_mut();
626        reader.read_buf_exact(&mut spare[..current_length]).await?;
627        // SAFETY: read_buf_exact filled exactly `current_length` bytes
628        unsafe {
629            buffer.set_len(buffer.len() + current_length);
630        }
631    }
632
633    Ok(sequence_id)
634}
635
636async fn write_handshake_payload(
637    stream: &mut Stream,
638    buffer_set: &mut BufferSet,
639    sequence_id: u8,
640) -> Result<()> {
641    let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
642    let mut seq_id = sequence_id;
643
644    loop {
645        let chunk_size = buffer[4..].len().min(0xFFFFFF);
646        PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
647        stream.write_all(&buffer[..4 + chunk_size]).await?;
648
649        if chunk_size < 0xFFFFFF {
650            break;
651        }
652
653        seq_id = seq_id.wrapping_add(1);
654        buffer = &mut buffer[0xFFFFFF..];
655    }
656    stream.flush().await?;
657    Ok(())
658}
659
660/// Handler to capture socket path from SELECT @@socket query
661struct SocketPathHandler {
662    path: Option<String>,
663}
664
665impl TextResultSetHandler for SocketPathHandler {
666    fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
667        Ok(())
668    }
669    fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
670        Ok(())
671    }
672    fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
673        Ok(())
674    }
675    fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
676        // 0xFB indicates NULL value
677        if row.0.first() == Some(&0xFB) {
678            return Ok(());
679        }
680        // Parse the first length-encoded string
681        let (value, _) = read_string_lenenc(row.0)?;
682        if !value.is_empty() {
683            self.path = Some(String::from_utf8_lossy(value).into_owned());
684        }
685        Ok(())
686    }
687}