barter_integration/stream/ext/
indexed.rs1use barter_instrument::index::error::IndexError;
2use derive_more::Constructor;
3use futures::Stream;
4use pin_project::pin_project;
5use std::{
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10pub trait Indexer {
16 type Unindexed;
17 type Indexed;
18
19 fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError>;
21}
22
23#[derive(Debug, Constructor)]
25#[pin_project]
26pub struct IndexedStream<Stream, Indexer> {
27 #[pin]
28 stream: Stream,
29 indexer: Indexer,
30}
31
32impl<St, Index> Stream for IndexedStream<St, Index>
33where
34 St: Stream,
35 Index: Indexer<Unindexed = St::Item>,
36{
37 type Item = Result<Index::Indexed, IndexError>;
38
39 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
40 let this = self.project();
41 match this.stream.poll_next(cx) {
42 Poll::Ready(Some(item)) => Poll::Ready(Some(this.indexer.index(item))),
43 Poll::Ready(None) => Poll::Ready(None),
44 Poll::Pending => Poll::Pending,
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use crate::stream::ext::BarterStreamExt;
53 use futures::StreamExt;
54 use std::collections::HashMap;
55 use tokio::sync::mpsc;
56 use tokio_stream::wrappers::UnboundedReceiverStream;
57 use tokio_test::{assert_pending, assert_ready};
58
59 #[derive(Debug, Clone)]
60 struct UnindexedData {
61 key: String,
62 value: i32,
63 }
64
65 #[derive(Debug, Clone, PartialEq)]
66 struct IndexedData {
67 index: usize,
68 value: i32,
69 }
70
71 struct MapIndexer {
72 map: HashMap<String, usize>,
73 }
74
75 impl Indexer for MapIndexer {
76 type Unindexed = UnindexedData;
77 type Indexed = IndexedData;
78
79 fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError> {
80 self.map
81 .get(&item.key)
82 .map(|&index| IndexedData {
83 index,
84 value: item.value,
85 })
86 .ok_or_else(|| IndexError::InstrumentIndex(format!("key '{}' not found", item.key)))
87 }
88 }
89
90 #[tokio::test]
91 async fn test_indexed_stream() {
92 let waker = futures::task::noop_waker_ref();
93 let mut cx = Context::from_waker(waker);
94
95 let (tx, rx) = mpsc::unbounded_channel::<UnindexedData>();
96 let rx = UnboundedReceiverStream::new(rx);
97
98 let mut map = HashMap::new();
99 map.insert("a".to_string(), 0);
100 map.insert("b".to_string(), 1);
101 map.insert("c".to_string(), 2);
102
103 let mut stream = rx.with_index(MapIndexer { map });
104
105 assert_pending!(stream.poll_next_unpin(&mut cx));
106
107 tx.send(UnindexedData {
108 key: "a".to_string(),
109 value: 10,
110 })
111 .unwrap();
112 assert_eq!(
113 assert_ready!(stream.poll_next_unpin(&mut cx)),
114 Some(Ok(IndexedData {
115 index: 0,
116 value: 10
117 }))
118 );
119
120 tx.send(UnindexedData {
121 key: "b".to_string(),
122 value: 20,
123 })
124 .unwrap();
125 assert_eq!(
126 assert_ready!(stream.poll_next_unpin(&mut cx)),
127 Some(Ok(IndexedData {
128 index: 1,
129 value: 20
130 }))
131 );
132
133 drop(tx);
134 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
135 }
136}