use super::request::{
ComposeRequest, ComposeRequestMulti, Error, GetResponse,
GetResponseMulti, SendRequest, SendRequestMulti,
};
use crate::base::iana::{Rcode, Rtype};
use crate::base::message::Message;
use crate::base::message_builder::StreamTarget;
use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive};
use crate::base::{ParsedName, Serial};
use crate::rdata::AllRecordData;
use crate::utils::config::DefMinMax;
use bytes::{Bytes, BytesMut};
use core::cmp;
use octseq::Octets;
use std::boxed::Box;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::vec::Vec;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot};
use tokio::time::sleep;
use tracing::trace;
const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(19),
Duration::from_millis(1),
Duration::from_secs(600),
);
const IDLE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(10),
Duration::ZERO,
Duration::from_secs(3600),
);
const DEF_CHAN_CAP: usize = 8;
const READ_REPLY_CHAN_CAP: usize = 8;
#[derive(Clone, Debug)]
pub struct Config {
response_timeout: Duration,
single_response_timeout: Duration,
streaming_response_timeout: Duration,
idle_timeout: Duration,
}
impl Config {
pub fn new() -> Self {
Default::default()
}
pub fn response_timeout(&self) -> Duration {
self.response_timeout
}
pub fn set_response_timeout(&mut self, timeout: Duration) {
self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
self.streaming_response_timeout = self.response_timeout;
}
pub fn streaming_response_timeout(&self) -> Duration {
self.streaming_response_timeout
}
pub fn set_streaming_response_timeout(&mut self, timeout: Duration) {
self.streaming_response_timeout = RESPONSE_TIMEOUT.limit(timeout);
}
pub fn idle_timeout(&self) -> Duration {
self.idle_timeout
}
pub fn set_idle_timeout(&mut self, timeout: Duration) {
self.idle_timeout = IDLE_TIMEOUT.limit(timeout)
}
}
impl Default for Config {
fn default() -> Self {
Self {
response_timeout: RESPONSE_TIMEOUT.default(),
single_response_timeout: RESPONSE_TIMEOUT.default(),
streaming_response_timeout: RESPONSE_TIMEOUT.default(),
idle_timeout: IDLE_TIMEOUT.default(),
}
}
}
#[derive(Debug)]
pub struct Connection<Req, ReqMulti> {
sender: mpsc::Sender<ChanReq<Req, ReqMulti>>,
}
impl<Req, ReqMulti> Connection<Req, ReqMulti> {
pub fn new<Stream>(
stream: Stream,
) -> (Self, Transport<Stream, Req, ReqMulti>) {
Self::with_config(stream, Default::default())
}
pub fn with_config<Stream>(
stream: Stream,
config: Config,
) -> (Self, Transport<Stream, Req, ReqMulti>) {
let (sender, transport) = Transport::new(stream, config);
(Self { sender }, transport)
}
}
impl<Req, ReqMulti> Connection<Req, ReqMulti>
where
Req: ComposeRequest + 'static,
ReqMulti: ComposeRequestMulti + 'static,
{
async fn handle_request_impl(
self,
msg: Req,
) -> Result<Message<Bytes>, Error> {
let (sender, receiver) = oneshot::channel();
let sender = ReplySender::Single(Some(sender));
let msg = ReqSingleMulti::Single(msg);
let req = ChanReq { sender, msg };
self.sender.send(req).await.map_err(|_| {
Error::ConnectionClosed
})?;
receiver.await.map_err(|_| Error::StreamReceiveError)?
}
async fn handle_streaming_request_impl(
self,
msg: ReqMulti,
sender: mpsc::Sender<Result<Option<Message<Bytes>>, Error>>,
) -> Result<(), Error> {
let reply_sender = ReplySender::Stream(sender);
let msg = ReqSingleMulti::Multi(msg);
let req = ChanReq {
sender: reply_sender,
msg,
};
self.sender.send(req).await.map_err(|_| {
Error::ConnectionClosed
})?;
Ok(())
}
pub fn get_request(&self, request_msg: Req) -> Request {
Request {
fut: Box::pin(self.clone().handle_request_impl(request_msg)),
}
}
fn get_streaming_request(&self, request_msg: ReqMulti) -> RequestMulti {
let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
RequestMulti {
stream: receiver,
fut: Some(Box::pin(
self.clone()
.handle_streaming_request_impl(request_msg, sender),
)),
}
}
}
impl<Req, ReqMulti> Clone for Connection<Req, ReqMulti> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
}
}
}
impl<Req, ReqMulti> SendRequest<Req> for Connection<Req, ReqMulti>
where
Req: ComposeRequest + 'static,
ReqMulti: ComposeRequestMulti + Debug + Send + Sync + 'static,
{
fn send_request(
&self,
request_msg: Req,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(self.get_request(request_msg))
}
}
impl<Req, ReqMulti> SendRequestMulti<ReqMulti> for Connection<Req, ReqMulti>
where
Req: ComposeRequest + Debug + Send + Sync + 'static,
ReqMulti: ComposeRequestMulti + 'static,
{
fn send_request(
&self,
request_msg: ReqMulti,
) -> Box<dyn GetResponseMulti + Send + Sync> {
Box::new(self.get_streaming_request(request_msg))
}
}
pub struct Request {
fut: Pin<
Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
>,
}
impl Request {
async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
(&mut self.fut).await
}
}
impl GetResponse for Request {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(self.get_response_impl())
}
}
impl Debug for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Request")
.field("fut", &format_args!("_"))
.finish()
}
}
pub struct RequestMulti {
stream: mpsc::Receiver<Result<Option<Message<Bytes>>, Error>>,
#[allow(clippy::type_complexity)]
fut: Option<
Pin<Box<dyn Future<Output = Result<(), Error>> + Send + Sync>>,
>,
}
impl RequestMulti {
async fn get_response_impl(
&mut self,
) -> Result<Option<Message<Bytes>>, Error> {
if self.fut.is_some() {
let fut = self.fut.take().expect("Some expected");
fut.await?;
}
self.stream
.recv()
.await
.ok_or(Error::ConnectionClosed)
.map_err(|_| Error::ConnectionClosed)?
}
}
impl GetResponseMulti for RequestMulti {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
let fut = self.get_response_impl();
Box::pin(fut)
}
}
impl Debug for RequestMulti {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Request")
.field("fut", &format_args!("_"))
.finish()
}
}
#[derive(Debug)]
pub struct Transport<Stream, Req, ReqMulti> {
stream: Stream,
config: Config,
receiver: mpsc::Receiver<ChanReq<Req, ReqMulti>>,
}
#[derive(Debug)]
enum ReplySender {
Single(Option<oneshot::Sender<ChanResp>>),
Stream(mpsc::Sender<Result<Option<Message<Bytes>>, Error>>),
}
impl ReplySender {
async fn send(&mut self, resp: ChanResp) -> Result<(), ()> {
match self {
ReplySender::Single(sender) => match sender.take() {
Some(sender) => sender.send(resp).map_err(|_| ()),
None => Err(()),
},
ReplySender::Stream(sender) => {
sender.send(resp.map(Some)).await.map_err(|_| ())
}
}
}
async fn send_eof(&mut self) -> Result<(), ()> {
match self {
ReplySender::Single(_) => {
panic!("cannot send EOF for Single");
}
ReplySender::Stream(sender) => {
sender.send(Ok(None)).await.map_err(|_| ())
}
}
}
fn is_stream(&self) -> bool {
matches!(self, Self::Stream(_))
}
}
#[derive(Debug)]
enum ReqSingleMulti<Req, ReqMulti> {
Single(Req),
Multi(ReqMulti),
}
#[derive(Debug)]
struct ChanReq<Req, ReqMulti> {
msg: ReqSingleMulti<Req, ReqMulti>,
sender: ReplySender,
}
type ChanResp = Result<Message<Bytes>, Error>;
struct Status {
state: ConnState,
send_keepalive: bool,
idle_timeout: Duration,
}
enum ConnState {
Active(Option<Instant>),
Idle(Instant),
IdleTimeout,
ReadError(Error),
ReadTimeout,
WriteError(Error),
}
impl std::fmt::Display for ConnState {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ConnState::Active(instant) => f.write_fmt(format_args!(
"Active (since {}s ago)",
instant
.map(|v| Instant::now().duration_since(v).as_secs())
.unwrap_or_default()
)),
ConnState::Idle(instant) => f.write_fmt(format_args!(
"Idle (since {}s ago)",
Instant::now().duration_since(*instant).as_secs()
)),
ConnState::IdleTimeout => f.write_str("IdleTimeout"),
ConnState::ReadError(err) => {
f.write_fmt(format_args!("ReadError: {err}"))
}
ConnState::ReadTimeout => f.write_str("ReadTimeout"),
ConnState::WriteError(err) => {
f.write_fmt(format_args!("WriteError: {err}"))
}
}
}
}
#[derive(Debug)]
enum XFRState {
AXFRInit,
AXFRFirstSoa(Serial),
IXFRInit,
IXFRFirstSoa(Serial),
IXFRFirstDiffSoa(Serial),
IXFRSecondDiffSoa(Serial),
Done,
Error,
}
impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti> {
fn new(
stream: Stream,
config: Config,
) -> (mpsc::Sender<ChanReq<Req, ReqMulti>>, Self) {
let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
(
sender,
Self {
config,
stream,
receiver,
},
)
}
}
impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti>
where
Stream: AsyncRead + AsyncWrite,
Req: ComposeRequest,
ReqMulti: ComposeRequestMulti,
{
pub async fn run(mut self) {
let (reply_sender, mut reply_receiver) =
mpsc::channel::<Message<Bytes>>(READ_REPLY_CHAN_CAP);
let (read_stream, mut write_stream) = tokio::io::split(self.stream);
let reader_fut = Self::reader(read_stream, reply_sender);
tokio::pin!(reader_fut);
let mut status = Status {
state: ConnState::Active(None),
idle_timeout: self.config.idle_timeout,
send_keepalive: true,
};
let mut query_vec =
Queries::<(ChanReq<Req, ReqMulti>, Option<XFRState>)>::new();
let mut reqmsg: Option<Vec<u8>> = None;
let mut reqmsg_offset = 0;
loop {
let opt_timeout = match status.state {
ConnState::Active(opt_instant) => {
if let Some(instant) = opt_instant {
let elapsed = instant.elapsed();
if elapsed > self.config.response_timeout {
Self::error(
Error::StreamReadTimeout,
&mut query_vec,
)
.await;
status.state = ConnState::ReadTimeout;
break;
}
Some(self.config.response_timeout - elapsed)
} else {
None
}
}
ConnState::Idle(instant) => {
let elapsed = instant.elapsed();
if elapsed >= status.idle_timeout {
status.state = ConnState::IdleTimeout;
break;
}
Some(status.idle_timeout - elapsed)
}
ConnState::IdleTimeout
| ConnState::ReadError(_)
| ConnState::WriteError(_) => None, ConnState::ReadTimeout => {
panic!("should not be in loop with ReadTimeout");
}
};
let timeout = match opt_timeout {
Some(timeout) => timeout,
None =>
{
self.config.response_timeout
}
};
let sleep_fut = sleep(timeout);
let recv_fut = self.receiver.recv();
let (do_write, msg) = match &reqmsg {
None => {
let msg: &[u8] = &[];
(false, msg)
}
Some(msg) => {
let msg: &[u8] = msg;
(true, msg)
}
};
tokio::select! {
biased;
res = &mut reader_fut => {
while let Ok(answer) = reply_receiver.try_recv() {
Self::demux_reply(answer, &mut status, &mut query_vec).await;
}
match res {
Ok(_) =>
panic!("reader terminated"),
Err(error) => {
Self::error(error.clone(), &mut query_vec).await;
status.state = ConnState::ReadError(error);
break
}
}
}
opt_answer = reply_receiver.recv() => {
let answer = opt_answer.expect("reader died?");
Self::demux_reply(answer, &mut status, &mut query_vec).await;
}
res = write_stream.write(&msg[reqmsg_offset..]),
if do_write => {
match res {
Err(error) => {
let error =
Error::StreamWriteError(Arc::new(error));
Self::error(error.clone(), &mut query_vec).await;
status.state =
ConnState::WriteError(error);
break;
}
Ok(len) => {
reqmsg_offset += len;
if reqmsg_offset >= msg.len() {
reqmsg = None;
reqmsg_offset = 0;
}
}
}
}
res = recv_fut, if !do_write => {
match res {
Some(req) => {
if req.sender.is_stream() {
self.config.response_timeout =
self.config.streaming_response_timeout;
} else {
self.config.response_timeout =
self.config.single_response_timeout;
}
Self::insert_req(
req, &mut status, &mut reqmsg, &mut query_vec
);
}
None => {
break;
}
}
}
_ = sleep_fut => {
}
}
match status.state {
ConnState::Active(_) | ConnState::Idle(_) => {
}
ConnState::IdleTimeout => break,
ConnState::ReadError(_)
| ConnState::ReadTimeout
| ConnState::WriteError(_) => {
panic!("Should not be here");
}
}
}
trace!("Closing TCP connecting in state: {}", status.state);
_ = write_stream.shutdown().await;
}
async fn reader(
mut sock: tokio::io::ReadHalf<Stream>,
sender: mpsc::Sender<Message<Bytes>>,
) -> Result<(), Error> {
loop {
let read_res = sock.read_u16().await;
let len = match read_res {
Ok(len) => len,
Err(error) => {
return Err(Error::StreamReadError(Arc::new(error)));
}
} as usize;
let mut buf = BytesMut::with_capacity(len);
loop {
let curlen = buf.len();
if curlen >= len {
if curlen > len {
panic!(
"reader: got too much data {curlen}, expetect {len}");
}
break;
}
let read_res = sock.read_buf(&mut buf).await;
match read_res {
Ok(readlen) => {
if readlen == 0 {
return Err(Error::StreamUnexpectedEndOfData);
}
}
Err(error) => {
return Err(Error::StreamReadError(Arc::new(error)));
}
};
}
let reply_message = Message::<Bytes>::from_octets(buf.into());
match reply_message {
Ok(answer) => {
sender
.send(answer)
.await
.expect("can't send reply to run");
}
Err(_) => {
return Err(Error::ShortMessage);
}
}
}
}
async fn error(
error: Error,
query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
) {
for (mut req, _) in query_vec.drain() {
_ = req.sender.send(Err(error.clone())).await;
}
}
fn handle_opts<Octs: Octets + AsRef<[u8]>>(
opts: &OptRecord<Octs>,
status: &mut Status,
) {
for option in opts.opt().iter().flatten() {
if let AllOptData::TcpKeepalive(tcpkeepalive) = option {
Self::handle_keepalive(tcpkeepalive, status);
}
}
}
async fn demux_reply(
answer: Message<Bytes>,
status: &mut Status,
query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
) {
if let Some(opts) = answer.opt() {
Self::handle_opts(&opts, status);
};
status.state = ConnState::Active(Some(Instant::now()));
let id = answer.header().id();
let (mut req, mut opt_xfr_data) = match query_vec.try_remove(id) {
Some(req) => req,
None => {
return;
}
};
let mut send_eof = false;
let answer = if match &req.msg {
ReqSingleMulti::Single(msg) => msg.is_answer(answer.for_slice()),
ReqSingleMulti::Multi(msg) => {
let xfr_data =
opt_xfr_data.expect("xfr_data should be present");
let (eof, xfr_data, is_answer) =
check_stream(msg, xfr_data, &answer);
send_eof = eof;
opt_xfr_data = Some(xfr_data);
is_answer
}
} {
Ok(answer)
} else {
Err(Error::WrongReplyForQuery)
};
_ = req.sender.send(answer).await;
if req.sender.is_stream() {
if send_eof {
_ = req.sender.send_eof().await;
} else {
query_vec.insert_at(id, (req, opt_xfr_data));
}
}
if query_vec.is_empty() {
status.state = ConnState::Active(None);
status.state = if status.idle_timeout.is_zero() {
ConnState::IdleTimeout
} else {
ConnState::Idle(Instant::now())
}
}
}
fn insert_req(
mut req: ChanReq<Req, ReqMulti>,
status: &mut Status,
reqmsg: &mut Option<Vec<u8>>,
query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
) {
match &status.state {
ConnState::Active(timer) => {
if timer.is_none() {
status.state = ConnState::Active(Some(Instant::now()));
}
}
ConnState::Idle(_) => {
status.state = ConnState::Active(Some(Instant::now()));
}
ConnState::IdleTimeout => {
_ = req.sender.send(Err(Error::StreamIdleTimeout));
return;
}
ConnState::ReadError(error) => {
_ = req.sender.send(Err(error.clone()));
return;
}
ConnState::ReadTimeout => {
_ = req.sender.send(Err(Error::StreamReadTimeout));
return;
}
ConnState::WriteError(error) => {
_ = req.sender.send(Err(error.clone()));
return;
}
}
let xfr_data = match &req.msg {
ReqSingleMulti::Single(_) => None,
ReqSingleMulti::Multi(msg) => {
let qtype = match msg.to_message().and_then(|m| {
m.sole_question()
.map_err(|_| Error::MessageParseError)
.map(|q| q.qtype())
}) {
Ok(msg) => msg,
Err(e) => {
_ = req.sender.send(Err(e));
return;
}
};
if qtype == Rtype::AXFR {
Some(XFRState::AXFRInit)
} else if qtype == Rtype::IXFR {
Some(XFRState::IXFRInit)
} else {
_ = req.sender.send(Err(Error::FormError));
return;
}
}
};
let (index, (req, _)) = match query_vec.insert((req, xfr_data)) {
Ok(res) => res,
Err((mut req, _)) => {
_ = req
.sender
.send(Err(Error::StreamTooManyOutstandingQueries));
return;
}
};
let hdr = match &mut req.msg {
ReqSingleMulti::Single(msg) => msg.header_mut(),
ReqSingleMulti::Multi(msg) => msg.header_mut(),
};
hdr.set_id(index);
if status.send_keepalive
&& match &mut req.msg {
ReqSingleMulti::Single(msg) => {
msg.add_opt(&TcpKeepalive::new(None)).is_ok()
}
ReqSingleMulti::Multi(msg) => {
msg.add_opt(&TcpKeepalive::new(None)).is_ok()
}
}
{
status.send_keepalive = false;
}
match Self::convert_query(&req.msg) {
Ok(msg) => {
*reqmsg = Some(msg);
}
Err(err) => {
if let Some((mut req, _)) = query_vec.try_remove(index) {
_ = req.sender.send(Err(err));
}
}
}
}
fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) {
if let Some(value) = opt_value.timeout() {
let value_dur = Duration::from(value);
status.idle_timeout = value_dur;
}
}
fn convert_query(
msg: &ReqSingleMulti<Req, ReqMulti>,
) -> Result<Vec<u8>, Error> {
match msg {
ReqSingleMulti::Single(msg) => {
let mut target = StreamTarget::new_vec();
msg.append_message(&mut target)
.map_err(|_| Error::StreamLongMessage)?;
Ok(target.into_target())
}
ReqSingleMulti::Multi(msg) => {
let target = StreamTarget::new_vec();
let target = msg
.append_message(target)
.map_err(|_| Error::StreamLongMessage)?;
Ok(target.finish().into_target())
}
}
}
}
fn check_stream<CRM>(
msg: &CRM,
mut xfr_state: XFRState,
answer: &Message<Bytes>,
) -> (bool, XFRState, bool)
where
CRM: ComposeRequestMulti,
{
match xfr_state {
XFRState::AXFRInit | XFRState::IXFRInit => {
if !msg.is_answer(answer.for_slice()) {
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
}
XFRState::AXFRFirstSoa(_)
| XFRState::IXFRFirstSoa(_)
| XFRState::IXFRFirstDiffSoa(_)
| XFRState::IXFRSecondDiffSoa(_) =>
{}
XFRState::Done => {
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
XFRState::Error =>
{
return (false, xfr_state, false)
}
}
if answer.header().rcode() != Rcode::NOERROR {
if !msg.is_answer(answer.for_slice()) {
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
return (true, xfr_state, true);
}
let ans_sec = match answer.answer() {
Ok(ans) => ans,
Err(_) => {
xfr_state = XFRState::Error;
return (true, xfr_state, false);
}
};
for rr in
ans_sec.into_records::<AllRecordData<Bytes, ParsedName<Bytes>>>()
{
let rr = match rr {
Ok(rr) => rr,
Err(_) => {
xfr_state = XFRState::Error;
return (true, xfr_state, false);
}
};
match xfr_state {
XFRState::AXFRInit => {
if let AllRecordData::Soa(soa) = rr.data() {
xfr_state = XFRState::AXFRFirstSoa(soa.serial());
continue;
}
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
XFRState::AXFRFirstSoa(serial) => {
if let AllRecordData::Soa(soa) = rr.data() {
if serial == soa.serial() {
xfr_state = XFRState::Done;
continue;
}
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
}
XFRState::IXFRInit => {
if let AllRecordData::Soa(soa) = rr.data() {
xfr_state = XFRState::IXFRFirstSoa(soa.serial());
continue;
}
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
XFRState::IXFRFirstSoa(serial) => {
if let AllRecordData::Soa(soa) = rr.data() {
if serial == soa.serial() {
xfr_state = XFRState::Done;
continue;
}
xfr_state = XFRState::IXFRFirstDiffSoa(serial);
continue;
}
xfr_state = XFRState::AXFRFirstSoa(serial);
}
XFRState::IXFRFirstDiffSoa(serial) => {
if let AllRecordData::Soa(_) = rr.data() {
xfr_state = XFRState::IXFRSecondDiffSoa(serial);
continue;
}
}
XFRState::IXFRSecondDiffSoa(serial) => {
if let AllRecordData::Soa(soa) = rr.data() {
if serial == soa.serial() {
xfr_state = XFRState::Done;
continue;
}
xfr_state = XFRState::IXFRFirstDiffSoa(serial);
continue;
}
}
XFRState::Done => {
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
XFRState::Error => panic!("should not be here"),
}
}
match xfr_state {
XFRState::AXFRInit | XFRState::IXFRInit => {
xfr_state = XFRState::Error;
return (false, xfr_state, false);
}
XFRState::AXFRFirstSoa(_)
| XFRState::IXFRFirstDiffSoa(_)
| XFRState::IXFRSecondDiffSoa(_) =>
{}
XFRState::IXFRFirstSoa(_) => {
xfr_state = XFRState::Done;
return (true, xfr_state, true);
}
XFRState::Done => return (true, xfr_state, true),
XFRState::Error => unreachable!(),
}
(false, xfr_state, true)
}
#[derive(Clone, Debug)]
struct Queries<T> {
count: usize,
curr: usize,
vec: Vec<Option<T>>,
}
impl<T> Queries<T> {
fn new() -> Self {
Self {
count: 0,
curr: 0,
vec: Vec::new(),
}
}
fn is_empty(&self) -> bool {
self.count == 0
}
fn insert(&mut self, req: T) -> Result<(u16, &mut T), T> {
if 2 * self.count > u16::MAX as usize {
return Err(req);
}
let idx = if self.vec.len() >= 2 * self.count {
let mut found = None;
for idx in self.curr..self.vec.len() {
if self.vec[idx].is_none() {
found = Some(idx);
break;
}
}
found
} else {
None
};
let idx = match idx {
Some(idx) => {
self.vec[idx] = Some(req);
idx
}
None => {
let idx = self.vec.len();
self.vec.push(Some(req));
idx
}
};
self.count += 1;
if idx == self.curr {
self.curr += 1;
}
let req = self.vec[idx].as_mut().expect("no inserted item?");
let idx = u16::try_from(idx).expect("query vec too large");
Ok((idx, req))
}
fn insert_at(&mut self, id: u16, req: T) {
let id = id as usize;
self.vec[id] = Some(req);
self.count += 1;
if id == self.curr {
self.curr += 1;
}
}
fn try_remove(&mut self, index: u16) -> Option<T> {
let res = self.vec.get_mut(usize::from(index))?.take()?;
self.count = self.count.saturating_sub(1);
self.curr = cmp::min(self.curr, index.into());
Some(res)
}
fn drain(&mut self) -> impl Iterator<Item = T> + '_ {
let res = self.vec.drain(..).flatten(); self.count = 0;
self.curr = 0;
res
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[allow(clippy::needless_range_loop)]
fn queries_insert_remove() {
let mut idxs = [None; 20];
let mut queries = Queries::new();
for i in 0..12 {
let (idx, item) = queries.insert(i).expect("test failed");
idxs[i] = Some(idx);
assert_eq!(i, *item);
}
assert_eq!(queries.count, 12);
assert_eq!(queries.vec.iter().flatten().count(), 12);
for i in [1, 2, 3, 4, 7, 9] {
let item = queries
.try_remove(idxs[i].expect("test failed"))
.expect("test failed");
assert_eq!(i, item);
idxs[i] = None;
}
assert_eq!(queries.count, 6);
assert_eq!(queries.vec.iter().flatten().count(), 6);
for i in 12..20 {
let (idx, item) = queries.insert(i).expect("test failed");
idxs[i] = Some(idx);
assert_eq!(i, *item);
}
assert_eq!(queries.count, 14);
assert_eq!(queries.vec.iter().flatten().count(), 14);
for i in 0..20 {
if let Some(idx) = idxs[i] {
let item = queries.try_remove(idx).expect("test failed");
assert_eq!(i, item);
}
}
assert_eq!(queries.count, 0);
assert_eq!(queries.vec.iter().flatten().count(), 0);
}
#[test]
fn queries_overrun() {
let mut queries = Queries::new();
for i in 0..usize::from(u16::MAX) * 2 {
let _ = queries.insert(i);
}
}
}