use std::io;
use std::marker::Unpin;
use futures_util::future;
use futures_util::pin_mut;
use futures_util::future::Either;
use log::debug;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::broadcast;
use tokio::task::spawn;
use tokio_stream::{Stream, StreamExt};
use super::pdu;
use super::payload::{Action, Payload, Timing};
use super::state::State;
pub trait PayloadSource: Clone + Sync + Send + 'static {
type Set: PayloadSet;
type Diff: PayloadDiff;
fn ready(&self) -> bool;
fn notify(&self) -> State;
fn full(&self) -> (State, Self::Set);
fn diff(&self, state: State) -> Option<(State, Self::Diff)>;
fn timing(&self) -> Timing;
}
pub trait PayloadSet: Sync + Send + 'static {
fn next(&mut self) -> Option<&Payload>;
}
pub trait PayloadDiff: Sync + Send + 'static {
fn next(&mut self) -> Option<(&Payload, Action)>;
}
pub trait Socket: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static {
fn update(&self, state: State, reset: bool) {
let _ = (state, reset);
}
}
impl Socket for tokio::net::TcpStream { }
pub struct Server<Listener, Source> {
listener: Listener,
notify: NotifySender,
source: Source,
}
impl<Listener, Source> Server<Listener, Source> {
pub fn new(
listener: Listener, notify: NotifySender, source: Source
) -> Self {
Server { listener, notify, source }
}
pub async fn run<Sock>(mut self) -> Result<(), io::Error>
where
Listener: Stream<Item = Result<Sock, io::Error>> + Unpin,
Sock: Socket,
Source: PayloadSource,
{
while let Some(sock) = self.listener.next().await {
let _ = spawn(
Connection::new(
sock?, self.notify.subscribe(), self.source.clone()
).run()
);
}
Ok(())
}
}
struct Connection<Sock, Source> {
sock: Sock,
notify: NotifyReceiver,
source: Source,
version: Option<u8>,
}
impl<Sock, Source> Connection<Sock, Source> {
fn new(sock: Sock, notify: NotifyReceiver, source: Source) -> Self {
Connection {
sock, notify, source,
version: None,
}
}
fn version(&self) -> u8 {
self.version.unwrap_or(0)
}
}
impl<Sock: Socket, Source: PayloadSource> Connection<Sock, Source> {
async fn run(mut self) -> Result<(), io::Error> {
while let Some(query) = self.recv().await? {
match query {
Query::Serial(state) => {
self.serial(state).await?
}
Query::Reset => {
self.reset().await?
}
Query::Error(err) => {
self.error(err).await?
}
Query::Notify => {
self.notify().await?
}
}
}
Ok(())
}
}
impl<Sock, Source> Connection<Sock, Source>
where Sock: AsyncRead + Unpin {
async fn recv(&mut self) -> Result<Option<Query>, io::Error> {
let header = {
let notify = self.notify.recv();
let header = pdu::Header::read(&mut self.sock);
pin_mut!(notify);
pin_mut!(header);
match future::select(notify, header).await {
Either::Left(_) => return Ok(Some(Query::Notify)),
Either::Right((Ok(header), _)) => header,
Either::Right((Err(err), _)) => {
if err.kind() == io::ErrorKind::UnexpectedEof {
return Ok(None)
}
else {
return Err(err)
}
}
}
};
if let Err(err) = self.check_version(header) {
return Ok(Some(err))
}
match header.pdu() {
pdu::SerialQuery::PDU => {
debug!("RTR: Got serial query.");
match Self::check_length(
header, pdu::SerialQuery::size()
) {
Ok(()) => {
let payload = pdu::SerialQueryPayload::read(
&mut self.sock
).await?;
Ok(Some(Query::Serial(State::from_parts(
header.session(), payload.serial()
))))
}
Err(err) => {
debug!("RTR: ... with bad length");
Ok(Some(err))
}
}
}
pdu::ResetQuery::PDU => {
debug!("RTR: Got reset query.");
match Self::check_length(
header, pdu::ResetQuery::size()
) {
Ok(()) => Ok(Some(Query::Reset)),
Err(err) => {
debug!("RTR: ... with bad length");
Ok(Some(err))
}
}
}
pdu::Error::PDU => {
debug!("RTR: Got error reply.");
Err(io::Error::new(io::ErrorKind::Other, "got error PDU"))
}
pdu => {
debug!("RTR: Got query with PDU {}.", pdu);
Ok(Some(Query::Error(
pdu::Error::new(
header.version(),
3,
header,
"expected Serial Query or Reset Query"
)
)))
}
}
}
fn check_version(
&mut self,
header: pdu::Header
) -> Result<(), Query> {
if let Some(current) = self.version {
if current != header.version() {
Err(Query::Error(
pdu::Error::new(
header.version(),
8,
header,
"version switched during connection"
)
))
}
else {
Ok(())
}
}
else if header.version() > 1 {
Err(Query::Error(
pdu::Error::new(
header.version(),
4,
header,
"only versions 0 and 1 supported"
)
))
}
else {
self.version = Some(header.version());
Ok(())
}
}
fn check_length(header: pdu::Header, expected: u32) -> Result<(), Query> {
if header.length() != expected {
Err(Query::Error(
pdu::Error::new(
header.version(),
3,
header,
"invalid length"
)
))
}
else {
Ok(())
}
}
}
impl<Sock: Socket, Source: PayloadSource> Connection<Sock, Source> {
async fn serial(&mut self, state: State) -> Result<(), io::Error> {
debug!("RTR server: request for serial {}", state.serial());
if !self.source.ready() {
return pdu::Error::new(
self.version(), 2, b"", b"Running initial validation"
).write(&mut self.sock).await;
}
match self.source.diff(state) {
Some((state, mut diff)) => {
debug!("RTR server: source has a diff");
pdu::CacheResponse::new(
self.version(), state,
).write(&mut self.sock).await?;
while let Some((payload, action)) = diff.next() {
pdu::Payload::new(
self.version(), action.into_flags(), payload
).write(&mut self.sock).await?;
}
let timing = self.source.timing();
pdu::EndOfData::new(
self.version(), state, timing
).write(&mut self.sock).await?;
self.sock.flush().await?;
self.sock.update(state, false);
Ok(())
}
None => {
debug!("RTR server: source ain't got no diff for that.");
pdu::CacheReset::new(self.version()).write(
&mut self.sock
).await
}
}
}
async fn reset(&mut self) -> Result<(), io::Error> {
if !self.source.ready() {
return pdu::Error::new(
self.version(), 2, "", b"Running initial validation"
).write(&mut self.sock).await;
}
let (state, mut iter) = self.source.full();
pdu::CacheResponse::new(
self.version(), state
).write(&mut self.sock).await?;
while let Some(payload) = iter.next() {
pdu::Payload::new(
self.version(), Action::Announce.into_flags(), payload
).write(&mut self.sock).await?;
}
let timing = self.source.timing();
pdu::EndOfData::new(
self.version(), state, timing
).write(&mut self.sock).await?;
self.sock.flush().await?;
self.sock.update(state, true);
Ok(())
}
async fn error(
&mut self, err: pdu::Error
) -> Result<(), io::Error> {
err.write(&mut self.sock).await?;
self.sock.flush().await
}
async fn notify(&mut self) -> Result<(), io::Error> {
let state = self.source.notify();
pdu::SerialNotify::new(
self.version(), state
).write(&mut self.sock).await
}
}
enum Query {
Serial(State),
Reset,
Error(pdu::Error),
Notify
}
#[derive(Clone, Debug)]
pub struct NotifySender(broadcast::Sender<()>);
impl NotifySender {
pub fn new() -> NotifySender {
NotifySender(broadcast::channel(1).0)
}
pub fn notify(&mut self) {
let _ = self.0.send(());
}
fn subscribe(&self) -> NotifyReceiver {
NotifyReceiver(Some(self.0.subscribe()))
}
}
impl Default for NotifySender {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct NotifyReceiver(Option<broadcast::Receiver<()>>);
impl NotifyReceiver {
pub async fn recv(&mut self) {
use tokio::sync::broadcast::error::{RecvError, TryRecvError};
if let Some(ref mut rx) = self.0 {
match rx.recv().await {
Ok(()) => {
return;
}
Err(RecvError::Lagged(_)) => {
if let Err(TryRecvError::Closed) = rx.try_recv() {
}
else {
return
}
}
Err(RecvError::Closed) => { }
}
}
self.0 = None;
future::pending().await
}
}