1use crate::buffer::BufferSet;
2use crate::buffer_pool::PooledBufferSet;
3use crate::constant::CapabilityFlags;
4use crate::error::{Error, Result};
5use crate::protocol::command::bulk_exec::{write_bulk_execute, BulkExec, BulkFlags, BulkParamsSet};
6use crate::protocol::command::prepared::write_execute;
7use crate::protocol::command::prepared::Exec;
8use crate::protocol::command::prepared::{read_prepare_ok, write_prepare};
9use crate::protocol::command::query::write_query;
10use crate::protocol::command::query::Query;
11use crate::protocol::command::utility::write_ping;
12use crate::protocol::command::utility::write_reset_connection;
13use crate::protocol::command::utility::DropHandler;
14use crate::protocol::command::utility::FirstRowHandler;
15use crate::protocol::command::Action;
16use crate::protocol::command::ColumnDefinition;
17use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
18use crate::protocol::packet::PacketHeader;
19use crate::protocol::primitive::read_string_lenenc;
20use crate::protocol::r#trait::{param::Params, BinaryResultSetHandler, TextResultSetHandler};
21use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
22use crate::protocol::TextRowPayload;
23use crate::PreparedStatement;
24use core::hint::unlikely;
25use core::io::BorrowedBuf;
26use std::net::TcpStream;
27use std::os::unix::net::UnixStream;
28use zerocopy::FromZeros;
29use zerocopy::{FromBytes, IntoBytes};
30
31use super::stream::Stream;
32
33pub struct Conn {
34 stream: Stream,
35 buffer_set: PooledBufferSet,
36 initial_handshake: InitialHandshake,
37 capability_flags: CapabilityFlags,
38 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
39 in_transaction: bool,
40}
41
42impl Conn {
43 pub(crate) fn set_in_transaction(&mut self, value: bool) {
44 self.in_transaction = value;
45 }
46
47 pub fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
49 where
50 Error: From<O::Error>,
51 {
52 let opts: crate::opts::Opts = opts.try_into()?;
53
54 let stream = if let Some(socket_path) = &opts.socket {
55 let stream = UnixStream::connect(socket_path)?;
56 Stream::unix(stream)
57 } else {
58 let host = opts.host.as_ref().ok_or_else(|| {
59 Error::BadConfigError("Missing host in connection options".to_string())
60 })?;
61
62 let addr = format!("{}:{}", host, opts.port);
63 let stream = TcpStream::connect(&addr)?;
64 stream.set_nodelay(opts.tcp_nodelay)?;
65 Stream::tcp(stream)
66 };
67
68 Self::new_with_stream(stream, &opts)
69 }
70
71 pub fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
73 let mut conn_stream = stream;
74 let mut buffer_set = opts.buffer_pool.get_buffer_set();
75
76 #[cfg(feature = "sync-tls")]
77 let host = opts.host.clone().unwrap_or_default();
78
79 let mut handshake = Handshake::new(opts);
80
81 loop {
82 match handshake.step(&mut buffer_set)? {
83 HandshakeAction::ReadPacket(buffer) => {
84 buffer.clear();
85 read_payload(&mut conn_stream, buffer)?;
86 }
87 HandshakeAction::WritePacket { sequence_id } => {
88 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
89 buffer_set.read_buffer.clear();
90 read_payload(&mut conn_stream, &mut buffer_set.read_buffer)?;
91 }
92 #[cfg(feature = "sync-tls")]
93 HandshakeAction::UpgradeTls { sequence_id } => {
94 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
95 conn_stream = conn_stream.upgrade_to_tls(&host)?;
96 }
97 #[cfg(not(feature = "sync-tls"))]
98 HandshakeAction::UpgradeTls { .. } => {
99 return Err(Error::BadConfigError(
100 "TLS requested but sync-tls feature is not enabled".to_string(),
101 ));
102 }
103 HandshakeAction::Finished => break,
104 }
105 }
106
107 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
108
109 let conn = Self {
110 stream: conn_stream,
111 buffer_set,
112 initial_handshake,
113 capability_flags,
114 mariadb_capabilities,
115 in_transaction: false,
116 };
117
118 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
120 conn.try_upgrade_to_unix_socket(opts)
121 } else {
122 conn
123 };
124
125 if let Some(init_command) = &opts.init_command {
127 conn.query_drop(init_command)?;
128 }
129
130 Ok(conn)
131 }
132
133 pub fn server_version(&self) -> &[u8] {
134 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
135 }
136
137 pub fn capability_flags(&self) -> CapabilityFlags {
139 self.capability_flags
140 }
141
142 pub fn is_mysql(&self) -> bool {
144 self.capability_flags.is_mysql()
145 }
146
147 pub fn is_mariadb(&self) -> bool {
149 self.capability_flags.is_mariadb()
150 }
151
152 pub fn connection_id(&self) -> u64 {
154 self.initial_handshake.connection_id as u64
155 }
156
157 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
159 self.initial_handshake.status_flags
160 }
161
162 fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
165 let mut handler = SocketPathHandler { path: None };
167 if self.query("SELECT @@socket", &mut handler).is_err() {
168 return self;
169 }
170
171 let socket_path = match handler.path {
172 Some(p) if !p.is_empty() => p,
173 _ => return self,
174 };
175
176 let unix_stream = match UnixStream::connect(&socket_path) {
178 Ok(s) => s,
179 Err(_) => return self,
180 };
181 let stream = Stream::unix(unix_stream);
182
183 let mut opts_unix = opts.clone();
186 opts_unix.upgrade_to_unix_socket = false;
187
188 match Self::new_with_stream(stream, &opts_unix) {
189 Ok(new_conn) => new_conn,
190 Err(_) => self,
191 }
192 }
193
194 fn write_payload(&mut self) -> Result<()> {
195 let mut sequence_id = 0_u8;
196 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
197
198 loop {
199 let chunk_size = buffer[4..].len().min(0xFFFFFF);
200 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
201 .encode_in_place(chunk_size, sequence_id);
202 self.stream.write_all(&buffer[..4 + chunk_size])?;
203
204 if chunk_size < 0xFFFFFF {
205 break;
206 }
207
208 sequence_id = sequence_id.wrapping_add(1);
209 buffer = &mut buffer[0xFFFFFF..];
210 }
211 self.stream.flush()?;
212 Ok(())
213 }
214
215 pub fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
217 use crate::protocol::command::ColumnDefinitions;
218
219 self.buffer_set.read_buffer.clear();
220
221 write_prepare(self.buffer_set.new_write_buffer(), sql);
222
223 self.write_payload()?;
224 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
225
226 if unlikely(
227 !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF,
228 ) {
229 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
230 }
231
232 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
233 let statement_id = prepare_ok.statement_id();
234 let num_params = prepare_ok.num_params();
235 let num_columns = prepare_ok.num_columns();
236
237 if num_params > 0 {
239 for _ in 0..num_params {
240 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
241 }
242 }
243
244 let column_definitions = if num_columns > 0 {
246 read_column_definition_packets(
247 &mut self.stream,
248 &mut self.buffer_set.column_definition_buffer,
249 num_columns as usize,
250 )?;
251 Some(ColumnDefinitions::new(
252 num_columns as usize,
253 std::mem::take(&mut self.buffer_set.column_definition_buffer),
254 )?)
255 } else {
256 None
257 };
258
259 let mut stmt = PreparedStatement::new(statement_id);
260 if let Some(col_defs) = column_definitions {
261 stmt.set_column_definitions(col_defs);
262 }
263 Ok(stmt)
264 }
265
266 fn drive_exec<H: BinaryResultSetHandler>(
267 &mut self,
268 stmt: &mut PreparedStatement,
269 handler: &mut H,
270 ) -> Result<()> {
271 let cache_metadata = self
272 .mariadb_capabilities
273 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
274 let mut exec = Exec::new(handler, stmt, cache_metadata);
275
276 loop {
277 match exec.step(&mut self.buffer_set)? {
278 Action::NeedPacket(buffer) => {
279 buffer.clear();
280 let _ = read_payload(&mut self.stream, buffer)?;
281 }
282 Action::ReadColumnMetadata { num_columns } => {
283 read_column_definition_packets(
284 &mut self.stream,
285 &mut self.buffer_set.column_definition_buffer,
286 num_columns,
287 )?;
288 }
289 Action::Finished => return Ok(()),
290 }
291 }
292 }
293
294 pub fn exec<'conn, P, H>(
295 &'conn mut self,
296 stmt: &'conn mut PreparedStatement,
297 params: P,
298 handler: &mut H,
299 ) -> Result<()>
300 where
301 P: Params,
302 H: BinaryResultSetHandler,
303 {
304 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
305 self.write_payload()?;
306 self.drive_exec(stmt, handler)
307 }
308
309 fn drive_bulk_exec<H: BinaryResultSetHandler>(
310 &mut self,
311 stmt: &mut PreparedStatement,
312 handler: &mut H,
313 ) -> Result<()> {
314 let cache_metadata = self
315 .mariadb_capabilities
316 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
317 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
318
319 loop {
320 match bulk_exec.step(&mut self.buffer_set)? {
321 Action::NeedPacket(buffer) => {
322 buffer.clear();
323 let _ = read_payload(&mut self.stream, buffer)?;
324 }
325 Action::ReadColumnMetadata { num_columns } => {
326 read_column_definition_packets(
327 &mut self.stream,
328 &mut self.buffer_set.column_definition_buffer,
329 num_columns,
330 )?;
331 }
332 Action::Finished => return Ok(()),
333 }
334 }
335 }
336
337 pub fn exec_bulk<P, I, H>(
339 &mut self,
340 stmt: &mut PreparedStatement,
341 params: P,
342 flags: BulkFlags,
343 handler: &mut H,
344 ) -> Result<()>
345 where
346 P: BulkParamsSet + IntoIterator<Item = I>,
347 I: Params,
348 H: BinaryResultSetHandler,
349 {
350 if !self.is_mariadb() {
351 for param in params {
353 self.exec_drop(stmt, param)?;
354 }
355 Ok(())
356 } else {
357 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
359 self.write_payload()?;
360 self.drive_bulk_exec(stmt, handler)
361 }
362 }
363
364 pub fn exec_first<'conn, P, H>(
371 &'conn mut self,
372 stmt: &'conn mut PreparedStatement,
373 params: P,
374 handler: &mut H,
375 ) -> Result<bool>
376 where
377 P: Params,
378 H: BinaryResultSetHandler,
379 {
380 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
381 self.write_payload()?;
382 let mut first_row_handler = FirstRowHandler::new(handler);
383 self.drive_exec(stmt, &mut first_row_handler)?;
384 Ok(first_row_handler.found_row)
385 }
386
387 pub fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
389 where
390 P: Params,
391 {
392 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
393 self.write_payload()?;
394 self.drive_exec(stmt, &mut DropHandler::default())
395 }
396
397 fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
398 let mut query = Query::new(handler);
399
400 loop {
401 match query.step(&mut self.buffer_set)? {
402 Action::NeedPacket(buffer) => {
403 buffer.clear();
404 let _ = read_payload(&mut self.stream, buffer)?;
405 }
406 Action::ReadColumnMetadata { num_columns } => {
407 read_column_definition_packets(
408 &mut self.stream,
409 &mut self.buffer_set.column_definition_buffer,
410 num_columns,
411 )?;
412 }
413 Action::Finished => return Ok(()),
414 }
415 }
416 }
417
418 pub fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
420 where
421 H: TextResultSetHandler,
422 {
423 write_query(self.buffer_set.new_write_buffer(), sql);
424 self.write_payload()?;
425 self.drive_query(handler)
426 }
427
428 pub fn query_drop(&mut self, sql: &str) -> Result<()> {
430 write_query(self.buffer_set.new_write_buffer(), sql);
431 self.write_payload()?;
432 self.drive_query(&mut DropHandler::default())
433 }
434
435 pub fn ping(&mut self) -> Result<()> {
439 write_ping(self.buffer_set.new_write_buffer());
440 self.write_payload()?;
441 self.buffer_set.read_buffer.clear();
442 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
443 Ok(())
444 }
445
446 pub fn reset(&mut self) -> Result<()> {
448 write_reset_connection(self.buffer_set.new_write_buffer());
449 self.write_payload()?;
450 self.buffer_set.read_buffer.clear();
451 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
452 self.in_transaction = false;
453 Ok(())
454 }
455
456 pub fn run_transaction<F, R>(&mut self, f: F) -> Result<R>
461 where
462 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
463 {
464 if self.in_transaction {
465 return Err(Error::NestedTransaction);
466 }
467
468 self.in_transaction = true;
469
470 if let Err(e) = self.query_drop("BEGIN") {
471 self.in_transaction = false;
472 return Err(e);
473 }
474
475 let tx = super::transaction::Transaction::new(self.connection_id());
476 let result = f(self, tx);
477
478 if self.in_transaction {
480 let rollback_result = self.query_drop("ROLLBACK");
481 self.in_transaction = false;
482
483 if let Err(e) = result {
485 return Err(e);
486 }
487 rollback_result?;
488 }
489
490 result
491 }
492}
493
494fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
497 buffer.clear();
498
499 let mut header = PacketHeader::new_zeroed();
500 reader.read_exact(header.as_mut_bytes())?;
501
502 let length = header.length();
503 let mut sequence_id = header.sequence_id;
504
505 buffer.reserve(length);
506
507 {
508 let spare = buffer.spare_capacity_mut();
509 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
510 reader.read_buf_exact(buf.unfilled())?;
511 unsafe {
513 buffer.set_len(length);
514 }
515 }
516
517 let mut current_length = length;
518 while current_length == 0xFFFFFF {
519 reader.read_exact(header.as_mut_bytes())?;
520
521 current_length = header.length();
522 sequence_id = header.sequence_id;
523
524 buffer.reserve(current_length);
525 let spare = buffer.spare_capacity_mut();
526 let mut buf: BorrowedBuf<'_> = (&mut spare[..current_length]).into();
527 reader.read_buf_exact(buf.unfilled())?;
528 unsafe {
530 buffer.set_len(buffer.len() + current_length);
531 }
532 }
533
534 Ok(sequence_id)
535}
536
537fn read_column_definition_packets(
538 reader: &mut Stream,
539 out: &mut Vec<u8>,
540 num_columns: usize,
541) -> Result<u8> {
542 out.clear();
543 let mut header = PacketHeader::new_zeroed();
544
545 for _ in 0..num_columns {
547 reader.read_exact(header.as_mut_bytes())?;
548 let length = header.length();
549 out.extend((length as u32).to_ne_bytes());
550
551 out.reserve(length);
552 let spare = out.spare_capacity_mut();
553 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
554 reader.read_buf_exact(buf.unfilled())?;
555 unsafe {
557 out.set_len(out.len() + length);
558 }
559 }
560
561 Ok(header.sequence_id)
562}
563
564fn write_handshake_payload(
565 stream: &mut Stream,
566 buffer_set: &mut BufferSet,
567 sequence_id: u8,
568) -> Result<()> {
569 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
570 let mut seq_id = sequence_id;
571
572 loop {
573 let chunk_size = buffer[4..].len().min(0xFFFFFF);
574 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
575 stream.write_all(&buffer[..4 + chunk_size])?;
576
577 if chunk_size < 0xFFFFFF {
578 break;
579 }
580
581 seq_id = seq_id.wrapping_add(1);
582 buffer = &mut buffer[0xFFFFFF..];
583 }
584 stream.flush()?;
585 Ok(())
586}
587
588struct SocketPathHandler {
590 path: Option<String>,
591}
592
593impl TextResultSetHandler for SocketPathHandler {
594 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
595 Ok(())
596 }
597 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
598 Ok(())
599 }
600 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
601 Ok(())
602 }
603 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
604 if row.0.first() == Some(&0xFB) {
606 return Ok(());
607 }
608 let (value, _) = read_string_lenenc(row.0)?;
610 if !value.is_empty() {
611 self.path = Some(String::from_utf8_lossy(value).into_owned());
612 }
613 Ok(())
614 }
615}