fraiseql_wire/connection/conn/
core.rs1use super::config::ConnectionConfig;
4use super::helpers::extract_entity_from_query;
5use crate::auth::ScramClient;
6use crate::connection::state::ConnectionState;
7use crate::connection::transport::Transport;
8use crate::protocol::{
9 decode_message, encode_message, AuthenticationMessage, BackendMessage, FrontendMessage,
10};
11use crate::{Error, Result};
12use bytes::{Buf, BytesMut};
13use std::sync::atomic::{AtomicU64, Ordering};
14use tracing::Instrument;
15
16static CHUNK_COUNT: AtomicU64 = AtomicU64::new(0);
19
20pub struct Connection {
22 transport: Transport,
23 state: ConnectionState,
24 read_buf: BytesMut,
25 process_id: Option<i32>,
26 secret_key: Option<i32>,
27}
28
29impl Connection {
30 pub fn new(transport: Transport) -> Self {
32 Self {
33 transport,
34 state: ConnectionState::Initial,
35 read_buf: BytesMut::with_capacity(8192),
36 process_id: None,
37 secret_key: None,
38 }
39 }
40
41 pub const fn state(&self) -> ConnectionState {
43 self.state
44 }
45
46 pub async fn startup(&mut self, config: &ConnectionConfig) -> Result<()> {
54 async {
55 self.state.transition(ConnectionState::AwaitingAuth)?;
56
57 let mut params = vec![
59 ("user".to_string(), config.user.clone()),
60 ("database".to_string(), config.database.clone()),
61 ];
62
63 if let Some(app_name) = &config.application_name {
65 params.push(("application_name".to_string(), app_name.clone()));
66 }
67
68 if let Some(timeout) = config.statement_timeout {
70 params.push((
71 "statement_timeout".to_string(),
72 timeout.as_millis().to_string(),
73 ));
74 }
75
76 if let Some(digits) = config.extra_float_digits {
78 params.push(("extra_float_digits".to_string(), digits.to_string()));
79 }
80
81 for (k, v) in &config.params {
83 params.push((k.clone(), v.clone()));
84 }
85
86 let startup = FrontendMessage::Startup {
88 version: crate::protocol::constants::PROTOCOL_VERSION,
89 params,
90 };
91 self.send_message(&startup).await?;
92
93 self.state.transition(ConnectionState::Authenticating)?;
95 self.authenticate(config).await?;
96
97 self.state.transition(ConnectionState::Idle)?;
98 tracing::info!("startup complete");
99 Ok(())
100 }
101 .instrument(tracing::info_span!(
102 "startup",
103 user = %config.user,
104 database = %config.database
105 ))
106 .await
107 }
108
109 async fn authenticate(&mut self, config: &ConnectionConfig) -> Result<()> {
111 let auth_start = std::time::Instant::now();
112 let mut auth_mechanism = "unknown";
113
114 loop {
115 let msg = self.receive_message().await?;
116
117 match msg {
118 BackendMessage::Authentication(auth) => match auth {
119 AuthenticationMessage::Ok => {
120 tracing::debug!("authentication successful");
121 crate::metrics::counters::auth_successful(auth_mechanism);
122 crate::metrics::histograms::auth_duration(
123 auth_mechanism,
124 auth_start.elapsed().as_millis() as u64,
125 );
126 }
128 AuthenticationMessage::CleartextPassword => {
129 auth_mechanism = crate::metrics::labels::MECHANISM_CLEARTEXT;
130 crate::metrics::counters::auth_attempted(auth_mechanism);
131
132 let password = config
133 .password
134 .as_ref()
135 .ok_or_else(|| Error::Authentication("password required".into()))?;
136 let pwd_msg = FrontendMessage::Password(password.as_str().to_string());
138 self.send_message(&pwd_msg).await?;
139 }
140 AuthenticationMessage::Md5Password { .. } => {
141 return Err(Error::Authentication(
142 "MD5 authentication not supported. Use SCRAM-SHA-256 or cleartext password".into(),
143 ));
144 }
145 AuthenticationMessage::Sasl { mechanisms } => {
146 auth_mechanism = crate::metrics::labels::MECHANISM_SCRAM;
147 crate::metrics::counters::auth_attempted(auth_mechanism);
148 self.handle_sasl(&mechanisms, config).await?;
149 }
150 AuthenticationMessage::SaslContinue { .. } => {
151 return Err(Error::Protocol(
152 "unexpected SaslContinue outside of SASL flow".into(),
153 ));
154 }
155 AuthenticationMessage::SaslFinal { .. } => {
156 return Err(Error::Protocol(
157 "unexpected SaslFinal outside of SASL flow".into(),
158 ));
159 }
160 },
161 BackendMessage::BackendKeyData {
162 process_id,
163 secret_key,
164 } => {
165 self.process_id = Some(process_id);
166 self.secret_key = Some(secret_key);
167 }
168 BackendMessage::ParameterStatus { name, value } => {
169 tracing::debug!("parameter status: {} = {}", name, value);
170 }
171 BackendMessage::ReadyForQuery { status: _ } => {
172 break;
173 }
174 BackendMessage::ErrorResponse(err) => {
175 crate::metrics::counters::auth_failed(auth_mechanism, "server_error");
176 return Err(Error::Authentication(err.to_string()));
177 }
178 _ => {
179 return Err(Error::Protocol(format!(
180 "unexpected message during auth: {:?}",
181 msg
182 )));
183 }
184 }
185 }
186
187 Ok(())
188 }
189
190 async fn handle_sasl(
192 &mut self,
193 mechanisms: &[String],
194 config: &ConnectionConfig,
195 ) -> Result<()> {
196 if !mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
198 return Err(Error::Authentication(format!(
199 "server does not support SCRAM-SHA-256. Available: {}",
200 mechanisms.join(", ")
201 )));
202 }
203
204 let password = config.password.as_ref().ok_or_else(|| {
206 Error::Authentication("password required for SCRAM authentication".into())
207 })?;
208
209 let mut scram = ScramClient::new(config.user.clone(), password.as_str().to_string());
212 tracing::debug!("initiating SCRAM-SHA-256 authentication");
213
214 let client_first = scram.client_first();
216 let msg = FrontendMessage::SaslInitialResponse {
217 mechanism: "SCRAM-SHA-256".to_string(),
218 data: client_first.into_bytes(),
219 };
220 self.send_message(&msg).await?;
221
222 let server_first_msg = self.receive_message().await?;
224 let server_first_data = match server_first_msg {
225 BackendMessage::Authentication(AuthenticationMessage::SaslContinue { data }) => data,
226 BackendMessage::ErrorResponse(err) => {
227 return Err(Error::Authentication(format!("SASL server error: {}", err)));
228 }
229 _ => {
230 return Err(Error::Protocol(
231 "expected SaslContinue message during SASL authentication".into(),
232 ));
233 }
234 };
235
236 let server_first = String::from_utf8(server_first_data).map_err(|e| {
237 Error::Authentication(format!("invalid UTF-8 in server first message: {}", e))
238 })?;
239
240 tracing::debug!("received SCRAM server first message");
241
242 let (client_final, scram_state) = scram
244 .client_final(&server_first)
245 .map_err(|e| Error::Authentication(format!("SCRAM error: {}", e)))?;
246
247 let msg = FrontendMessage::SaslResponse {
249 data: client_final.into_bytes(),
250 };
251 self.send_message(&msg).await?;
252
253 let server_final_msg = self.receive_message().await?;
255 let server_final_data = match server_final_msg {
256 BackendMessage::Authentication(AuthenticationMessage::SaslFinal { data }) => data,
257 BackendMessage::ErrorResponse(err) => {
258 return Err(Error::Authentication(format!("SASL server error: {}", err)));
259 }
260 _ => {
261 return Err(Error::Protocol(
262 "expected SaslFinal message during SASL authentication".into(),
263 ));
264 }
265 };
266
267 let server_final = String::from_utf8(server_final_data).map_err(|e| {
268 Error::Authentication(format!("invalid UTF-8 in server final message: {}", e))
269 })?;
270
271 scram
273 .verify_server_final(&server_final, &scram_state)
274 .map_err(|e| Error::Authentication(format!("SCRAM verification failed: {}", e)))?;
275
276 tracing::debug!("SCRAM-SHA-256 authentication successful");
277 Ok(())
278 }
279
280 pub async fn simple_query(&mut self, query: &str) -> Result<Vec<BackendMessage>> {
288 if self.state != ConnectionState::Idle {
289 return Err(Error::ConnectionBusy(format!(
290 "connection in state: {}",
291 self.state
292 )));
293 }
294
295 self.state.transition(ConnectionState::QueryInProgress)?;
296
297 let query_msg = FrontendMessage::Query(query.to_string());
298 self.send_message(&query_msg).await?;
299
300 self.state.transition(ConnectionState::ReadingResults)?;
301
302 let mut messages = Vec::new();
303
304 loop {
305 let msg = self.receive_message().await?;
306 let is_ready = matches!(msg, BackendMessage::ReadyForQuery { .. });
307 messages.push(msg);
308
309 if is_ready {
310 break;
311 }
312 }
313
314 self.state.transition(ConnectionState::Idle)?;
315 Ok(messages)
316 }
317
318 async fn send_message(&mut self, msg: &FrontendMessage) -> Result<()> {
320 let buf = encode_message(msg)?;
321 self.transport.write_all(&buf).await?;
322 self.transport.flush().await?;
323 Ok(())
324 }
325
326 async fn receive_message(&mut self) -> Result<BackendMessage> {
328 loop {
329 if let Ok((msg, consumed)) = decode_message(&mut self.read_buf) {
331 self.read_buf.advance(consumed);
332 return Ok(msg);
333 }
334
335 let n = self.transport.read_buf(&mut self.read_buf).await?;
337 if n == 0 {
338 return Err(Error::ConnectionClosed);
339 }
340 }
341 }
342
343 pub async fn close(mut self) -> Result<()> {
350 self.state.transition(ConnectionState::Closed)?;
351 let _ = self.send_message(&FrontendMessage::Terminate).await;
352 self.transport.shutdown().await?;
353 Ok(())
354 }
355
356 #[allow(clippy::too_many_arguments)] pub async fn streaming_query(
371 mut self,
372 query: &str,
373 chunk_size: usize,
374 max_memory: Option<usize>,
375 soft_limit_warn_threshold: Option<f32>,
376 soft_limit_fail_threshold: Option<f32>,
377 enable_adaptive_chunking: bool,
378 adaptive_min_chunk_size: Option<usize>,
379 adaptive_max_chunk_size: Option<usize>,
380 ) -> Result<crate::stream::JsonStream> {
381 async {
382 let startup_start = std::time::Instant::now();
383
384 use crate::json::validate_row_description;
385 use crate::stream::{extract_json_bytes, parse_json, AdaptiveChunking, ChunkingStrategy, JsonStream};
386 use serde_json::Value;
387 use tokio::sync::mpsc;
388
389 if self.state != ConnectionState::Idle {
390 return Err(Error::ConnectionBusy(format!(
391 "connection in state: {}",
392 self.state
393 )));
394 }
395
396 self.state.transition(ConnectionState::QueryInProgress)?;
397
398 let query_msg = FrontendMessage::Query(query.to_string());
399 self.send_message(&query_msg).await?;
400
401 self.state.transition(ConnectionState::ReadingResults)?;
402
403 let row_desc;
406 loop {
407 let msg = self.receive_message().await?;
408
409 match msg {
410 BackendMessage::ErrorResponse(err) => {
411 tracing::debug!("PostgreSQL error response: {}", err);
413 loop {
414 let msg = self.receive_message().await?;
415 if matches!(msg, BackendMessage::ReadyForQuery { .. }) {
416 break;
417 }
418 }
419 return Err(Error::Sql(err.to_string()));
420 }
421 BackendMessage::BackendKeyData { process_id, secret_key: _ } => {
422 tracing::debug!("PostgreSQL backend key data received: pid={}", process_id);
424 continue;
426 }
427 BackendMessage::ParameterStatus { .. } => {
428 tracing::debug!("PostgreSQL parameter status change received");
430 continue;
431 }
432 BackendMessage::NoticeResponse(notice) => {
433 tracing::debug!("PostgreSQL notice: {}", notice);
435 continue;
436 }
437 BackendMessage::RowDescription(_) => {
438 row_desc = msg;
439 break;
440 }
441 BackendMessage::ReadyForQuery { .. } => {
442 return Err(Error::Protocol(
445 "no result set received from query - \
446 check that the entity name is correct and the table/view exists"
447 .into(),
448 ));
449 }
450 _ => {
451 return Err(Error::Protocol(format!(
452 "unexpected message type in query response: {:?}",
453 msg
454 )));
455 }
456 }
457 }
458
459 validate_row_description(&row_desc)?;
460
461 let startup_duration = startup_start.elapsed().as_millis() as u64;
463 let entity = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
464 crate::metrics::histograms::query_startup_duration(&entity, startup_duration);
465
466 let (result_tx, result_rx) = mpsc::channel::<Result<Value>>(chunk_size);
468 let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
469
470 let entity_for_metrics = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
472 let entity_for_stream = entity_for_metrics.clone(); let stream = JsonStream::new(
475 result_rx,
476 cancel_tx,
477 entity_for_stream,
478 max_memory,
479 soft_limit_warn_threshold,
480 soft_limit_fail_threshold,
481 );
482
483 let state_lock = stream.clone_state();
485 let pause_signal = stream.clone_pause_signal();
486 let resume_signal = stream.clone_resume_signal();
487
488 let state_atomic = stream.clone_state_atomic();
490
491 let pause_timeout = stream.pause_timeout();
493
494 let query_start = std::time::Instant::now();
496
497 tokio::spawn(async move {
498 let strategy = ChunkingStrategy::new(chunk_size);
499 let mut chunk = strategy.new_chunk();
500 let mut total_rows = 0u64;
501
502 let _adaptive = if enable_adaptive_chunking {
504 let mut adp = AdaptiveChunking::new();
505
506 if let Some(min) = adaptive_min_chunk_size {
508 if let Some(max) = adaptive_max_chunk_size {
509 adp = adp.with_bounds(min, max);
510 }
511 }
512
513 Some(adp)
514 } else {
515 None
516 };
517 let _current_chunk_size = chunk_size;
518
519 loop {
520 if state_lock.is_some() && state_atomic.load(std::sync::atomic::Ordering::Acquire) == 1 {
523 if let (Some(ref state_lock), Some(ref _pause_signal), Some(ref resume_signal)) =
525 (&state_lock, &pause_signal, &resume_signal)
526 {
527 let current_state = state_lock.lock().await;
528 if *current_state == crate::stream::StreamState::Paused {
529 tracing::debug!("stream paused, waiting for resume");
530 drop(current_state); if let Some(timeout) = pause_timeout {
534 if tokio::time::timeout(timeout, resume_signal.notified()).await == Ok(()) {
535 tracing::debug!("stream resumed");
536 } else {
537 tracing::debug!("pause timeout expired, auto-resuming");
538 crate::metrics::counters::stream_pause_timeout_expired(&entity_for_metrics);
539 }
540 } else {
541 resume_signal.notified().await;
543 tracing::debug!("stream resumed");
544 }
545
546 let mut state = state_lock.lock().await;
548 *state = crate::stream::StreamState::Running;
549 }
550 }
551 }
552
553 tokio::select! {
554 _ = cancel_rx.recv() => {
556 tracing::debug!("query cancelled");
557 crate::metrics::counters::query_completed("cancelled", &entity_for_metrics);
558 break;
559 }
560
561 msg_result = self.receive_message() => {
563 match msg_result {
564 Ok(msg) => match msg {
565 BackendMessage::DataRow(_) => {
566 match extract_json_bytes(&msg) {
567 Ok(json_bytes) => {
568 chunk.push(json_bytes);
569
570 if strategy.is_full(&chunk) {
571 let chunk_start = std::time::Instant::now();
572 let rows = chunk.into_rows();
573 let chunk_size_rows = rows.len() as u64;
574
575 const BATCH_SIZE: usize = 8;
578 let mut batch = Vec::with_capacity(BATCH_SIZE);
579 let mut send_error = false;
580
581 for row_bytes in rows {
582 match parse_json(row_bytes) {
583 Ok(value) => {
584 total_rows += 1;
585 batch.push(Ok(value));
586
587 if batch.len() == BATCH_SIZE {
589 for item in batch.drain(..) {
590 if result_tx.send(item).await.is_err() {
591 crate::metrics::counters::query_completed("error", &entity_for_metrics);
592 send_error = true;
593 break;
594 }
595 }
596 if send_error {
597 break;
598 }
599 }
600 }
601 Err(e) => {
602 crate::metrics::counters::json_parse_error(&entity_for_metrics);
603 let _ = result_tx.send(Err(e)).await;
604 crate::metrics::counters::query_completed("error", &entity_for_metrics);
605 send_error = true;
606 break;
607 }
608 }
609 }
610
611 if !send_error {
613 for item in batch {
614 if result_tx.send(item).await.is_err() {
615 crate::metrics::counters::query_completed("error", &entity_for_metrics);
616 break;
617 }
618 }
619 }
620
621 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
623
624 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
626 if chunk_idx.is_multiple_of(10) {
627 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
628 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
629 }
630
631 chunk = strategy.new_chunk();
637 }
638 }
639 Err(e) => {
640 crate::metrics::counters::json_parse_error(&entity_for_metrics);
641 let _ = result_tx.send(Err(e)).await;
642 crate::metrics::counters::query_completed("error", &entity_for_metrics);
643 break;
644 }
645 }
646 }
647 BackendMessage::CommandComplete(_) => {
648 if !chunk.is_empty() {
650 let chunk_start = std::time::Instant::now();
651 let rows = chunk.into_rows();
652 let chunk_size_rows = rows.len() as u64;
653
654 const BATCH_SIZE: usize = 8;
656 let mut batch = Vec::with_capacity(BATCH_SIZE);
657 let mut send_error = false;
658
659 for row_bytes in rows {
660 match parse_json(row_bytes) {
661 Ok(value) => {
662 total_rows += 1;
663 batch.push(Ok(value));
664
665 if batch.len() == BATCH_SIZE {
667 for item in batch.drain(..) {
668 if result_tx.send(item).await.is_err() {
669 crate::metrics::counters::query_completed("error", &entity_for_metrics);
670 send_error = true;
671 break;
672 }
673 }
674 if send_error {
675 break;
676 }
677 }
678 }
679 Err(e) => {
680 crate::metrics::counters::json_parse_error(&entity_for_metrics);
681 let _ = result_tx.send(Err(e)).await;
682 crate::metrics::counters::query_completed("error", &entity_for_metrics);
683 send_error = true;
684 break;
685 }
686 }
687 }
688
689 if !send_error {
691 for item in batch {
692 if result_tx.send(item).await.is_err() {
693 crate::metrics::counters::query_completed("error", &entity_for_metrics);
694 break;
695 }
696 }
697 }
698
699 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
701 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
702 if chunk_idx.is_multiple_of(10) {
703 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
704 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
705 }
706 chunk = strategy.new_chunk();
707 }
708
709 let query_duration = query_start.elapsed().as_millis() as u64;
711 crate::metrics::counters::rows_processed(&entity_for_metrics, total_rows, "ok");
712 crate::metrics::histograms::query_total_duration(&entity_for_metrics, query_duration);
713 crate::metrics::counters::query_completed("success", &entity_for_metrics);
714 }
715 BackendMessage::ReadyForQuery { .. } => {
716 break;
717 }
718 BackendMessage::ErrorResponse(err) => {
719 crate::metrics::counters::query_error(&entity_for_metrics, "server_error");
720 crate::metrics::counters::query_completed("error", &entity_for_metrics);
721 let _ = result_tx.send(Err(Error::Sql(err.to_string()))).await;
722 break;
723 }
724 _ => {
725 crate::metrics::counters::query_error(&entity_for_metrics, "protocol_error");
726 crate::metrics::counters::query_completed("error", &entity_for_metrics);
727 let _ = result_tx.send(Err(Error::Protocol(
728 format!("unexpected message: {:?}", msg)
729 ))).await;
730 break;
731 }
732 },
733 Err(e) => {
734 crate::metrics::counters::query_error(&entity_for_metrics, "connection_error");
735 crate::metrics::counters::query_completed("error", &entity_for_metrics);
736 let _ = result_tx.send(Err(e)).await;
737 break;
738 }
739 }
740 }
741 }
742 }
743 });
744
745 Ok(stream)
746 }
747 .instrument(tracing::debug_span!(
748 "streaming_query",
749 query = %query,
750 chunk_size = %chunk_size
751 ))
752 .await
753 }
754}