1use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::{Arc, Mutex, Weak};
21
22use crate::protocol::message::{backend::Message, frontend};
23use crate::types::Oid;
24use tracing::{trace, warn};
25
26use super::connection::RawConnection;
27use super::error::{Error, Result};
28use super::row::Row;
29use super::statement::Column;
30use super::sync_stream::SyncStream;
31
32pub trait SqlParam {
41 fn encode(&self) -> Vec<u8>;
43}
44
45impl SqlParam for i16 {
46 #[inline]
47 fn encode(&self) -> Vec<u8> {
48 self.to_le_bytes().to_vec()
49 }
50}
51
52impl SqlParam for i32 {
53 #[inline]
54 fn encode(&self) -> Vec<u8> {
55 self.to_le_bytes().to_vec()
56 }
57}
58
59impl SqlParam for i64 {
60 #[inline]
61 fn encode(&self) -> Vec<u8> {
62 self.to_le_bytes().to_vec()
63 }
64}
65
66impl SqlParam for f32 {
67 #[inline]
68 fn encode(&self) -> Vec<u8> {
69 self.to_le_bytes().to_vec()
70 }
71}
72
73impl SqlParam for f64 {
74 #[inline]
75 fn encode(&self) -> Vec<u8> {
76 self.to_le_bytes().to_vec()
77 }
78}
79
80impl SqlParam for bool {
81 #[inline]
82 fn encode(&self) -> Vec<u8> {
83 vec![u8::from(*self)]
84 }
85}
86
87impl SqlParam for &str {
88 #[inline]
89 fn encode(&self) -> Vec<u8> {
90 self.as_bytes().to_vec()
91 }
92}
93
94impl SqlParam for String {
95 #[inline]
96 fn encode(&self) -> Vec<u8> {
97 self.as_bytes().to_vec()
98 }
99}
100
101impl SqlParam for &String {
102 #[inline]
103 fn encode(&self) -> Vec<u8> {
104 self.as_bytes().to_vec()
105 }
106}
107
108impl SqlParam for Vec<u8> {
109 #[inline]
110 fn encode(&self) -> Vec<u8> {
111 self.clone()
112 }
113}
114
115impl SqlParam for &[u8] {
116 #[inline]
117 fn encode(&self) -> Vec<u8> {
118 self.to_vec()
119 }
120}
121
122#[macro_export]
140macro_rules! params {
141 () => {
142 &[] as &[Option<Vec<u8>>]
143 };
144 ($($val:expr),+ $(,)?) => {{
145 use $crate::client::prepare::SqlParam;
146 vec![$(Some($val.encode())),+]
147 }};
148}
149
150static STATEMENT_COUNTER: AtomicU64 = AtomicU64::new(0);
152
153fn generate_statement_name() -> String {
155 let id = STATEMENT_COUNTER.fetch_add(1, Ordering::Relaxed);
156 format!("__hyper_stmt_{id}")
157}
158
159#[derive(Debug)]
179pub struct PreparedStatement {
180 name: String,
182 query: String,
184 param_types: Vec<Oid>,
186 columns: Vec<Column>,
188}
189
190#[derive(Debug)]
209pub struct OwnedPreparedStatement {
210 statement: PreparedStatement,
212 connection: Weak<Mutex<RawConnection<SyncStream>>>,
214}
215
216impl OwnedPreparedStatement {
217 pub(crate) fn new(
219 statement: PreparedStatement,
220 connection: &Arc<Mutex<RawConnection<SyncStream>>>,
221 ) -> Self {
222 OwnedPreparedStatement {
223 statement,
224 connection: Arc::downgrade(connection),
225 }
226 }
227
228 #[must_use]
230 pub fn name(&self) -> &str {
231 self.statement.name()
232 }
233
234 #[must_use]
236 pub fn query(&self) -> &str {
237 self.statement.query()
238 }
239
240 #[must_use]
242 pub fn param_types(&self) -> &[Oid] {
243 self.statement.param_types()
244 }
245
246 #[must_use]
248 pub fn param_count(&self) -> usize {
249 self.statement.param_count()
250 }
251
252 #[must_use]
254 pub fn columns(&self) -> &[Column] {
255 self.statement.columns()
256 }
257
258 #[must_use]
260 pub fn column_count(&self) -> usize {
261 self.statement.column_count()
262 }
263
264 #[must_use]
266 pub fn statement(&self) -> &PreparedStatement {
267 &self.statement
268 }
269
270 pub fn close(self) -> Result<()> {
283 if let Some(conn) = self.connection.upgrade() {
284 close_statement(&conn, &self.statement)?;
285 }
286 std::mem::forget(self);
288 Ok(())
289 }
290}
291
292impl Drop for OwnedPreparedStatement {
293 fn drop(&mut self) {
294 if let Some(conn) = self.connection.upgrade() {
296 if let Err(e) = close_statement_internal(&conn, &self.statement) {
297 warn!(
298 target: "hyperdb_api",
299 statement_name = %self.statement.name,
300 error = %e,
301 "failed-to-close-prepared-statement-during-drop"
302 );
303 }
304 }
305 }
308}
309
310impl PreparedStatement {
311 #[must_use]
313 pub fn name(&self) -> &str {
314 &self.name
315 }
316
317 #[must_use]
319 pub fn query(&self) -> &str {
320 &self.query
321 }
322
323 #[must_use]
325 pub fn param_types(&self) -> &[Oid] {
326 &self.param_types
327 }
328
329 #[must_use]
331 pub fn param_count(&self) -> usize {
332 self.param_types.len()
333 }
334
335 #[must_use]
337 pub fn columns(&self) -> &[Column] {
338 &self.columns
339 }
340
341 #[must_use]
343 pub fn column_count(&self) -> usize {
344 self.columns.len()
345 }
346}
347
348pub fn prepare(
358 connection: &Arc<Mutex<RawConnection<SyncStream>>>,
359 query: &str,
360 param_types: &[Oid],
361) -> Result<PreparedStatement> {
362 let name = generate_statement_name();
363 let mut conn = connection
364 .lock()
365 .map_err(|_| Error::connection("connection mutex poisoned"))?;
366
367 frontend::parse(&name, query, param_types, conn.write_buf())?;
369
370 frontend::describe(b'S', &name, conn.write_buf())?;
372
373 frontend::sync(conn.write_buf());
375 conn.flush()?;
376
377 let mut parsed_params = Vec::new();
379 let mut parsed_columns = Vec::new();
380
381 loop {
382 let msg = conn.read_message()?;
383 match msg {
384 Message::ParseComplete => {
385 }
387 Message::ParameterDescription(desc) => {
388 for oid in desc.parameters().filter_map(|r| {
389 r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing parameter OID")).ok()
390 }) {
391 parsed_params.push(oid);
392 }
393 }
394 Message::RowDescription(desc) => {
395 for f in desc.fields().filter_map(|r| {
396 r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing row description field")).ok()
397 }) {
398 parsed_columns.push(Column::new(
399 f.name().to_string(),
400 f.type_oid(),
401 f.type_modifier(),
402 super::statement::ColumnFormat::from_code(f.format()),
403 ));
404 }
405 }
406 Message::NoData => {
407 }
409 Message::ReadyForQuery(_) => {
410 break;
411 }
412 Message::ErrorResponse(body) => {
413 return Err(conn.consume_error(&body));
414 }
415 _ => {}
416 }
417 }
418
419 Ok(PreparedStatement {
420 name,
421 query: query.to_string(),
422 param_types: parsed_params,
423 columns: parsed_columns,
424 })
425}
426
427pub fn execute_prepared(
439 connection: &Arc<Mutex<RawConnection<SyncStream>>>,
440 statement: &PreparedStatement,
441 params: &[Option<&[u8]>],
442) -> Result<Vec<Row>> {
443 let mut conn = connection
444 .lock()
445 .map_err(|_| Error::connection("connection mutex poisoned"))?;
446
447 let param_formats: Vec<i16> = vec![1; params.len()]; let result_formats: Vec<i16> = vec![1; statement.columns.len()]; frontend::bind(
452 "", &statement.name,
454 ¶m_formats,
455 params,
456 &result_formats,
457 conn.write_buf(),
458 )?;
459
460 frontend::execute("", 0, conn.write_buf())?; frontend::sync(conn.write_buf());
465 conn.flush()?;
466
467 let mut rows = Vec::new();
469 let columns = Arc::new(statement.columns.clone());
470
471 loop {
472 let msg = conn.read_message()?;
473 match msg {
474 Message::BindComplete => {
475 }
477 Message::DataRow(data) => {
478 rows.push(Row::new(Arc::clone(&columns), data)?);
479 }
480 Message::CommandComplete(_) => {
481 }
483 Message::EmptyQueryResponse => {
484 }
486 Message::ReadyForQuery(_) => {
487 break;
488 }
489 Message::ErrorResponse(body) => {
490 return Err(conn.consume_error(&body));
491 }
492 _ => {}
493 }
494 }
495
496 Ok(rows)
497}
498
499pub fn execute_prepared_no_result(
506 connection: &Arc<Mutex<RawConnection<SyncStream>>>,
507 statement: &PreparedStatement,
508 params: &[Option<&[u8]>],
509) -> Result<u64> {
510 let mut conn = connection
511 .lock()
512 .map_err(|_| Error::connection("connection mutex poisoned"))?;
513
514 let param_formats: Vec<i16> = vec![1; params.len()];
516 let result_formats: Vec<i16> = vec![];
517
518 frontend::bind(
519 "",
520 &statement.name,
521 ¶m_formats,
522 params,
523 &result_formats,
524 conn.write_buf(),
525 )?;
526
527 frontend::execute("", 0, conn.write_buf())?;
529
530 frontend::sync(conn.write_buf());
532 conn.flush()?;
533
534 let mut affected_rows = 0u64;
536
537 loop {
538 let msg = conn.read_message()?;
539 match msg {
540 Message::BindComplete => {}
541 Message::CommandComplete(body) => {
542 if let Ok(tag) = body.tag() {
543 affected_rows = parse_affected_rows(tag);
544 }
545 }
546 Message::EmptyQueryResponse => {}
547 Message::ReadyForQuery(_) => {
548 break;
549 }
550 Message::ErrorResponse(body) => {
551 return Err(conn.consume_error(&body));
552 }
553 _ => {}
554 }
555 }
556
557 Ok(affected_rows)
558}
559
560pub fn close_statement(
570 connection: &Arc<Mutex<RawConnection<SyncStream>>>,
571 statement: &PreparedStatement,
572) -> Result<()> {
573 close_statement_internal(connection, statement)
574}
575
576fn close_statement_internal(
578 connection: &Arc<Mutex<RawConnection<SyncStream>>>,
579 statement: &PreparedStatement,
580) -> Result<()> {
581 let mut conn = connection
582 .lock()
583 .map_err(|_| Error::connection("connection mutex poisoned"))?;
584
585 frontend::close(b'S', &statement.name, conn.write_buf())?;
587
588 frontend::sync(conn.write_buf());
590 conn.flush()?;
591
592 loop {
594 let msg = conn.read_message()?;
595 match msg {
596 Message::CloseComplete => {}
597 Message::ReadyForQuery(_) => {
598 break;
599 }
600 Message::ErrorResponse(body) => {
601 return Err(conn.consume_error(&body));
602 }
603 _ => {}
604 }
605 }
606
607 Ok(())
608}
609
610pub fn prepare_owned(
616 connection: &Arc<Mutex<RawConnection<SyncStream>>>,
617 query: &str,
618 param_types: &[Oid],
619) -> Result<OwnedPreparedStatement> {
620 let statement = prepare(connection, query, param_types)?;
621 Ok(OwnedPreparedStatement::new(statement, connection))
622}
623
624fn parse_affected_rows(tag: &str) -> u64 {
626 let parts: Vec<&str> = tag.split_whitespace().collect();
627
628 match parts.first() {
629 Some(&"INSERT") => parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0),
630 Some(&"UPDATE" | &"DELETE" | &"SELECT" | &"COPY") => {
631 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0)
632 }
633 _ => 0,
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn test_sql_param_i16() {
643 assert_eq!(0_i16.encode(), vec![0, 0]);
644 assert_eq!(1_i16.encode(), vec![1, 0]);
645 assert_eq!((-1_i16).encode(), vec![255, 255]);
646 }
647
648 #[test]
649 fn test_sql_param_i32() {
650 assert_eq!(0_i32.encode(), vec![0, 0, 0, 0]);
651 assert_eq!(1_i32.encode(), vec![1, 0, 0, 0]);
652 assert_eq!((-1_i32).encode(), vec![255, 255, 255, 255]);
653 assert_eq!(256_i32.encode(), vec![0, 1, 0, 0]);
654 }
655
656 #[test]
657 fn test_sql_param_i64() {
658 assert_eq!(0_i64.encode(), vec![0, 0, 0, 0, 0, 0, 0, 0]);
659 assert_eq!(1_i64.encode(), vec![1, 0, 0, 0, 0, 0, 0, 0]);
660 assert_eq!(
661 (-1_i64).encode(),
662 vec![255, 255, 255, 255, 255, 255, 255, 255]
663 );
664 }
665
666 #[test]
667 #[expect(
668 clippy::float_cmp,
669 reason = "1.5 is exactly representable; encode/decode must round-trip bit-for-bit"
670 )]
671 fn test_sql_param_f32() {
672 let encoded = 1.5_f32.encode();
673 assert_eq!(encoded.len(), 4);
674 let decoded = f32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
675 assert_eq!(decoded, 1.5);
676 }
677
678 #[test]
679 #[expect(
680 clippy::float_cmp,
681 reason = "1.5 is exactly representable; encode/decode must round-trip bit-for-bit"
682 )]
683 fn test_sql_param_f64() {
684 let encoded = 1.5_f64.encode();
685 assert_eq!(encoded.len(), 8);
686 let decoded = f64::from_le_bytes([
687 encoded[0], encoded[1], encoded[2], encoded[3], encoded[4], encoded[5], encoded[6],
688 encoded[7],
689 ]);
690 assert_eq!(decoded, 1.5);
691 }
692
693 #[test]
694 fn test_sql_param_bool() {
695 assert_eq!(true.encode(), vec![1]);
696 assert_eq!(false.encode(), vec![0]);
697 }
698
699 #[test]
700 fn test_sql_param_str() {
701 assert_eq!("hello".encode(), b"hello".to_vec());
702 assert_eq!("".encode(), Vec::<u8>::new());
703 assert_eq!("héllo".encode(), "héllo".as_bytes().to_vec());
704 }
705
706 #[test]
707 fn test_sql_param_string() {
708 let s = String::from("hello");
709 assert_eq!(s.encode(), b"hello".to_vec());
710 assert_eq!(s.encode(), b"hello".to_vec());
711 }
712
713 #[test]
714 fn test_sql_param_bytes() {
715 let bytes: Vec<u8> = vec![1, 2, 3, 4];
716 assert_eq!(bytes.encode(), vec![1, 2, 3, 4]);
717 assert_eq!(bytes.as_slice().encode(), vec![1, 2, 3, 4]);
718 }
719
720 #[test]
721 fn test_params_macro_empty() {
722 let p = params![];
723 assert!(p.is_empty());
724 }
725
726 #[test]
727 fn test_params_macro_single() {
728 let p = params![42_i32];
729 assert_eq!(p.len(), 1);
730 assert_eq!(p[0], Some(vec![42, 0, 0, 0]));
731 }
732
733 #[test]
734 fn test_params_macro_multiple() {
735 let p = params![42_i32, "hello", true];
736 assert_eq!(p.len(), 3);
737 assert_eq!(p[0], Some(vec![42, 0, 0, 0]));
738 assert_eq!(p[1], Some(b"hello".to_vec()));
739 assert_eq!(p[2], Some(vec![1]));
740 }
741}