fraiseql_wire/stream/
filter.rs1use crate::Result;
4use futures::stream::Stream;
5use serde_json::Value;
6use std::pin::Pin;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::task::{Context, Poll};
9
10pub type Predicate = Box<dyn Fn(&Value) -> bool + Send>;
12
13pub struct FilteredStream<S> {
15 inner: S,
16 predicate: Predicate,
17 eval_count: AtomicU64,
19}
20
21impl<S> FilteredStream<S> {
22 pub fn new(inner: S, predicate: Predicate) -> Self {
24 Self {
25 inner,
26 predicate,
27 eval_count: AtomicU64::new(0),
28 }
29 }
30}
31
32impl<S> Stream for FilteredStream<S>
33where
34 S: Stream<Item = Result<Value>> + Unpin,
35{
36 type Item = Result<Value>;
37
38 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
39 loop {
40 match Pin::new(&mut self.inner).poll_next(cx) {
41 Poll::Ready(Some(Ok(value))) => {
42 let eval_idx = self.eval_count.fetch_add(1, Ordering::Relaxed);
45 let passed = if eval_idx.is_multiple_of(1000) {
46 let filter_start = std::time::Instant::now();
48 let result = (self.predicate)(&value);
49 let filter_duration = filter_start.elapsed().as_millis() as u64;
50 crate::metrics::histograms::filter_duration("unknown", filter_duration);
51 result
52 } else {
53 (self.predicate)(&value)
55 };
56
57 if passed {
58 return Poll::Ready(Some(Ok(value)));
59 }
60 crate::metrics::counters::rows_filtered("unknown", 1);
62 continue;
63 }
64 Poll::Ready(Some(Err(e))) => {
65 return Poll::Ready(Some(Err(e)));
67 }
68 Poll::Ready(None) => {
69 return Poll::Ready(None);
71 }
72 Poll::Pending => {
73 return Poll::Pending;
74 }
75 }
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 #![allow(clippy::unwrap_used)] use super::*;
84 use crate::Error;
85 use futures::{stream, StreamExt};
86
87 #[tokio::test]
88 async fn test_filter_stream() {
89 let values = vec![
90 Ok(serde_json::json!({"id": 1, "active": true})),
91 Ok(serde_json::json!({"id": 2, "active": false})),
92 Ok(serde_json::json!({"id": 3, "active": true})),
93 ];
94
95 let inner = stream::iter(values);
96
97 let predicate: Predicate = Box::new(|v| v["active"].as_bool().unwrap_or(false));
98
99 let mut filtered = FilteredStream::new(inner, predicate);
100
101 let mut results = Vec::new();
102 while let Some(item) = filtered.next().await {
103 let value = item.unwrap();
104 results.push(value["id"].as_i64().unwrap());
105 }
106
107 assert_eq!(results, vec![1, 3]);
108 }
109
110 #[tokio::test]
111 async fn test_filter_propagates_errors() {
112 let values = vec![
113 Ok(serde_json::json!({"id": 1})),
114 Err(Error::JsonDecode(serde_json::Error::io(
115 std::io::Error::other("test error"),
116 ))),
117 Ok(serde_json::json!({"id": 2})),
118 ];
119
120 let inner = stream::iter(values);
121 let predicate: Predicate = Box::new(|_| true);
122
123 let mut filtered = FilteredStream::new(inner, predicate);
124
125 assert!(filtered.next().await.unwrap().is_ok());
127
128 assert!(filtered.next().await.unwrap().is_err());
130
131 assert!(filtered.next().await.unwrap().is_ok());
133 }
134
135 #[tokio::test]
136 async fn test_filter_all_filtered_out() {
137 let values = vec![
138 Ok(serde_json::json!({"id": 1})),
139 Ok(serde_json::json!({"id": 2})),
140 ];
141
142 let inner = stream::iter(values);
143 let predicate: Predicate = Box::new(|_| false); let mut filtered = FilteredStream::new(inner, predicate);
146
147 assert!(filtered.next().await.is_none());
149 }
150}