1use crate::protocol::BackendMessage;
4use crate::{Error, Result};
5use bytes::Bytes;
6use futures::stream::Stream;
7use serde_json::Value;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU64, AtomicU8, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13use tokio::sync::{mpsc, Mutex, Notify};
14
15const STATE_RUNNING: u8 = 0;
18const STATE_PAUSED: u8 = 1;
19const STATE_COMPLETED: u8 = 2;
20const STATE_FAILED: u8 = 3;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum StreamState {
29 Running,
31 Paused,
33 Completed,
35 Failed,
37}
38
39#[derive(Debug, Clone)]
44pub struct StreamStats {
45 pub items_buffered: usize,
47 pub estimated_memory: usize,
49 pub total_rows_yielded: u64,
51 pub total_rows_filtered: u64,
53}
54
55impl StreamStats {
56 pub fn zero() -> Self {
60 Self {
61 items_buffered: 0,
62 estimated_memory: 0,
63 total_rows_yielded: 0,
64 total_rows_filtered: 0,
65 }
66 }
67}
68
69pub struct JsonStream {
71 receiver: mpsc::Receiver<Result<Value>>,
72 _cancel_tx: mpsc::Sender<()>, entity: String, rows_yielded: Arc<AtomicU64>, rows_filtered: Arc<AtomicU64>, max_memory: Option<usize>, soft_limit_fail_threshold: Option<f32>, state_atomic: Arc<AtomicU8>,
83
84 pause_resume: Option<PauseResumeState>,
87
88 poll_count: AtomicU64, }
91
92pub struct PauseResumeState {
95 state: Arc<Mutex<StreamState>>, pause_signal: Arc<Notify>, resume_signal: Arc<Notify>, paused_occupancy: Arc<AtomicUsize>, pause_timeout: Option<Duration>, }
101
102impl JsonStream {
103 pub(crate) fn new(
105 receiver: mpsc::Receiver<Result<Value>>,
106 cancel_tx: mpsc::Sender<()>,
107 entity: String,
108 max_memory: Option<usize>,
109 _soft_limit_warn_threshold: Option<f32>,
110 soft_limit_fail_threshold: Option<f32>,
111 ) -> Self {
112 Self {
113 receiver,
114 _cancel_tx: cancel_tx,
115 entity,
116 rows_yielded: Arc::new(AtomicU64::new(0)),
117 rows_filtered: Arc::new(AtomicU64::new(0)),
118 max_memory,
119 soft_limit_fail_threshold,
120
121 state_atomic: Arc::new(AtomicU8::new(STATE_RUNNING)),
123
124 pause_resume: None,
126
127 poll_count: AtomicU64::new(0),
129 }
130 }
131
132 fn ensure_pause_resume(&mut self) -> &mut PauseResumeState {
134 if self.pause_resume.is_none() {
135 self.pause_resume = Some(PauseResumeState {
136 state: Arc::new(Mutex::new(StreamState::Running)),
137 pause_signal: Arc::new(Notify::new()),
138 resume_signal: Arc::new(Notify::new()),
139 paused_occupancy: Arc::new(AtomicUsize::new(0)),
140 pause_timeout: None,
141 });
142 }
143 self.pause_resume.as_mut().unwrap()
144 }
145
146 pub fn state_snapshot(&self) -> StreamState {
154 match self.state_atomic.load(Ordering::Acquire) {
156 STATE_RUNNING => StreamState::Running,
157 STATE_PAUSED => StreamState::Paused,
158 STATE_COMPLETED => StreamState::Completed,
159 STATE_FAILED => StreamState::Failed,
160 _ => {
161 if self.receiver.is_closed() {
163 StreamState::Completed
164 } else {
165 StreamState::Running
166 }
167 }
168 }
169 }
170
171 pub fn paused_occupancy(&self) -> usize {
176 self.pause_resume
177 .as_ref()
178 .map(|pr| pr.paused_occupancy.load(Ordering::Relaxed))
179 .unwrap_or(0)
180 }
181
182 pub fn set_pause_timeout(&mut self, duration: Duration) {
199 self.ensure_pause_resume().pause_timeout = Some(duration);
200 tracing::debug!("pause timeout set to {:?}", duration);
201 }
202
203 pub fn clear_pause_timeout(&mut self) {
205 if let Some(ref mut pr) = self.pause_resume {
206 pr.pause_timeout = None;
207 tracing::debug!("pause timeout cleared");
208 }
209 }
210
211 pub(crate) fn pause_timeout(&self) -> Option<Duration> {
213 self.pause_resume.as_ref().and_then(|pr| pr.pause_timeout)
214 }
215
216 pub async fn pause(&mut self) -> Result<()> {
234 let entity = self.entity.clone();
235
236 self.state_atomic_set_paused();
238
239 let pr = self.ensure_pause_resume();
240 let mut state = pr.state.lock().await;
241
242 match *state {
243 StreamState::Running => {
244 pr.pause_signal.notify_one();
246 *state = StreamState::Paused;
248
249 crate::metrics::counters::stream_paused(&entity);
251 Ok(())
252 }
253 StreamState::Paused => {
254 Ok(())
256 }
257 StreamState::Completed | StreamState::Failed => {
258 Err(Error::Protocol(
260 "cannot pause a completed or failed stream".to_string(),
261 ))
262 }
263 }
264 }
265
266 pub async fn resume(&mut self) -> Result<()> {
284 let current = self.state_atomic_get();
287
288 if let Some(ref mut pr) = self.pause_resume {
290 let entity = self.entity.clone();
291
292 if current == STATE_PAUSED {
294 self.state_atomic.store(STATE_RUNNING, Ordering::Release);
296 }
297
298 let mut state = pr.state.lock().await;
299
300 match *state {
301 StreamState::Paused => {
302 pr.resume_signal.notify_one();
304 *state = StreamState::Running;
306
307 crate::metrics::counters::stream_resumed(&entity);
309 Ok(())
310 }
311 StreamState::Running => {
312 Ok(())
314 }
315 StreamState::Completed | StreamState::Failed => {
316 Err(Error::Protocol(
318 "cannot resume a completed or failed stream".to_string(),
319 ))
320 }
321 }
322 } else {
323 Ok(())
325 }
326 }
327
328 pub async fn pause_with_reason(&mut self, reason: &str) -> Result<()> {
344 tracing::debug!("pausing stream: {}", reason);
345 self.pause().await
346 }
347
348 pub(crate) fn clone_state(&self) -> Option<Arc<Mutex<StreamState>>> {
350 self.pause_resume.as_ref().map(|pr| Arc::clone(&pr.state))
351 }
352
353 pub(crate) fn clone_pause_signal(&self) -> Option<Arc<Notify>> {
355 self.pause_resume
356 .as_ref()
357 .map(|pr| Arc::clone(&pr.pause_signal))
358 }
359
360 pub(crate) fn clone_resume_signal(&self) -> Option<Arc<Notify>> {
362 self.pause_resume
363 .as_ref()
364 .map(|pr| Arc::clone(&pr.resume_signal))
365 }
366
367 pub(crate) fn clone_state_atomic(&self) -> Arc<AtomicU8> {
373 Arc::clone(&self.state_atomic)
374 }
375
376 pub(crate) fn state_atomic_get(&self) -> u8 {
378 self.state_atomic.load(Ordering::Acquire)
379 }
380
381 pub(crate) fn state_atomic_set_paused(&self) {
383 self.state_atomic.store(STATE_PAUSED, Ordering::Release);
384 }
385
386 pub(crate) fn state_atomic_set_completed(&self) {
388 self.state_atomic.store(STATE_COMPLETED, Ordering::Release);
389 }
390
391 pub(crate) fn state_atomic_set_failed(&self) {
393 self.state_atomic.store(STATE_FAILED, Ordering::Release);
394 }
395
396 pub fn stats(&self) -> StreamStats {
408 let items_buffered = self.receiver.len();
409 let estimated_memory = items_buffered * 2048; let total_rows_yielded = self.rows_yielded.load(Ordering::Relaxed);
411 let total_rows_filtered = self.rows_filtered.load(Ordering::Relaxed);
412
413 StreamStats {
414 items_buffered,
415 estimated_memory,
416 total_rows_yielded,
417 total_rows_filtered,
418 }
419 }
420
421 #[allow(unused)]
423 pub(crate) fn increment_rows_yielded(&self, count: u64) {
424 self.rows_yielded.fetch_add(count, Ordering::Relaxed);
425 }
426
427 #[allow(unused)]
429 pub(crate) fn increment_rows_filtered(&self, count: u64) {
430 self.rows_filtered.fetch_add(count, Ordering::Relaxed);
431 }
432
433 #[allow(unused)]
435 pub(crate) fn clone_rows_yielded(&self) -> Arc<AtomicU64> {
436 Arc::clone(&self.rows_yielded)
437 }
438
439 #[allow(unused)]
441 pub(crate) fn clone_rows_filtered(&self) -> Arc<AtomicU64> {
442 Arc::clone(&self.rows_filtered)
443 }
444}
445
446impl Stream for JsonStream {
447 type Item = Result<Value>;
448
449 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
450 let poll_idx = self.poll_count.fetch_add(1, Ordering::Relaxed);
453 if poll_idx.is_multiple_of(1000) {
454 let occupancy = self.receiver.len() as u64;
455 crate::metrics::histograms::channel_occupancy(&self.entity, occupancy);
456 crate::metrics::gauges::stream_buffered_items(&self.entity, occupancy as usize);
457 }
458
459 if let Some(limit) = self.max_memory {
462 let items_buffered = self.receiver.len();
463 let estimated_memory = items_buffered * 2048; if let Some(fail_threshold) = self.soft_limit_fail_threshold {
467 let threshold_bytes = (limit as f32 * fail_threshold) as usize;
468 if estimated_memory > threshold_bytes {
469 crate::metrics::counters::memory_limit_exceeded(&self.entity);
471 self.state_atomic_set_failed();
472 return Poll::Ready(Some(Err(Error::MemoryLimitExceeded {
473 limit,
474 estimated_memory,
475 })));
476 }
477 } else if estimated_memory > limit {
478 crate::metrics::counters::memory_limit_exceeded(&self.entity);
480 self.state_atomic_set_failed();
481 return Poll::Ready(Some(Err(Error::MemoryLimitExceeded {
482 limit,
483 estimated_memory,
484 })));
485 }
486
487 }
490
491 match self.receiver.poll_recv(cx) {
492 Poll::Ready(Some(Ok(value))) => Poll::Ready(Some(Ok(value))),
493 Poll::Ready(Some(Err(e))) => {
494 self.state_atomic_set_failed();
496 Poll::Ready(Some(Err(e)))
497 }
498 Poll::Ready(None) => {
499 self.state_atomic_set_completed();
501 Poll::Ready(None)
502 }
503 Poll::Pending => Poll::Pending,
504 }
505 }
506}
507
508pub fn extract_json_bytes(msg: &BackendMessage) -> Result<Bytes> {
510 match msg {
511 BackendMessage::DataRow(fields) => {
512 if fields.len() != 1 {
513 return Err(Error::Protocol(format!(
514 "expected 1 field, got {}",
515 fields.len()
516 )));
517 }
518
519 let field = &fields[0];
520 field
521 .clone()
522 .ok_or_else(|| Error::Protocol("null data field".into()))
523 }
524 _ => Err(Error::Protocol("expected DataRow".into())),
525 }
526}
527
528pub fn parse_json(data: Bytes) -> Result<Value> {
530 let value: Value = serde_json::from_slice(&data)?;
531 Ok(value)
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_extract_json_bytes() {
540 let data = Bytes::from_static(b"{\"key\":\"value\"}");
541 let msg = BackendMessage::DataRow(vec![Some(data.clone())]);
542
543 let extracted = extract_json_bytes(&msg).unwrap();
544 assert_eq!(extracted, data);
545 }
546
547 #[test]
548 fn test_extract_null_field() {
549 let msg = BackendMessage::DataRow(vec![None]);
550 assert!(extract_json_bytes(&msg).is_err());
551 }
552
553 #[test]
554 fn test_parse_json() {
555 let data = Bytes::from_static(b"{\"key\":\"value\"}");
556 let value = parse_json(data).unwrap();
557
558 assert_eq!(value["key"], "value");
559 }
560
561 #[test]
562 fn test_parse_invalid_json() {
563 let data = Bytes::from_static(b"not json");
564 assert!(parse_json(data).is_err());
565 }
566
567 #[test]
568 fn test_stream_stats_creation() {
569 let stats = StreamStats::zero();
570 assert_eq!(stats.items_buffered, 0);
571 assert_eq!(stats.estimated_memory, 0);
572 assert_eq!(stats.total_rows_yielded, 0);
573 assert_eq!(stats.total_rows_filtered, 0);
574 }
575
576 #[test]
577 fn test_stream_stats_memory_estimation() {
578 let stats = StreamStats {
579 items_buffered: 100,
580 estimated_memory: 100 * 2048,
581 total_rows_yielded: 100,
582 total_rows_filtered: 10,
583 };
584
585 assert_eq!(stats.estimated_memory, 204800);
587 }
588
589 #[test]
590 fn test_stream_stats_clone() {
591 let stats = StreamStats {
592 items_buffered: 50,
593 estimated_memory: 100000,
594 total_rows_yielded: 500,
595 total_rows_filtered: 50,
596 };
597
598 let cloned = stats.clone();
599 assert_eq!(cloned.items_buffered, stats.items_buffered);
600 assert_eq!(cloned.estimated_memory, stats.estimated_memory);
601 assert_eq!(cloned.total_rows_yielded, stats.total_rows_yielded);
602 assert_eq!(cloned.total_rows_filtered, stats.total_rows_filtered);
603 }
604
605 #[test]
606 fn test_stream_state_constants() {
607 assert_ne!(STATE_RUNNING, STATE_PAUSED);
609 assert_ne!(STATE_RUNNING, STATE_COMPLETED);
610 assert_ne!(STATE_RUNNING, STATE_FAILED);
611 assert_ne!(STATE_PAUSED, STATE_COMPLETED);
612 assert_ne!(STATE_PAUSED, STATE_FAILED);
613 assert_ne!(STATE_COMPLETED, STATE_FAILED);
614 }
615
616 #[test]
617 fn test_stream_state_enum_equality() {
618 assert_eq!(StreamState::Running, StreamState::Running);
619 assert_eq!(StreamState::Paused, StreamState::Paused);
620 assert_eq!(StreamState::Completed, StreamState::Completed);
621 assert_eq!(StreamState::Failed, StreamState::Failed);
622 assert_ne!(StreamState::Running, StreamState::Paused);
623 }
624}