1use tokio::net::{TcpStream, UnixStream};
2use tracing::instrument;
3use zerocopy::{FromBytes, FromZeros, IntoBytes};
4
5use crate::PreparedStatement;
6use crate::buffer::BufferSet;
7use crate::buffer_pool::PooledBufferSet;
8use crate::constant::CapabilityFlags;
9use crate::error::{Error, Result};
10use crate::protocol::command::Action;
11use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
12use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
13use crate::protocol::command::query::{Query, write_query};
14use crate::protocol::command::utility::{
15 DropHandler, FirstRowHandler, write_ping, write_reset_connection,
16};
17use crate::protocol::command::ColumnDefinition;
18use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
19use crate::protocol::packet::PacketHeader;
20use crate::protocol::primitive::read_string_lenenc;
21use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
22use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
23use crate::protocol::TextRowPayload;
24
25use super::stream::Stream;
26
27pub struct Conn {
28 stream: Stream,
29 buffer_set: PooledBufferSet,
30 initial_handshake: InitialHandshake,
31 capability_flags: CapabilityFlags,
32 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
33 in_transaction: bool,
34}
35
36impl Conn {
37 pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
39 where
40 Error: From<O::Error>,
41 {
42 let opts: crate::opts::Opts = opts.try_into()?;
43
44 let stream = if let Some(socket_path) = &opts.socket {
45 let stream = UnixStream::connect(socket_path).await?;
46 Stream::unix(stream)
47 } else {
48 let host = opts.host.as_ref().ok_or_else(|| {
49 Error::BadConfigError("Missing host in connection options".to_string())
50 })?;
51
52 let addr = format!("{}:{}", host, opts.port);
53 let stream = TcpStream::connect(&addr).await?;
54 stream.set_nodelay(opts.tcp_nodelay)?;
55 Stream::tcp(stream)
56 };
57
58 Self::new_with_stream(stream, &opts).await
59 }
60
61 pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
63 let mut conn_stream = stream;
64 let mut buffer_set = opts.buffer_pool.get_buffer_set();
65
66 #[cfg(feature = "tokio-tls")]
67 let host = opts.host.clone().unwrap_or_default();
68
69 let mut handshake = Handshake::new(opts);
70
71 loop {
72 match handshake.step(&mut buffer_set)? {
73 HandshakeAction::ReadPacket(buffer) => {
74 buffer.clear();
75 read_payload(&mut conn_stream, buffer).await?;
76 }
77 HandshakeAction::WritePacket { sequence_id } => {
78 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
79 buffer_set.read_buffer.clear();
80 read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
81 }
82 #[cfg(feature = "tokio-tls")]
83 HandshakeAction::UpgradeTls { sequence_id } => {
84 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
85 conn_stream = conn_stream.upgrade_to_tls(&host).await?;
86 }
87 #[cfg(not(feature = "tokio-tls"))]
88 HandshakeAction::UpgradeTls { .. } => {
89 return Err(Error::BadConfigError(
90 "TLS requested but tokio-tls feature is not enabled".to_string(),
91 ));
92 }
93 HandshakeAction::Finished => break,
94 }
95 }
96
97 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
98
99 let conn = Self {
100 stream: conn_stream,
101 buffer_set,
102 initial_handshake,
103 capability_flags,
104 mariadb_capabilities,
105 in_transaction: false,
106 };
107
108 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
110 conn.try_upgrade_to_unix_socket(opts).await
111 } else {
112 conn
113 };
114
115 if let Some(init_command) = &opts.init_command {
117 conn.query_drop(init_command).await?;
118 }
119
120 Ok(conn)
121 }
122
123 pub fn server_version(&self) -> &[u8] {
124 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
125 }
126
127 pub fn capability_flags(&self) -> CapabilityFlags {
129 self.capability_flags
130 }
131
132 pub fn is_mysql(&self) -> bool {
134 self.capability_flags.is_mysql()
135 }
136
137 pub fn is_mariadb(&self) -> bool {
139 self.capability_flags.is_mariadb()
140 }
141
142 pub fn connection_id(&self) -> u64 {
144 self.initial_handshake.connection_id as u64
145 }
146
147 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
149 self.initial_handshake.status_flags
150 }
151
152 pub(crate) fn set_in_transaction(&mut self, value: bool) {
153 self.in_transaction = value;
154 }
155
156 async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
159 let mut handler = SocketPathHandler { path: None };
161 if self.query("SELECT @@socket", &mut handler).await.is_err() {
162 return self;
163 }
164
165 let socket_path = match handler.path {
166 Some(p) if !p.is_empty() => p,
167 _ => return self,
168 };
169
170 let unix_stream = match UnixStream::connect(&socket_path).await {
172 Ok(s) => s,
173 Err(_) => return self,
174 };
175 let stream = Stream::unix(unix_stream);
176
177 let mut opts_unix = opts.clone();
180 opts_unix.upgrade_to_unix_socket = false;
181
182 match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
183 Ok(new_conn) => new_conn,
184 Err(_) => self,
185 }
186 }
187
188 #[instrument(skip_all)]
190 async fn write_payload(&mut self) -> Result<()> {
191 let mut sequence_id = 0_u8;
192 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
193
194 loop {
195 let chunk_size = buffer[4..].len().min(0xFFFFFF);
196 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
197 .encode_in_place(chunk_size, sequence_id);
198 self.stream.write_all(&buffer[..4 + chunk_size]).await?;
199
200 if chunk_size < 0xFFFFFF {
201 break;
202 }
203
204 sequence_id = sequence_id.wrapping_add(1);
205 buffer = &mut buffer[0xFFFFFF..];
206 }
207 self.stream.flush().await?;
208 Ok(())
209 }
210
211 pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
215 use crate::protocol::command::ColumnDefinitions;
216
217 self.buffer_set.read_buffer.clear();
218
219 write_prepare(self.buffer_set.new_write_buffer(), sql);
220
221 self.write_payload().await?;
222
223 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
224
225 if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
226 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
227 }
228
229 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
230 let statement_id = prepare_ok.statement_id();
231 let num_params = prepare_ok.num_params();
232 let num_columns = prepare_ok.num_columns();
233
234 for _ in 0..num_params {
236 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
237 }
238
239 let column_definitions = if num_columns > 0 {
241 self.read_column_definition_packets(num_columns as usize)
242 .await?;
243 Some(ColumnDefinitions::new(
244 num_columns as usize,
245 std::mem::take(&mut self.buffer_set.column_definition_buffer),
246 )?)
247 } else {
248 None
249 };
250
251 let mut stmt = PreparedStatement::new(statement_id);
252 if let Some(col_defs) = column_definitions {
253 stmt.set_column_definitions(col_defs);
254 }
255 Ok(stmt)
256 }
257
258 #[tracing::instrument(skip_all)]
259 async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
260 let mut header = PacketHeader::new_zeroed();
261 let out = &mut self.buffer_set.column_definition_buffer;
262 out.clear();
263
264 for _ in 0..num_columns {
266 self.stream.read_exact(header.as_mut_bytes()).await?;
267 let length = header.length();
268 out.extend((length as u32).to_ne_bytes());
269
270 out.reserve(length);
271 let spare = out.spare_capacity_mut();
272 self.stream.read_buf_exact(&mut spare[..length]).await?;
273 unsafe {
275 out.set_len(out.len() + length);
276 }
277 }
278
279 Ok(header.sequence_id)
280 }
281
282 async fn drive_exec<H: BinaryResultSetHandler>(
283 &mut self,
284 stmt: &mut crate::PreparedStatement,
285 handler: &mut H,
286 ) -> Result<()> {
287 let cache_metadata = self
288 .mariadb_capabilities
289 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
290 let mut exec = Exec::new(handler, stmt, cache_metadata);
291
292 loop {
293 match exec.step(&mut self.buffer_set)? {
294 Action::NeedPacket(buffer) => {
295 buffer.clear();
296 let _ = read_payload(&mut self.stream, buffer).await?;
297 }
298 Action::ReadColumnMetadata { num_columns } => {
299 self.read_column_definition_packets(num_columns).await?;
300 }
301 Action::Finished => return Ok(()),
302 }
303 }
304 }
305
306 async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
307 let mut query = Query::new(handler);
308
309 loop {
310 match query.step(&mut self.buffer_set)? {
311 Action::NeedPacket(buffer) => {
312 buffer.clear();
313 let _ = read_payload(&mut self.stream, buffer).await?;
314 }
315 Action::ReadColumnMetadata { num_columns } => {
316 self.read_column_definition_packets(num_columns).await?;
317 }
318 Action::Finished => return Ok(()),
319 }
320 }
321 }
322
323 pub async fn exec<P, H>(
325 &mut self,
326 stmt: &mut PreparedStatement,
327 params: P,
328 handler: &mut H,
329 ) -> Result<()>
330 where
331 P: Params,
332 H: BinaryResultSetHandler,
333 {
334 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
335 self.write_payload().await?;
336 self.drive_exec(stmt, handler).await
337 }
338
339 async fn drive_bulk_exec<H: BinaryResultSetHandler>(
340 &mut self,
341 stmt: &mut crate::PreparedStatement,
342 handler: &mut H,
343 ) -> Result<()> {
344 let cache_metadata = self
345 .mariadb_capabilities
346 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
347 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
348
349 loop {
350 match bulk_exec.step(&mut self.buffer_set)? {
351 Action::NeedPacket(buffer) => {
352 buffer.clear();
353 let _ = read_payload(&mut self.stream, buffer).await?;
354 }
355 Action::ReadColumnMetadata { num_columns } => {
356 self.read_column_definition_packets(num_columns).await?;
357 }
358 Action::Finished => return Ok(()),
359 }
360 }
361 }
362
363 pub async fn exec_bulk<P, I, H>(
365 &mut self,
366 stmt: &mut PreparedStatement,
367 params: P,
368 flags: BulkFlags,
369 handler: &mut H,
370 ) -> Result<()>
371 where
372 P: BulkParamsSet + IntoIterator<Item = I>,
373 I: Params,
374 H: BinaryResultSetHandler,
375 {
376 if !self.is_mariadb() {
377 for param in params {
379 self.exec_drop(stmt, param).await?;
380 }
381 Ok(())
382 } else {
383 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
385 self.write_payload().await?;
386 self.drive_bulk_exec(stmt, handler).await
387 }
388 }
389
390 pub async fn exec_first<P, H>(
397 &mut self,
398 stmt: &mut PreparedStatement,
399 params: P,
400 handler: &mut H,
401 ) -> Result<bool>
402 where
403 P: Params,
404 H: BinaryResultSetHandler,
405 {
406 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
407 self.write_payload().await?;
408 let mut first_row_handler = FirstRowHandler::new(handler);
409 self.drive_exec(stmt, &mut first_row_handler).await?;
410 Ok(first_row_handler.found_row)
411 }
412
413 #[instrument(skip_all)]
415 pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
416 where
417 P: Params,
418 {
419 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
420 self.write_payload().await?;
421 self.drive_exec(stmt, &mut DropHandler::default()).await
422 }
423
424 pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
426 where
427 H: TextResultSetHandler,
428 {
429 write_query(self.buffer_set.new_write_buffer(), sql);
430 self.write_payload().await?;
431 self.drive_query(handler).await
432 }
433
434 #[instrument(skip_all)]
436 pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
437 write_query(self.buffer_set.new_write_buffer(), sql);
438 self.write_payload().await?;
439 self.drive_query(&mut DropHandler::default()).await
440 }
441
442 pub async fn ping(&mut self) -> Result<()> {
446 write_ping(self.buffer_set.new_write_buffer());
447 self.write_payload().await?;
448 self.buffer_set.read_buffer.clear();
449 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
450 Ok(())
451 }
452
453 pub async fn reset(&mut self) -> Result<()> {
455 write_reset_connection(self.buffer_set.new_write_buffer());
456 self.write_payload().await?;
457 self.buffer_set.read_buffer.clear();
458 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
459 self.in_transaction = false;
460 Ok(())
461 }
462
463 pub async fn run_transaction<F, Fut, R>(&mut self, f: F) -> Result<R>
468 where
469 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
470 Fut: core::future::Future<Output = Result<R>>,
471 {
472 if self.in_transaction {
473 return Err(Error::NestedTransaction);
474 }
475
476 self.in_transaction = true;
477
478 if let Err(err) = self.query_drop("BEGIN").await {
479 self.in_transaction = false;
480 return Err(err);
481 }
482
483 let tx = super::transaction::Transaction::new(self.connection_id());
484 let result = f(self, tx).await;
485
486 if self.in_transaction {
488 let rollback_result = self.query_drop("ROLLBACK").await;
489 self.in_transaction = false;
490
491 if let Err(e) = result {
493 return Err(e);
494 }
495 rollback_result?;
496 }
497
498 result
499 }
500}
501
502#[instrument(skip_all)]
505async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
506 let mut packet_header = PacketHeader::new_zeroed();
507
508 buffer.clear();
509 reader.read_exact(packet_header.as_mut_bytes()).await?;
510
511 let length = packet_header.length();
512 let mut sequence_id = packet_header.sequence_id;
513
514 buffer.reserve(length);
515
516 {
518 let spare = buffer.spare_capacity_mut();
519 reader.read_buf_exact(&mut spare[..length]).await?;
520 unsafe {
522 buffer.set_len(length);
523 }
524 }
525
526 let mut current_length = length;
527 while current_length == 0xFFFFFF {
528 reader.read_exact(packet_header.as_mut_bytes()).await?;
529
530 current_length = packet_header.length();
531 sequence_id = packet_header.sequence_id;
532
533 buffer.reserve(current_length);
534 let spare = buffer.spare_capacity_mut();
535 reader.read_buf_exact(&mut spare[..current_length]).await?;
536 unsafe {
538 buffer.set_len(buffer.len() + current_length);
539 }
540 }
541
542 Ok(sequence_id)
543}
544
545async fn write_handshake_payload(
546 stream: &mut Stream,
547 buffer_set: &mut BufferSet,
548 sequence_id: u8,
549) -> Result<()> {
550 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
551 let mut seq_id = sequence_id;
552
553 loop {
554 let chunk_size = buffer[4..].len().min(0xFFFFFF);
555 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
556 stream.write_all(&buffer[..4 + chunk_size]).await?;
557
558 if chunk_size < 0xFFFFFF {
559 break;
560 }
561
562 seq_id = seq_id.wrapping_add(1);
563 buffer = &mut buffer[0xFFFFFF..];
564 }
565 stream.flush().await?;
566 Ok(())
567}
568
569struct SocketPathHandler {
571 path: Option<String>,
572}
573
574impl TextResultSetHandler for SocketPathHandler {
575 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
576 Ok(())
577 }
578 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
579 Ok(())
580 }
581 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
582 Ok(())
583 }
584 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
585 if row.0.first() == Some(&0xFB) {
587 return Ok(());
588 }
589 let (value, _) = read_string_lenenc(row.0)?;
591 if !value.is_empty() {
592 self.path = Some(String::from_utf8_lossy(value).into_owned());
593 }
594 Ok(())
595 }
596}