use core::pin::Pin;
use std::collections::HashMap;
use std::sync::Arc;
use futures_util::future::FutureExt;
use futures_util::lock::Mutex;
use futures_util::stream::Stream;
use super::{
ClientHandle,
rc_stream::{RcStream, rc_stream},
};
use crate::{
NetError,
proto::op::{DnsRequest, DnsResponse, Query},
xfer::DnsHandle,
};
#[derive(Clone)]
#[must_use = "queries can only be sent through a ClientHandle"]
pub struct MemoizeClientHandle<H: ClientHandle> {
client: H,
active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
}
impl<H> MemoizeClientHandle<H>
where
H: ClientHandle,
{
pub fn new(client: H) -> Self {
Self {
client,
active_queries: Arc::new(Mutex::new(HashMap::new())),
}
}
async fn inner_send(
request: DnsRequest,
active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
client: H,
) -> impl Stream<Item = Result<DnsResponse, NetError>> {
let query = request.queries.first().expect("no query!").clone();
let mut active_queries = active_queries.lock().await;
if let Some(rc_stream) = active_queries.get(&query) {
return rc_stream.clone();
};
active_queries
.entry(query)
.or_insert_with(|| rc_stream(client.send(request)))
.clone()
}
}
impl<H: ClientHandle> DnsHandle for MemoizeClientHandle<H> {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
type Runtime = H::Runtime;
fn send(&self, request: DnsRequest) -> Self::Response {
Box::pin(
Self::inner_send(
request,
Arc::clone(&self.active_queries),
self.client.clone(),
)
.flatten_stream(),
)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use core::pin::Pin;
use std::sync::Arc;
use futures_util::lock::Mutex;
use futures_util::stream;
use super::*;
use crate::{
proto::{
op::{DnsRequest, DnsResponse, Message, MessageType, OpCode, Query},
rr::RecordType,
},
runtime::TokioRuntimeProvider,
xfer::{DnsHandle, FirstAnswer},
};
use test_support::subscribe;
#[derive(Clone)]
struct TestClient {
i: Arc<Mutex<u16>>,
}
impl DnsHandle for TestClient {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
type Runtime = TokioRuntimeProvider;
fn send(&self, request: DnsRequest) -> Self::Response {
let i = Arc::clone(&self.i);
Box::pin(stream::once(async move {
let mut i = i.lock().await;
let message = Message::new(*i, MessageType::Query, OpCode::Query).into_response();
std::println!(
"sending {}: {}",
*i,
request.queries.first().expect("no query!").clone()
);
*i += 1;
Ok(DnsResponse::from_message(message).unwrap())
}))
}
}
#[test]
fn test_memoized() {
use futures_executor::block_on;
subscribe();
let client = MemoizeClientHandle::new(TestClient {
i: Arc::new(Mutex::new(0)),
});
let mut test1 = Message::query();
test1.add_query(Query::new().set_query_type(RecordType::A).clone());
let mut test2 = Message::query();
test2.add_query(Query::new().set_query_type(RecordType::AAAA).clone());
let result = block_on(client.send(DnsRequest::from(test1.clone())).first_answer()).unwrap();
assert_eq!(result.id, 0);
let result = block_on(client.send(DnsRequest::from(test2.clone())).first_answer()).unwrap();
assert_eq!(result.id, 1);
let result = block_on(client.send(DnsRequest::from(test1)).first_answer()).unwrap();
assert_eq!(result.id, 0);
let result = block_on(client.send(DnsRequest::from(test2)).first_answer()).unwrap();
assert_eq!(result.id, 1);
}
}