use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use futures_util::future::FutureExt;
use futures_util::lock::Mutex;
use futures_util::stream::Stream;
use trust_dns_proto::{
error::ProtoError,
xfer::{DnsHandle, DnsRequest, DnsResponse},
};
use crate::client::rc_stream::{rc_stream, RcStream};
use crate::client::ClientHandle;
use crate::op::Query;
#[derive(Clone)]
#[must_use = "queries can only be sent through a ClientHandle"]
pub struct MemoizeClientHandle<H: ClientHandle> {
client: H,
active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
}
impl<H> MemoizeClientHandle<H>
where
H: ClientHandle,
{
pub fn new(client: H) -> Self {
Self {
client,
active_queries: Arc::new(Mutex::new(HashMap::new())),
}
}
async fn inner_send(
request: DnsRequest,
active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
mut client: H,
) -> impl Stream<Item = Result<DnsResponse, ProtoError>> {
let query = request.queries().first().expect("no query!").clone();
let mut active_queries = active_queries.lock().await;
if let Some(rc_stream) = active_queries.get(&query) {
return rc_stream.clone();
};
active_queries
.entry(query)
.or_insert_with(|| rc_stream(client.send(request)))
.clone()
}
}
impl<H> DnsHandle for MemoizeClientHandle<H>
where
H: ClientHandle,
{
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
type Error = ProtoError;
fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
let request = request.into();
Box::pin(
Self::inner_send(
request,
Arc::clone(&self.active_queries),
self.client.clone(),
)
.flatten_stream(),
)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use std::pin::Pin;
use std::sync::Arc;
use futures::lock::Mutex;
use futures::*;
use trust_dns_proto::{
error::ProtoError,
xfer::{DnsHandle, DnsRequest, DnsResponse},
};
use crate::client::*;
use crate::op::*;
use crate::rr::*;
use trust_dns_proto::xfer::FirstAnswer;
#[derive(Clone)]
struct TestClient {
i: Arc<Mutex<u16>>,
}
impl DnsHandle for TestClient {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
type Error = ProtoError;
fn send<R: Into<DnsRequest> + Send + 'static>(&mut self, request: R) -> Self::Response {
let i = Arc::clone(&self.i);
let future = async {
let i = i;
let request = request;
let mut message = Message::new();
let mut i = i.lock().await;
message.set_id(*i);
println!(
"sending {}: {}",
*i,
request.into().queries().first().expect("no query!").clone()
);
*i += 1;
Ok(message.into())
};
Box::pin(stream::once(future))
}
}
#[test]
fn test_memoized() {
use futures::executor::block_on;
let mut client = MemoizeClientHandle::new(TestClient {
i: Arc::new(Mutex::new(0)),
});
let mut test1 = Message::new();
test1.add_query(Query::new().set_query_type(RecordType::A).clone());
let mut test2 = Message::new();
test2.add_query(Query::new().set_query_type(RecordType::AAAA).clone());
let result = block_on(client.send(test1.clone()).first_answer())
.ok()
.unwrap();
assert_eq!(result.id(), 0);
let result = block_on(client.send(test2.clone()).first_answer())
.ok()
.unwrap();
assert_eq!(result.id(), 1);
let result = block_on(client.send(test1).first_answer()).ok().unwrap();
assert_eq!(result.id(), 0);
let result = block_on(client.send(test2).first_answer()).ok().unwrap();
assert_eq!(result.id(), 1);
}
}