use core::pin::Pin;
use core::task::{Context, Poll};
use futures_util::stream::{Stream, StreamExt};
use crate::xfer::{DnsHandle, DnsRequest, DnsResponse};
use crate::{DnsError, NetError};
#[derive(Clone)]
#[must_use = "queries can only be sent through a ClientHandle"]
#[allow(dead_code)]
pub struct RetryDnsHandle<H> {
handle: H,
attempts: usize,
}
impl<H> RetryDnsHandle<H> {
pub fn new(handle: H, attempts: usize) -> Self {
Self { handle, attempts }
}
}
impl<H: DnsHandle> DnsHandle for RetryDnsHandle<H> {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin>>;
type Runtime = H::Runtime;
fn send(&self, request: DnsRequest) -> Self::Response {
let stream = self.handle.send(request.clone());
Box::pin(RetrySendStream {
request,
handle: self.handle.clone(),
stream,
remaining_attempts: self.attempts,
})
}
}
struct RetrySendStream<H: DnsHandle> {
request: DnsRequest,
handle: H,
stream: <H as DnsHandle>::Response,
remaining_attempts: usize,
}
impl<H: DnsHandle> Stream for RetrySendStream<H> {
type Item = Result<DnsResponse, NetError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let err = match self.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Err(e))) => e,
poll => return poll,
};
match (self.remaining_attempts, err) {
(0, err) => return Poll::Ready(Some(Err(err))),
(
_,
err @ NetError::NoConnections
| err @ NetError::Dns(DnsError::NoRecordsFound(_)),
) => return Poll::Ready(Some(Err(err))),
(_, NetError::Busy) => {}
(_, _) => self.remaining_attempts -= 1,
}
let request = self.request.clone();
self.stream = self.handle.send(request);
}
}
}
#[cfg(all(test, feature = "tokio"))]
mod test {
use core::sync::atomic::{AtomicU16, Ordering};
use std::sync::Arc;
use futures_executor::block_on;
use futures_util::future::{err, ok};
use futures_util::stream::{Stream, once};
use super::*;
use crate::proto::op::Message;
use crate::runtime::TokioRuntimeProvider;
use crate::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
use test_support::subscribe;
#[derive(Clone)]
struct TestClient {
last_succeed: bool,
retries: u16,
attempts: Arc<AtomicU16>,
}
impl DnsHandle for TestClient {
type Response = Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin>;
type Runtime = TokioRuntimeProvider;
fn send(&self, _: DnsRequest) -> Self::Response {
let i = self.attempts.load(Ordering::SeqCst);
if (i > self.retries || self.retries - i == 0) && self.last_succeed {
let mut message = Message::query();
message.metadata.id = i;
return Box::new(once(ok(
DnsResponse::from_message(message.into_response()).unwrap()
)));
}
self.attempts.fetch_add(1, Ordering::SeqCst);
Box::new(once(err(NetError::from("last retry set to fail"))))
}
}
#[test]
fn test_retry() {
subscribe();
let handle = RetryDnsHandle::new(
TestClient {
last_succeed: true,
retries: 1,
attempts: Arc::new(AtomicU16::new(0)),
},
2,
);
let test1 = DnsRequest::from(Message::query());
let result = block_on(handle.send(test1).first_answer()).expect("should have succeeded");
assert_eq!(result.id, 1); }
#[test]
fn test_error() {
subscribe();
let client = RetryDnsHandle::new(
TestClient {
last_succeed: false,
retries: 1,
attempts: Arc::new(AtomicU16::new(0)),
},
2,
);
let test1 = DnsRequest::from(Message::query());
assert!(block_on(client.send(test1).first_answer()).is_err());
}
}