use std::collections::VecDeque;
use futures::{Stream, StreamExt};
use serde_json::{json, Value as JsonValue};
use crate::cancel_token::CancellationFlag;
use crate::store::filter::SqlValue;
use crate::store::postgres_backend::{
bind_value, build_select_sql, classify_sql_error, introspect_conn,
map_pg_row, PostgresStoreBackend, StoreError, StoreRow,
};
use crate::stream_effect::BackpressurePolicy;
pub const DEFAULT_RETRIEVE_POLICY: BackpressurePolicy =
BackpressurePolicy::PauseUpstream;
pub const DEFAULT_MAX_ROWS: usize = 10_000;
#[derive(Debug, Clone, PartialEq, Default)]
pub struct RowStreamOutcome {
pub rows: Vec<StoreRow>,
pub total_seen: usize,
pub dropped: usize,
pub truncated: bool,
pub cancelled: bool,
}
pub async fn drain_with_policy<S>(
mut stream: S,
policy: BackpressurePolicy,
max_rows: usize,
cancel: &CancellationFlag,
) -> Result<RowStreamOutcome, StoreError>
where
S: Stream<Item = Result<StoreRow, StoreError>> + Unpin,
{
let mut kept: VecDeque<StoreRow> = VecDeque::new();
let mut outcome = RowStreamOutcome::default();
while let Some(item) = stream.next().await {
if cancel.is_cancelled() {
outcome.cancelled = true;
break;
}
let row = item?;
outcome.total_seen += 1;
match policy {
BackpressurePolicy::Fail => {
if kept.len() >= max_rows {
return Err(StoreError::Query {
op: "retrieve",
source: format!(
"result set exceeds the {max_rows}-row stream \
bound (backpressure policy: fail)"
),
});
}
kept.push_back(row);
}
BackpressurePolicy::DropOldest => {
kept.push_back(row);
if kept.len() > max_rows {
kept.pop_front();
outcome.dropped += 1;
}
}
BackpressurePolicy::PauseUpstream => {
if kept.len() >= max_rows {
outcome.truncated = true;
break;
}
kept.push_back(row);
}
BackpressurePolicy::DegradeQuality => {
kept.push_back(row);
}
}
}
outcome.rows = kept.into_iter().collect();
Ok(outcome)
}
pub async fn stream_retrieve(
backend: &PostgresStoreBackend,
table: &str,
where_expr: &str,
policy: BackpressurePolicy,
max_rows: usize,
cancel: &CancellationFlag,
bindings: &std::collections::HashMap<String, String>,
) -> Result<RowStreamOutcome, StoreError> {
if let Some(resolved) = backend.cached_schema(table) {
let (sql, params): (String, Vec<SqlValue>) = build_select_sql(
table,
Some(resolved.schema.as_str()),
where_expr,
bindings,
&resolved.column_types,
)?;
let mut query = sqlx::query(&sql).persistent(false);
for value in ¶ms {
query = bind_value(query, value);
}
let cursor = query.fetch(backend.pool()).map(|item| {
item.map_err(|e| classify_sql_error("retrieve", e))
.and_then(|pg_row| map_pg_row(&pg_row))
});
match drain_with_policy(cursor, policy, max_rows, cancel).await {
Ok(outcome) => return Ok(outcome),
Err(e) if e.is_schema_drift() => {
backend.evict_schema(table);
}
Err(e) => return Err(e),
}
}
let mut tx = backend
.pool()
.begin()
.await
.map_err(|e| StoreError::Connect { source: e.to_string() })?;
let resolved = introspect_conn(&mut tx, table).await;
let no_types = std::collections::HashMap::new();
let (schema, column_types) = match &resolved {
Ok(r) => (Some(r.schema.as_str()), &r.column_types),
Err(e) => {
tracing::warn!(
target: "axon::store",
table = %table,
op = "introspect_in_tx_stream",
error = %e,
d_letter = "D3+38.x.a",
"store introspection failed inside the stream-cursor \
transaction; falling back to bare-table cursor — the \
drain will likely fail with the same root cause."
);
(None, &no_types)
}
};
let (sql, params): (String, Vec<SqlValue>) =
build_select_sql(table, schema, where_expr, bindings, column_types)?;
let mut query = sqlx::query(&sql).persistent(false);
for value in ¶ms {
query = bind_value(query, value);
}
let outcome = {
let cursor = query.fetch(&mut *tx).map(|item| {
item.map_err(|e| classify_sql_error("retrieve", e))
.and_then(|pg_row| map_pg_row(&pg_row))
});
drain_with_policy(cursor, policy, max_rows, cancel).await
};
tx.commit()
.await
.map_err(|e| StoreError::Connect { source: e.to_string() })?;
if let Ok(r) = resolved {
backend.cache_schema(table, r);
}
outcome
}
pub fn stream_metadata(
policy: BackpressurePolicy,
outcome: &RowStreamOutcome,
) -> JsonValue {
json!({
"policy": policy.slug(),
"total_seen": outcome.total_seen,
"dropped": outcome.dropped,
"truncated": outcome.truncated,
"cancelled": outcome.cancelled,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
fn row(id: i64) -> StoreRow {
StoreRow {
columns: vec![("id".to_string(), Value::from(id))],
}
}
fn ok_stream(
n: usize,
) -> impl Stream<Item = Result<StoreRow, StoreError>> + Unpin {
futures::stream::iter(
(0..n as i64).map(|i| Ok(row(i))).collect::<Vec<_>>(),
)
}
#[tokio::test]
async fn fail_policy_allows_a_result_within_the_bound() {
let outcome = drain_with_policy(
ok_stream(5),
BackpressurePolicy::Fail,
10,
&CancellationFlag::new(),
)
.await
.unwrap();
assert_eq!(outcome.rows.len(), 5);
assert_eq!(outcome.total_seen, 5);
}
#[tokio::test]
async fn fail_policy_errors_when_the_result_exceeds_the_bound() {
let result = drain_with_policy(
ok_stream(50),
BackpressurePolicy::Fail,
10,
&CancellationFlag::new(),
)
.await;
assert!(matches!(result, Err(StoreError::Query { .. })));
}
#[tokio::test]
async fn drop_oldest_keeps_the_most_recent_window() {
let outcome = drain_with_policy(
ok_stream(100),
BackpressurePolicy::DropOldest,
10,
&CancellationFlag::new(),
)
.await
.unwrap();
assert_eq!(outcome.rows.len(), 10, "bounded to the window");
assert_eq!(outcome.dropped, 90);
assert_eq!(outcome.total_seen, 100);
assert_eq!(outcome.rows.first().unwrap().get("id"), Some(&Value::from(90)));
assert_eq!(outcome.rows.last().unwrap().get("id"), Some(&Value::from(99)));
}
#[tokio::test]
async fn pause_upstream_truncates_at_the_bound() {
let outcome = drain_with_policy(
ok_stream(100),
BackpressurePolicy::PauseUpstream,
10,
&CancellationFlag::new(),
)
.await
.unwrap();
assert_eq!(outcome.rows.len(), 10);
assert!(outcome.truncated, "more rows existed past the bound");
assert_eq!(outcome.rows.first().unwrap().get("id"), Some(&Value::from(0)));
assert_eq!(outcome.rows.last().unwrap().get("id"), Some(&Value::from(9)));
}
#[tokio::test]
async fn pause_upstream_within_the_bound_is_not_truncated() {
let outcome = drain_with_policy(
ok_stream(3),
BackpressurePolicy::PauseUpstream,
10,
&CancellationFlag::new(),
)
.await
.unwrap();
assert_eq!(outcome.rows.len(), 3);
assert!(!outcome.truncated);
}
#[tokio::test]
async fn degrade_quality_is_the_oss_identity_drain() {
let outcome = drain_with_policy(
ok_stream(50),
BackpressurePolicy::DegradeQuality,
10,
&CancellationFlag::new(),
)
.await
.unwrap();
assert_eq!(outcome.rows.len(), 50);
assert_eq!(outcome.dropped, 0);
assert!(!outcome.truncated);
}
#[tokio::test]
async fn a_cancelled_flag_stops_the_drain_immediately() {
let cancel = CancellationFlag::new();
cancel.cancel();
let outcome = drain_with_policy(
ok_stream(100),
BackpressurePolicy::DegradeQuality,
1000,
&cancel,
)
.await
.unwrap();
assert!(outcome.cancelled);
assert!(outcome.rows.is_empty(), "no row consumed after cancel");
}
#[tokio::test]
async fn a_row_decode_error_aborts_the_drain() {
let items: Vec<Result<StoreRow, StoreError>> = vec![
Ok(row(0)),
Err(StoreError::Decode {
column: "x".into(),
pg_type: "INT4".into(),
source: "boom".into(),
}),
Ok(row(2)),
];
let result = drain_with_policy(
futures::stream::iter(items),
BackpressurePolicy::DegradeQuality,
100,
&CancellationFlag::new(),
)
.await;
assert!(matches!(result, Err(StoreError::Decode { .. })));
}
#[tokio::test]
async fn an_empty_result_drains_cleanly() {
let outcome = drain_with_policy(
ok_stream(0),
DEFAULT_RETRIEVE_POLICY,
DEFAULT_MAX_ROWS,
&CancellationFlag::new(),
)
.await
.unwrap();
assert!(outcome.rows.is_empty());
assert_eq!(outcome.total_seen, 0);
assert!(!outcome.truncated && !outcome.cancelled);
}
#[test]
fn stream_metadata_carries_the_drain_disposition() {
let outcome = RowStreamOutcome {
rows: vec![row(1)],
total_seen: 100,
dropped: 99,
truncated: false,
cancelled: false,
};
let meta = stream_metadata(BackpressurePolicy::DropOldest, &outcome);
assert_eq!(meta["policy"], "drop_oldest");
assert_eq!(meta["total_seen"], 100);
assert_eq!(meta["dropped"], 99);
assert_eq!(meta["truncated"], false);
}
#[test]
fn defaults_are_pause_upstream_and_a_sane_bound() {
assert_eq!(DEFAULT_RETRIEVE_POLICY, BackpressurePolicy::PauseUpstream);
assert!(DEFAULT_MAX_ROWS >= 1000);
}
}