1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use bytes::Bytes;
use futures_util::future::BoxFuture;
use futures_util::stream::FuturesUnordered;
use futures_util::{Stream, StreamExt};
use tl_proto::TlRead;

use super::node::DhtNode;
use super::peers_iter::PeersIter;
use crate::proto;

/// Stream for the `DhtNode::values` method.
#[must_use = "streams do nothing unless polled"]
pub struct DhtValuesStream<T> {
    dht: Arc<DhtNode>,
    query: Bytes,
    batch_len: Option<usize>,
    known_peers_version: u64,
    use_new_peers: bool,
    peers_iter: PeersIter,
    futures: FuturesUnordered<ValueFuture<T>>,
    future_count: usize,
    _marker: std::marker::PhantomData<T>,
}

impl<T> Unpin for DhtValuesStream<T> {}

impl<T> DhtValuesStream<T>
where
    for<'a> T: TlRead<'a, Repr = tl_proto::Boxed> + Send + 'static,
{
    pub(super) fn new(dht: Arc<DhtNode>, key: proto::dht::Key<'_>) -> Self {
        let key_id = tl_proto::hash_as_boxed(key);
        let peers_iter = PeersIter::with_key_id(key_id);

        let batch_len = Some(dht.options().default_value_batch_len);
        let known_peers_version = dht.known_peers().version();

        let query = tl_proto::serialize(proto::rpc::DhtFindValue { key: &key_id, k: 6 }).into();

        Self {
            dht,
            query,
            batch_len,
            known_peers_version,
            use_new_peers: false,
            peers_iter,
            futures: Default::default(),
            future_count: usize::MAX,
            _marker: Default::default(),
        }
    }

    /// Use all DHT nodes in peers iterator
    pub fn use_full_batch(mut self) -> Self {
        self.batch_len = None;
        self
    }

    /// Whether stream should fill peers iterator when new nodes are found
    pub fn use_new_peers(mut self, enable: bool) -> Self {
        self.use_new_peers = enable;
        self
    }

    fn refill_futures(&mut self) {
        // Spawn at most `max_tasks` queries
        while let Some(peer_id) = self.peers_iter.next() {
            let dht = self.dht.clone();
            let query = self.query.clone();

            self.futures.push(Box::pin(async move {
                match dht.query_raw(&peer_id, query).await {
                    Ok(Some(result)) => match dht.parse_value_result::<T>(&result) {
                        Ok(Some(value)) => Some(value),
                        Ok(None) => None,
                        Err(e) => {
                            tracing::warn!("Failed to parse queried value: {e}");
                            None
                        }
                    },
                    Ok(None) => None,
                    Err(e) => {
                        tracing::warn!("Failed to query value: {e}");
                        None
                    }
                }
            }));

            self.future_count += 1;
            if self.future_count > MAX_PARALLEL_FUTURES {
                break;
            }
        }
    }
}

impl<T> Stream for DhtValuesStream<T>
where
    for<'a> T: TlRead<'a, Repr = tl_proto::Boxed> + Send + 'static,
{
    type Item = ReceivedValue<T>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.get_mut();

        // Fill iterator during the first poll
        if this.future_count == usize::MAX {
            this.peers_iter.fill(&this.dht, this.batch_len);
            this.future_count = 0;
        }

        loop {
            // Keep starting new futures when we can
            if this.future_count < MAX_PARALLEL_FUTURES {
                this.refill_futures();
            }

            match this.futures.poll_next_unpin(cx) {
                Poll::Ready(Some(value)) => {
                    // Refill peers iterator when version has changed and `use_new_peers` is set
                    match this.dht.known_peers().version() {
                        version if this.use_new_peers && version != this.known_peers_version => {
                            this.peers_iter.fill(&this.dht, this.batch_len);
                            this.known_peers_version = version;
                        }
                        _ => {}
                    }

                    // Decrease the number of parallel futures on each new item from `futures`
                    this.future_count -= 1;

                    if let Some(value) = value {
                        break Poll::Ready(Some(value));
                    }
                }
                Poll::Ready(None) => break Poll::Ready(None),
                Poll::Pending => break Poll::Pending,
            }
        }
    }
}

type ValueFuture<T> = BoxFuture<'static, Option<ReceivedValue<T>>>;
type ReceivedValue<T> = (proto::dht::KeyDescriptionOwned, T);

const MAX_PARALLEL_FUTURES: usize = 5;