Skip to main content

barter_integration/stream/ext/
indexed.rs

1use barter_instrument::index::error::IndexError;
2use derive_more::Constructor;
3use futures::Stream;
4use pin_project::pin_project;
5use std::{
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10/// Type that indexes data structures.
11///
12/// An example `Indexer` use case is "keying" an event: <br>
13/// Unindexed = MarketEvent<MarketDataInstrument, DataKind> <br>
14/// Indexed = MarketEvent<InstrumentIndex, DataKind>
15pub trait Indexer {
16    type Unindexed;
17    type Indexed;
18
19    /// Index the input.
20    fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError>;
21}
22
23/// Stream adapter that indexes items using an [`Indexer`].
24#[derive(Debug, Constructor)]
25#[pin_project]
26pub struct IndexedStream<Stream, Indexer> {
27    #[pin]
28    stream: Stream,
29    indexer: Indexer,
30}
31
32impl<St, Index> Stream for IndexedStream<St, Index>
33where
34    St: Stream,
35    Index: Indexer<Unindexed = St::Item>,
36{
37    type Item = Result<Index::Indexed, IndexError>;
38
39    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
40        let this = self.project();
41        match this.stream.poll_next(cx) {
42            Poll::Ready(Some(item)) => Poll::Ready(Some(this.indexer.index(item))),
43            Poll::Ready(None) => Poll::Ready(None),
44            Poll::Pending => Poll::Pending,
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use crate::stream::ext::BarterStreamExt;
53    use futures::StreamExt;
54    use std::collections::HashMap;
55    use tokio::sync::mpsc;
56    use tokio_stream::wrappers::UnboundedReceiverStream;
57    use tokio_test::{assert_pending, assert_ready};
58
59    #[derive(Debug, Clone)]
60    struct UnindexedData {
61        key: String,
62        value: i32,
63    }
64
65    #[derive(Debug, Clone, PartialEq)]
66    struct IndexedData {
67        index: usize,
68        value: i32,
69    }
70
71    struct MapIndexer {
72        map: HashMap<String, usize>,
73    }
74
75    impl Indexer for MapIndexer {
76        type Unindexed = UnindexedData;
77        type Indexed = IndexedData;
78
79        fn index(&self, item: Self::Unindexed) -> Result<Self::Indexed, IndexError> {
80            self.map
81                .get(&item.key)
82                .map(|&index| IndexedData {
83                    index,
84                    value: item.value,
85                })
86                .ok_or_else(|| IndexError::InstrumentIndex(format!("key '{}' not found", item.key)))
87        }
88    }
89
90    #[tokio::test]
91    async fn test_indexed_stream() {
92        let waker = futures::task::noop_waker_ref();
93        let mut cx = Context::from_waker(waker);
94
95        let (tx, rx) = mpsc::unbounded_channel::<UnindexedData>();
96        let rx = UnboundedReceiverStream::new(rx);
97
98        let mut map = HashMap::new();
99        map.insert("a".to_string(), 0);
100        map.insert("b".to_string(), 1);
101        map.insert("c".to_string(), 2);
102
103        let mut stream = rx.with_index(MapIndexer { map });
104
105        assert_pending!(stream.poll_next_unpin(&mut cx));
106
107        tx.send(UnindexedData {
108            key: "a".to_string(),
109            value: 10,
110        })
111        .unwrap();
112        assert_eq!(
113            assert_ready!(stream.poll_next_unpin(&mut cx)),
114            Some(Ok(IndexedData {
115                index: 0,
116                value: 10
117            }))
118        );
119
120        tx.send(UnindexedData {
121            key: "b".to_string(),
122            value: 20,
123        })
124        .unwrap();
125        assert_eq!(
126            assert_ready!(stream.poll_next_unpin(&mut cx)),
127            Some(Ok(IndexedData {
128                index: 1,
129                value: 20
130            }))
131        );
132
133        drop(tx);
134        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
135    }
136}