fraiseql_wire/connection/conn/
core.rs1use super::config::ConnectionConfig;
4use super::helpers::extract_entity_from_query;
5use crate::auth::ScramClient;
6use crate::connection::state::ConnectionState;
7use crate::connection::transport::Transport;
8use crate::protocol::{
9 decode_message, encode_message, AuthenticationMessage, BackendMessage, FrontendMessage,
10};
11use crate::{Result, WireError};
12use bytes::{Buf, BytesMut};
13use std::sync::atomic::{AtomicU64, Ordering};
14use tracing::Instrument;
15
16static CHUNK_COUNT: AtomicU64 = AtomicU64::new(0);
19
20pub struct Connection {
22 transport: Transport,
23 state: ConnectionState,
24 read_buf: BytesMut,
25 process_id: Option<i32>,
26 secret_key: Option<i32>,
27}
28
29impl Connection {
30 pub fn new(transport: Transport) -> Self {
32 Self {
33 transport,
34 state: ConnectionState::Initial,
35 read_buf: BytesMut::with_capacity(8192),
36 process_id: None,
37 secret_key: None,
38 }
39 }
40
41 pub const fn state(&self) -> ConnectionState {
43 self.state
44 }
45
46 pub async fn startup(&mut self, config: &ConnectionConfig) -> Result<()> {
54 async {
55 self.state.transition(ConnectionState::AwaitingAuth)?;
56
57 let mut params = vec![
59 ("user".to_string(), config.user.clone()),
60 ("database".to_string(), config.database.clone()),
61 ];
62
63 if let Some(app_name) = &config.application_name {
65 params.push(("application_name".to_string(), app_name.clone()));
66 }
67
68 if let Some(timeout) = config.statement_timeout {
70 params.push((
71 "statement_timeout".to_string(),
72 timeout.as_millis().to_string(),
73 ));
74 }
75
76 if let Some(digits) = config.extra_float_digits {
78 params.push(("extra_float_digits".to_string(), digits.to_string()));
79 }
80
81 for (k, v) in &config.params {
83 params.push((k.clone(), v.clone()));
84 }
85
86 let startup = FrontendMessage::Startup {
88 version: crate::protocol::constants::PROTOCOL_VERSION,
89 params,
90 };
91 self.send_message(&startup).await?;
92
93 self.state.transition(ConnectionState::Authenticating)?;
95 self.authenticate(config).await?;
96
97 self.state.transition(ConnectionState::Idle)?;
98 tracing::info!("startup complete");
99 Ok(())
100 }
101 .instrument(tracing::info_span!(
102 "startup",
103 user = %config.user,
104 database = %config.database
105 ))
106 .await
107 }
108
109 async fn authenticate(&mut self, config: &ConnectionConfig) -> Result<()> {
111 let auth_start = std::time::Instant::now();
112 let mut auth_mechanism = "unknown";
113
114 loop {
115 let msg = self.receive_message().await?;
116
117 match msg {
118 BackendMessage::Authentication(auth) => match auth {
119 AuthenticationMessage::Ok => {
120 tracing::debug!("authentication successful");
121 crate::metrics::counters::auth_successful(auth_mechanism);
122 crate::metrics::histograms::auth_duration(
123 auth_mechanism,
124 auth_start.elapsed().as_millis() as u64,
125 );
126 }
128 AuthenticationMessage::CleartextPassword => {
129 auth_mechanism = crate::metrics::labels::MECHANISM_CLEARTEXT;
130 crate::metrics::counters::auth_attempted(auth_mechanism);
131
132 let password = config
133 .password
134 .as_ref()
135 .ok_or_else(|| WireError::Authentication("password required".into()))?;
136 let pwd_msg = FrontendMessage::Password(password.as_str().to_string());
138 self.send_message(&pwd_msg).await?;
139 }
140 AuthenticationMessage::Md5Password { .. } => {
141 return Err(WireError::Authentication(
142 "MD5 authentication not supported. Use SCRAM-SHA-256 or cleartext password".into(),
143 ));
144 }
145 AuthenticationMessage::Sasl { mechanisms } => {
146 auth_mechanism = crate::metrics::labels::MECHANISM_SCRAM;
147 crate::metrics::counters::auth_attempted(auth_mechanism);
148 self.handle_sasl(&mechanisms, config).await?;
149 }
150 AuthenticationMessage::SaslContinue { .. } => {
151 return Err(WireError::Protocol(
152 "unexpected SaslContinue outside of SASL flow".into(),
153 ));
154 }
155 AuthenticationMessage::SaslFinal { .. } => {
156 return Err(WireError::Protocol(
157 "unexpected SaslFinal outside of SASL flow".into(),
158 ));
159 }
160 },
161 BackendMessage::BackendKeyData {
162 process_id,
163 secret_key,
164 } => {
165 self.process_id = Some(process_id);
166 self.secret_key = Some(secret_key);
167 }
168 BackendMessage::ParameterStatus { name, value } => {
169 tracing::debug!("parameter status: {} = {}", name, value);
170 }
171 BackendMessage::ReadyForQuery { status: _ } => {
172 break;
173 }
174 BackendMessage::ErrorResponse(err) => {
175 crate::metrics::counters::auth_failed(auth_mechanism, "server_error");
176 return Err(WireError::Authentication(err.to_string()));
177 }
178 _ => {
179 return Err(WireError::Protocol(format!(
180 "unexpected message during auth: {:?}",
181 msg
182 )));
183 }
184 }
185 }
186
187 Ok(())
188 }
189
190 async fn handle_sasl(
192 &mut self,
193 mechanisms: &[String],
194 config: &ConnectionConfig,
195 ) -> Result<()> {
196 if !mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
198 return Err(WireError::Authentication(format!(
199 "server does not support SCRAM-SHA-256. Available: {}",
200 mechanisms.join(", ")
201 )));
202 }
203
204 let password = config.password.as_ref().ok_or_else(|| {
206 WireError::Authentication("password required for SCRAM authentication".into())
207 })?;
208
209 let mut scram = ScramClient::new(config.user.clone(), password.as_str().to_string());
212 tracing::debug!("initiating SCRAM-SHA-256 authentication");
213
214 let client_first = scram.client_first();
216 let msg = FrontendMessage::SaslInitialResponse {
217 mechanism: "SCRAM-SHA-256".to_string(),
218 data: client_first.into_bytes(),
219 };
220 self.send_message(&msg).await?;
221
222 let server_first_msg = self.receive_message().await?;
224 let server_first_data = match server_first_msg {
225 BackendMessage::Authentication(AuthenticationMessage::SaslContinue { data }) => data,
226 BackendMessage::ErrorResponse(err) => {
227 return Err(WireError::Authentication(format!(
228 "SASL server error: {}",
229 err
230 )));
231 }
232 _ => {
233 return Err(WireError::Protocol(
234 "expected SaslContinue message during SASL authentication".into(),
235 ));
236 }
237 };
238
239 let server_first = String::from_utf8(server_first_data).map_err(|e| {
240 WireError::Authentication(format!("invalid UTF-8 in server first message: {}", e))
241 })?;
242
243 tracing::debug!("received SCRAM server first message");
244
245 let (client_final, scram_state) = scram
247 .client_final(&server_first)
248 .map_err(|e| WireError::Authentication(format!("SCRAM error: {}", e)))?;
249
250 let msg = FrontendMessage::SaslResponse {
252 data: client_final.into_bytes(),
253 };
254 self.send_message(&msg).await?;
255
256 let server_final_msg = self.receive_message().await?;
258 let server_final_data = match server_final_msg {
259 BackendMessage::Authentication(AuthenticationMessage::SaslFinal { data }) => data,
260 BackendMessage::ErrorResponse(err) => {
261 return Err(WireError::Authentication(format!(
262 "SASL server error: {}",
263 err
264 )));
265 }
266 _ => {
267 return Err(WireError::Protocol(
268 "expected SaslFinal message during SASL authentication".into(),
269 ));
270 }
271 };
272
273 let server_final = String::from_utf8(server_final_data).map_err(|e| {
274 WireError::Authentication(format!("invalid UTF-8 in server final message: {}", e))
275 })?;
276
277 scram
279 .verify_server_final(&server_final, &scram_state)
280 .map_err(|e| WireError::Authentication(format!("SCRAM verification failed: {}", e)))?;
281
282 tracing::debug!("SCRAM-SHA-256 authentication successful");
283 Ok(())
284 }
285
286 pub async fn simple_query(&mut self, query: &str) -> Result<Vec<BackendMessage>> {
294 if self.state != ConnectionState::Idle {
295 return Err(WireError::ConnectionBusy(format!(
296 "connection in state: {}",
297 self.state
298 )));
299 }
300
301 self.state.transition(ConnectionState::QueryInProgress)?;
302
303 let query_msg = FrontendMessage::Query(query.to_string());
304 self.send_message(&query_msg).await?;
305
306 self.state.transition(ConnectionState::ReadingResults)?;
307
308 let mut messages = Vec::new();
309
310 loop {
311 let msg = self.receive_message().await?;
312 let is_ready = matches!(msg, BackendMessage::ReadyForQuery { .. });
313 messages.push(msg);
314
315 if is_ready {
316 break;
317 }
318 }
319
320 self.state.transition(ConnectionState::Idle)?;
321 Ok(messages)
322 }
323
324 async fn send_message(&mut self, msg: &FrontendMessage) -> Result<()> {
326 let buf = encode_message(msg)?;
327 self.transport.write_all(&buf).await?;
328 self.transport.flush().await?;
329 Ok(())
330 }
331
332 async fn receive_message(&mut self) -> Result<BackendMessage> {
334 loop {
335 if let Ok((msg, consumed)) = decode_message(&mut self.read_buf) {
337 self.read_buf.advance(consumed);
338 return Ok(msg);
339 }
340
341 let n = self.transport.read_buf(&mut self.read_buf).await?;
343 if n == 0 {
344 return Err(WireError::ConnectionClosed);
345 }
346 }
347 }
348
349 pub async fn close(mut self) -> Result<()> {
356 self.state.transition(ConnectionState::Closed)?;
357 let _ = self.send_message(&FrontendMessage::Terminate).await;
358 self.transport.shutdown().await?;
359 Ok(())
360 }
361
362 #[allow(clippy::too_many_arguments)] pub async fn streaming_query(
374 mut self,
375 query: &str,
376 chunk_size: usize,
377 max_memory: Option<usize>,
378 soft_limit_warn_threshold: Option<f32>,
379 soft_limit_fail_threshold: Option<f32>,
380 enable_adaptive_chunking: bool,
381 adaptive_min_chunk_size: Option<usize>,
382 adaptive_max_chunk_size: Option<usize>,
383 ) -> Result<crate::stream::JsonStream> {
384 async {
385 let startup_start = std::time::Instant::now();
386
387 use crate::json::validate_row_description;
388 use crate::stream::{extract_json_bytes, parse_json, AdaptiveChunking, ChunkingStrategy, JsonStream};
389 use serde_json::Value;
390 use tokio::sync::mpsc;
391
392 if self.state != ConnectionState::Idle {
393 return Err(WireError::ConnectionBusy(format!(
394 "connection in state: {}",
395 self.state
396 )));
397 }
398
399 self.state.transition(ConnectionState::QueryInProgress)?;
400
401 let query_msg = FrontendMessage::Query(query.to_string());
402 self.send_message(&query_msg).await?;
403
404 self.state.transition(ConnectionState::ReadingResults)?;
405
406 let row_desc;
409 loop {
410 let msg = self.receive_message().await?;
411
412 match msg {
413 BackendMessage::ErrorResponse(err) => {
414 tracing::debug!("PostgreSQL error response: {}", err);
416 loop {
417 let msg = self.receive_message().await?;
418 if matches!(msg, BackendMessage::ReadyForQuery { .. }) {
419 break;
420 }
421 }
422 return Err(WireError::Sql(err.to_string()));
423 }
424 BackendMessage::BackendKeyData { process_id, secret_key: _ } => {
425 tracing::debug!("PostgreSQL backend key data received: pid={}", process_id);
427 continue;
429 }
430 BackendMessage::ParameterStatus { .. } => {
431 tracing::debug!("PostgreSQL parameter status change received");
433 continue;
434 }
435 BackendMessage::NoticeResponse(notice) => {
436 tracing::debug!("PostgreSQL notice: {}", notice);
438 continue;
439 }
440 BackendMessage::RowDescription(_) => {
441 row_desc = msg;
442 break;
443 }
444 BackendMessage::ReadyForQuery { .. } => {
445 return Err(WireError::Protocol(
448 "no result set received from query - \
449 check that the entity name is correct and the table/view exists"
450 .into(),
451 ));
452 }
453 _ => {
454 return Err(WireError::Protocol(format!(
455 "unexpected message type in query response: {:?}",
456 msg
457 )));
458 }
459 }
460 }
461
462 validate_row_description(&row_desc)?;
463
464 let startup_duration = startup_start.elapsed().as_millis() as u64;
466 let entity = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
467 crate::metrics::histograms::query_startup_duration(&entity, startup_duration);
468
469 let (result_tx, result_rx) = mpsc::channel::<Result<Value>>(chunk_size);
471 let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
472
473 let entity_for_metrics = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
475 let entity_for_stream = entity_for_metrics.clone(); let stream = JsonStream::new(
478 result_rx,
479 cancel_tx,
480 entity_for_stream,
481 max_memory,
482 soft_limit_warn_threshold,
483 soft_limit_fail_threshold,
484 );
485
486 let state_lock = stream.clone_state();
488 let pause_signal = stream.clone_pause_signal();
489 let resume_signal = stream.clone_resume_signal();
490
491 let state_atomic = stream.clone_state_atomic();
493
494 let pause_timeout = stream.pause_timeout();
496
497 let query_start = std::time::Instant::now();
499
500 tokio::spawn(async move {
501 let strategy = ChunkingStrategy::new(chunk_size);
502 let mut chunk = strategy.new_chunk();
503 let mut total_rows = 0u64;
504
505 let _adaptive = if enable_adaptive_chunking {
507 let mut adp = AdaptiveChunking::new();
508
509 if let Some(min) = adaptive_min_chunk_size {
511 if let Some(max) = adaptive_max_chunk_size {
512 adp = adp.with_bounds(min, max);
513 }
514 }
515
516 Some(adp)
517 } else {
518 None
519 };
520 let _current_chunk_size = chunk_size;
521
522 loop {
523 if state_lock.is_some() && state_atomic.load(std::sync::atomic::Ordering::Acquire) == 1 {
526 if let (Some(ref state_lock), Some(ref _pause_signal), Some(ref resume_signal)) =
528 (&state_lock, &pause_signal, &resume_signal)
529 {
530 let current_state = state_lock.lock().await;
531 if *current_state == crate::stream::StreamState::Paused {
532 tracing::debug!("stream paused, waiting for resume");
533 drop(current_state); if let Some(timeout) = pause_timeout {
537 if tokio::time::timeout(timeout, resume_signal.notified()).await == Ok(()) {
538 tracing::debug!("stream resumed");
539 } else {
540 tracing::debug!("pause timeout expired, auto-resuming");
541 crate::metrics::counters::stream_pause_timeout_expired(&entity_for_metrics);
542 }
543 } else {
544 resume_signal.notified().await;
546 tracing::debug!("stream resumed");
547 }
548
549 let mut state = state_lock.lock().await;
551 *state = crate::stream::StreamState::Running;
552 }
553 }
554 }
555
556 tokio::select! {
557 _ = cancel_rx.recv() => {
559 tracing::debug!("query cancelled");
560 crate::metrics::counters::query_completed("cancelled", &entity_for_metrics);
561 break;
562 }
563
564 msg_result = self.receive_message() => {
566 match msg_result {
567 Ok(msg) => match msg {
568 BackendMessage::DataRow(_) => {
569 match extract_json_bytes(&msg) {
570 Ok(json_bytes) => {
571 chunk.push(json_bytes);
572
573 if strategy.is_full(&chunk) {
574 let chunk_start = std::time::Instant::now();
575 let rows = chunk.into_rows();
576 let chunk_size_rows = rows.len() as u64;
577
578 const BATCH_SIZE: usize = 8;
581 let mut batch = Vec::with_capacity(BATCH_SIZE);
582 let mut send_error = false;
583
584 for row_bytes in rows {
585 match parse_json(row_bytes) {
586 Ok(value) => {
587 total_rows += 1;
588 batch.push(Ok(value));
589
590 if batch.len() == BATCH_SIZE {
592 for item in batch.drain(..) {
593 if result_tx.send(item).await.is_err() {
594 crate::metrics::counters::query_completed("error", &entity_for_metrics);
595 send_error = true;
596 break;
597 }
598 }
599 if send_error {
600 break;
601 }
602 }
603 }
604 Err(e) => {
605 crate::metrics::counters::json_parse_error(&entity_for_metrics);
606 let _ = result_tx.send(Err(e)).await;
607 crate::metrics::counters::query_completed("error", &entity_for_metrics);
608 send_error = true;
609 break;
610 }
611 }
612 }
613
614 if !send_error {
616 for item in batch {
617 if result_tx.send(item).await.is_err() {
618 crate::metrics::counters::query_completed("error", &entity_for_metrics);
619 break;
620 }
621 }
622 }
623
624 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
626
627 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
629 if chunk_idx.is_multiple_of(10) {
630 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
631 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
632 }
633
634 chunk = strategy.new_chunk();
640 }
641 }
642 Err(e) => {
643 crate::metrics::counters::json_parse_error(&entity_for_metrics);
644 let _ = result_tx.send(Err(e)).await;
645 crate::metrics::counters::query_completed("error", &entity_for_metrics);
646 break;
647 }
648 }
649 }
650 BackendMessage::CommandComplete(_) => {
651 if !chunk.is_empty() {
653 let chunk_start = std::time::Instant::now();
654 let rows = chunk.into_rows();
655 let chunk_size_rows = rows.len() as u64;
656
657 const BATCH_SIZE: usize = 8;
659 let mut batch = Vec::with_capacity(BATCH_SIZE);
660 let mut send_error = false;
661
662 for row_bytes in rows {
663 match parse_json(row_bytes) {
664 Ok(value) => {
665 total_rows += 1;
666 batch.push(Ok(value));
667
668 if batch.len() == BATCH_SIZE {
670 for item in batch.drain(..) {
671 if result_tx.send(item).await.is_err() {
672 crate::metrics::counters::query_completed("error", &entity_for_metrics);
673 send_error = true;
674 break;
675 }
676 }
677 if send_error {
678 break;
679 }
680 }
681 }
682 Err(e) => {
683 crate::metrics::counters::json_parse_error(&entity_for_metrics);
684 let _ = result_tx.send(Err(e)).await;
685 crate::metrics::counters::query_completed("error", &entity_for_metrics);
686 send_error = true;
687 break;
688 }
689 }
690 }
691
692 if !send_error {
694 for item in batch {
695 if result_tx.send(item).await.is_err() {
696 crate::metrics::counters::query_completed("error", &entity_for_metrics);
697 break;
698 }
699 }
700 }
701
702 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
704 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
705 if chunk_idx.is_multiple_of(10) {
706 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
707 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
708 }
709 chunk = strategy.new_chunk();
710 }
711
712 let query_duration = query_start.elapsed().as_millis() as u64;
714 crate::metrics::counters::rows_processed(&entity_for_metrics, total_rows, "ok");
715 crate::metrics::histograms::query_total_duration(&entity_for_metrics, query_duration);
716 crate::metrics::counters::query_completed("success", &entity_for_metrics);
717 }
718 BackendMessage::ReadyForQuery { .. } => {
719 break;
720 }
721 BackendMessage::ErrorResponse(err) => {
722 crate::metrics::counters::query_error(&entity_for_metrics, "server_error");
723 crate::metrics::counters::query_completed("error", &entity_for_metrics);
724 let _ = result_tx.send(Err(WireError::Sql(err.to_string()))).await;
725 break;
726 }
727 _ => {
728 crate::metrics::counters::query_error(&entity_for_metrics, "protocol_error");
729 crate::metrics::counters::query_completed("error", &entity_for_metrics);
730 let _ = result_tx.send(Err(WireError::Protocol(
731 format!("unexpected message: {:?}", msg)
732 ))).await;
733 break;
734 }
735 },
736 Err(e) => {
737 crate::metrics::counters::query_error(&entity_for_metrics, "connection_error");
738 crate::metrics::counters::query_completed("error", &entity_for_metrics);
739 let _ = result_tx.send(Err(e)).await;
740 break;
741 }
742 }
743 }
744 }
745 }
746 });
747
748 Ok(stream)
749 }
750 .instrument(tracing::debug_span!(
751 "streaming_query",
752 query = %query,
753 chunk_size = %chunk_size
754 ))
755 .await
756 }
757}