use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures_util::future::BoxFuture;
use futures_util::stream::FuturesUnordered;
use futures_util::{Stream, StreamExt};
use tl_proto::TlRead;
use super::node::Node;
use super::peers_iter::PeersIter;
use crate::proto;
#[must_use = "streams do nothing unless polled"]
pub struct DhtValuesStream<T> {
dht: Arc<Node>,
query: Bytes,
batch_len: Option<usize>,
known_peers_version: u64,
use_new_peers: bool,
peers_iter: PeersIter,
futures: FuturesUnordered<ValueFuture<T>>,
future_count: usize,
_marker: std::marker::PhantomData<T>,
}
impl<T> Unpin for DhtValuesStream<T> {}
impl<T> DhtValuesStream<T>
where
for<'a> T: TlRead<'a, Repr = tl_proto::Boxed> + Send + 'static,
{
pub(super) fn new(dht: Arc<Node>, key: proto::dht::Key<'_>) -> Self {
let key_id = tl_proto::hash_as_boxed(key);
let peers_iter = PeersIter::with_key_id(key_id);
let batch_len = Some(dht.options().default_value_batch_len);
let known_peers_version = dht.known_peers().version();
let query = tl_proto::serialize(proto::rpc::DhtFindValue { key: &key_id, k: 6 }).into();
Self {
dht,
query,
batch_len,
known_peers_version,
use_new_peers: false,
peers_iter,
futures: Default::default(),
future_count: usize::MAX,
_marker: Default::default(),
}
}
pub fn use_full_batch(mut self) -> Self {
self.batch_len = None;
self
}
pub fn use_new_peers(mut self, enable: bool) -> Self {
self.use_new_peers = enable;
self
}
fn refill_futures(&mut self) {
while let Some(peer_id) = self.peers_iter.next() {
let dht = self.dht.clone();
let query = self.query.clone();
self.futures.push(Box::pin(async move {
match dht.query_raw(&peer_id, query).await {
Ok(Some(result)) => match dht.parse_value_result::<T>(&result) {
Ok(Some(value)) => Some(value),
Ok(None) => None,
Err(e) => {
tracing::warn!("failed to parse queried value: {e}");
None
}
},
Ok(None) => None,
Err(e) => {
tracing::warn!("failed to query value: {e}");
None
}
}
}));
self.future_count += 1;
if self.future_count > MAX_PARALLEL_FUTURES {
break;
}
}
}
}
impl<T> Stream for DhtValuesStream<T>
where
for<'a> T: TlRead<'a, Repr = tl_proto::Boxed> + Send + 'static,
{
type Item = ReceivedValue<T>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.future_count == usize::MAX {
this.peers_iter.fill(&this.dht, this.batch_len);
this.future_count = 0;
}
loop {
if this.future_count < MAX_PARALLEL_FUTURES {
this.refill_futures();
}
match this.futures.poll_next_unpin(cx) {
Poll::Ready(Some(value)) => {
match this.dht.known_peers().version() {
version if this.use_new_peers && version != this.known_peers_version => {
this.peers_iter.fill(&this.dht, this.batch_len);
this.known_peers_version = version;
}
_ => {}
}
this.future_count -= 1;
if let Some(value) = value {
break Poll::Ready(Some(value));
}
}
Poll::Ready(None) => break Poll::Ready(None),
Poll::Pending => break Poll::Pending,
}
}
}
}
type ValueFuture<T> = BoxFuture<'static, Option<ReceivedValue<T>>>;
type ReceivedValue<T> = (proto::dht::KeyDescriptionOwned, T);
const MAX_PARALLEL_FUTURES: usize = 5;