use crate::protocol::BackendMessage;
use crate::{Result, WireError};
use bytes::Bytes;
use futures::stream::Stream;
use serde_json::Value;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, AtomicU8, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::{mpsc, Mutex, Notify};
pub const STATE_RUNNING: u8 = 0;
pub const STATE_PAUSED: u8 = 1;
pub const STATE_COMPLETED: u8 = 2;
pub const STATE_FAILED: u8 = 3;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum StreamState {
Running,
Paused,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub struct StreamStats {
pub items_buffered: usize,
pub estimated_memory: usize,
pub total_rows_yielded: u64,
pub total_rows_filtered: u64,
}
impl StreamStats {
#[must_use]
pub const fn zero() -> Self {
Self {
items_buffered: 0,
estimated_memory: 0,
total_rows_yielded: 0,
total_rows_filtered: 0,
}
}
}
pub struct JsonStream {
receiver: mpsc::Receiver<Result<Value>>,
_cancel_tx: mpsc::Sender<()>, entity: String, rows_yielded: Arc<AtomicU64>, rows_filtered: Arc<AtomicU64>, max_memory: Option<usize>, soft_limit_fail_threshold: Option<f32>,
state_atomic: Arc<AtomicU8>,
pause_resume: Option<PauseResumeState>,
poll_count: AtomicU64, }
pub struct PauseResumeState {
state: Arc<Mutex<StreamState>>, pause_signal: Arc<Notify>, resume_signal: Arc<Notify>, paused_occupancy: Arc<AtomicUsize>, pause_timeout: Option<Duration>, }
impl JsonStream {
pub(crate) fn new(
receiver: mpsc::Receiver<Result<Value>>,
cancel_tx: mpsc::Sender<()>,
entity: String,
max_memory: Option<usize>,
_soft_limit_warn_threshold: Option<f32>,
soft_limit_fail_threshold: Option<f32>,
) -> Self {
Self {
receiver,
_cancel_tx: cancel_tx,
entity,
rows_yielded: Arc::new(AtomicU64::new(0)),
rows_filtered: Arc::new(AtomicU64::new(0)),
max_memory,
soft_limit_fail_threshold,
state_atomic: Arc::new(AtomicU8::new(STATE_RUNNING)),
pause_resume: None,
poll_count: AtomicU64::new(0),
}
}
fn ensure_pause_resume(&mut self) -> &mut PauseResumeState {
if self.pause_resume.is_none() {
self.pause_resume = Some(PauseResumeState {
state: Arc::new(Mutex::new(StreamState::Running)),
pause_signal: Arc::new(Notify::new()),
resume_signal: Arc::new(Notify::new()),
paused_occupancy: Arc::new(AtomicUsize::new(0)),
pause_timeout: None,
});
}
self.pause_resume
.as_mut()
.expect("pause_resume initialized in the block above")
}
pub fn state_snapshot(&self) -> StreamState {
match self.state_atomic.load(Ordering::Acquire) {
STATE_RUNNING => StreamState::Running,
STATE_PAUSED => StreamState::Paused,
STATE_COMPLETED => StreamState::Completed,
STATE_FAILED => StreamState::Failed,
_ => {
if self.receiver.is_closed() {
StreamState::Completed
} else {
StreamState::Running
}
}
}
}
pub fn paused_occupancy(&self) -> usize {
self.pause_resume
.as_ref()
.map_or(0, |pr| pr.paused_occupancy.load(Ordering::Relaxed))
}
pub fn set_pause_timeout(&mut self, duration: Duration) {
self.ensure_pause_resume().pause_timeout = Some(duration);
tracing::debug!("pause timeout set to {:?}", duration);
}
pub fn clear_pause_timeout(&mut self) {
if let Some(ref mut pr) = self.pause_resume {
pr.pause_timeout = None;
tracing::debug!("pause timeout cleared");
}
}
pub(crate) fn pause_timeout(&self) -> Option<Duration> {
self.pause_resume.as_ref().and_then(|pr| pr.pause_timeout)
}
pub async fn pause(&mut self) -> Result<()> {
let entity = self.entity.clone();
self.state_atomic_set_paused();
let pr = self.ensure_pause_resume();
let mut state = pr.state.lock().await;
match *state {
StreamState::Running => {
pr.pause_signal.notify_one();
*state = StreamState::Paused;
crate::metrics::counters::stream_paused(&entity);
Ok(())
}
StreamState::Paused => {
Ok(())
}
StreamState::Completed | StreamState::Failed => {
Err(WireError::Protocol(
"cannot pause a completed or failed stream".to_string(),
))
}
}
}
pub async fn resume(&mut self) -> Result<()> {
let current = self.state_atomic_get();
if let Some(ref mut pr) = self.pause_resume {
let entity = self.entity.clone();
if current == STATE_PAUSED {
self.state_atomic.store(STATE_RUNNING, Ordering::Release);
}
let mut state = pr.state.lock().await;
match *state {
StreamState::Paused => {
pr.resume_signal.notify_one();
*state = StreamState::Running;
crate::metrics::counters::stream_resumed(&entity);
Ok(())
}
StreamState::Running => {
Ok(())
}
StreamState::Completed | StreamState::Failed => {
Err(WireError::Protocol(
"cannot resume a completed or failed stream".to_string(),
))
}
}
} else {
Ok(())
}
}
pub async fn pause_with_reason(&mut self, reason: &str) -> Result<()> {
tracing::debug!("pausing stream: {}", reason);
self.pause().await
}
pub(crate) fn clone_state(&self) -> Option<Arc<Mutex<StreamState>>> {
self.pause_resume.as_ref().map(|pr| Arc::clone(&pr.state))
}
pub(crate) fn clone_pause_signal(&self) -> Option<Arc<Notify>> {
self.pause_resume
.as_ref()
.map(|pr| Arc::clone(&pr.pause_signal))
}
pub(crate) fn clone_resume_signal(&self) -> Option<Arc<Notify>> {
self.pause_resume
.as_ref()
.map(|pr| Arc::clone(&pr.resume_signal))
}
pub(crate) fn clone_state_atomic(&self) -> Arc<AtomicU8> {
Arc::clone(&self.state_atomic)
}
pub(crate) fn state_atomic_get(&self) -> u8 {
self.state_atomic.load(Ordering::Acquire)
}
pub(crate) fn state_atomic_set_paused(&self) {
self.state_atomic.store(STATE_PAUSED, Ordering::Release);
}
pub(crate) fn state_atomic_set_completed(&self) {
self.state_atomic.store(STATE_COMPLETED, Ordering::Release);
}
pub(crate) fn state_atomic_set_failed(&self) {
self.state_atomic.store(STATE_FAILED, Ordering::Release);
}
pub fn stats(&self) -> StreamStats {
let items_buffered = self.receiver.len();
let estimated_memory = items_buffered * 2048; let total_rows_yielded = self.rows_yielded.load(Ordering::Relaxed);
let total_rows_filtered = self.rows_filtered.load(Ordering::Relaxed);
StreamStats {
items_buffered,
estimated_memory,
total_rows_yielded,
total_rows_filtered,
}
}
}
impl Stream for JsonStream {
type Item = Result<Value>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let poll_idx = self.poll_count.fetch_add(1, Ordering::Relaxed);
if poll_idx.is_multiple_of(1000) {
let occupancy = self.receiver.len() as u64;
crate::metrics::histograms::channel_occupancy(&self.entity, occupancy);
crate::metrics::gauges::stream_buffered_items(&self.entity, occupancy as usize);
}
if let Some(limit) = self.max_memory {
let items_buffered = self.receiver.len();
let estimated_memory = items_buffered * 2048;
if let Some(fail_threshold) = self.soft_limit_fail_threshold {
let threshold_bytes = (limit as f32 * fail_threshold) as usize;
if estimated_memory > threshold_bytes {
crate::metrics::counters::memory_limit_exceeded(&self.entity);
self.state_atomic_set_failed();
return Poll::Ready(Some(Err(WireError::MemoryLimitExceeded {
limit,
estimated_memory,
})));
}
} else if estimated_memory > limit {
crate::metrics::counters::memory_limit_exceeded(&self.entity);
self.state_atomic_set_failed();
return Poll::Ready(Some(Err(WireError::MemoryLimitExceeded {
limit,
estimated_memory,
})));
}
}
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(Ok(value))) => Poll::Ready(Some(Ok(value))),
Poll::Ready(Some(Err(e))) => {
self.state_atomic_set_failed();
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => {
self.state_atomic_set_completed();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
pub fn extract_json_bytes(msg: &BackendMessage) -> Result<Bytes> {
match msg {
BackendMessage::DataRow(fields) => match fields.as_slice() {
[only] => only
.clone()
.ok_or_else(|| WireError::Protocol("null data field".into())),
_ => Err(WireError::Protocol(format!(
"expected 1 field, got {}",
fields.len()
))),
},
_ => Err(WireError::Protocol("expected DataRow".into())),
}
}
pub fn parse_json(data: Bytes) -> Result<Value> {
let value: Value = serde_json::from_slice(&data)?;
Ok(value)
}
#[cfg(test)]
mod tests;