astra 0.4.0

A blocking HTTP server built on hyper.
Documentation
use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::mem::MaybeUninit;
use std::net::{self as sys, Shutdown};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};

use hyper::rt::ReadBufCursor;
use mio::{Events, Token};

#[derive(Clone)]
pub struct Reactor {
    shared: Arc<Shared>,
}

struct Shared {
    registry: mio::Registry,
    token: AtomicUsize,
    sources: Mutex<HashMap<Token, Arc<Source>>>,
}

impl Reactor {
    pub fn new() -> io::Result<Self> {
        let poll = mio::Poll::new()?;
        let shared = Arc::new(Shared {
            token: AtomicUsize::new(0),
            registry: poll.registry().try_clone()?,
            sources: Mutex::new(HashMap::with_capacity(64)),
        });

        std::thread::Builder::new()
            .name("astra-reactor".to_owned())
            .spawn({
                let shared = shared.clone();
                move || shared.run(poll)
            })?;

        Ok(Reactor { shared })
    }

    pub fn register(&self, sys: sys::TcpStream) -> io::Result<TcpStream> {
        sys.set_nonblocking(true)?;
        let mut sys = mio::net::TcpStream::from_std(sys);
        let token = Token(self.shared.token.fetch_add(1, Ordering::Relaxed));

        self.shared.registry.register(
            &mut sys,
            token,
            mio::Interest::READABLE | mio::Interest::WRITABLE,
        )?;

        let source = Arc::new(Source {
            token,
            interest: Default::default(),
            triggered: Default::default(),
        });

        {
            let mut sources = self.shared.sources.lock().unwrap();
            sources.insert(token, source.clone());
        }

        Ok(TcpStream {
            sys,
            source,
            reactor: self.clone(),
        })
    }

    fn poll_ready(
        &self,
        source: &Source,
        direction: usize,
        cx: &Context<'_>,
    ) -> Poll<io::Result<()>> {
        if source.triggered[direction].load(Ordering::Acquire) {
            return Poll::Ready(Ok(()));
        }

        {
            let mut interest = source.interest.lock().unwrap();

            match &mut interest[direction] {
                Some(existing) if existing.will_wake(cx.waker()) => {}
                _ => {
                    interest[direction] = Some(cx.waker().clone());
                }
            }
        }

        // check if anything changed while we were registering
        // our waker
        if source.triggered[direction].load(Ordering::Acquire) {
            return Poll::Ready(Ok(()));
        }

        Poll::Pending
    }

    fn clear_trigger(&self, source: &Source, direction: usize) {
        source.triggered[direction].store(false, Ordering::Release);
    }
}

impl Shared {
    fn run(&self, mut poll: mio::Poll) -> io::Result<()> {
        let mut events = Events::with_capacity(64);
        let mut wakers = Vec::new();

        loop {
            if let Err(err) = self.poll(&mut poll, &mut events, &mut wakers) {
                log::warn!("Failed to poll reactor: {}", err);
            }

            events.clear();
        }
    }

    fn poll(
        &self,
        poll: &mut mio::Poll,
        events: &mut Events,
        wakers: &mut Vec<Waker>,
    ) -> io::Result<()> {
        if let Err(err) = poll.poll(events, None) {
            if err.kind() != io::ErrorKind::Interrupted {
                return Err(err);
            }

            return Ok(());
        }

        for event in events.iter() {
            let source = {
                let sources = self.sources.lock().unwrap();
                match sources.get(&event.token()) {
                    Some(source) => source.clone(),
                    None => continue,
                }
            };

            let mut interest = source.interest.lock().unwrap();

            if event.is_readable() {
                if let Some(waker) = interest[direction::READ].take() {
                    wakers.push(waker);
                }

                source.triggered[direction::READ].store(true, Ordering::Release);
            }

            if event.is_writable() {
                if let Some(waker) = interest[direction::WRITE].take() {
                    wakers.push(waker);
                }

                source.triggered[direction::WRITE].store(true, Ordering::Release);
            }
        }

        for waker in wakers.drain(..) {
            waker.wake();
        }

        Ok(())
    }
}

mod direction {
    pub const READ: usize = 0;
    pub const WRITE: usize = 1;
}

struct Source {
    interest: Mutex<[Option<Waker>; 2]>,
    triggered: [AtomicBool; 2],
    token: Token,
}

pub struct TcpStream {
    pub sys: mio::net::TcpStream,
    reactor: Reactor,
    source: Arc<Source>,
}

impl TcpStream {
    pub fn poll_io<T>(
        &self,
        direction: usize,
        mut f: impl FnMut() -> io::Result<T>,
        cx: &Context<'_>,
    ) -> Poll<io::Result<T>> {
        loop {
            if self
                .reactor
                .poll_ready(&self.source, direction, cx)?
                .is_pending()
            {
                return Poll::Pending;
            }

            match f() {
                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
                    self.reactor.clear_trigger(&self.source, direction);
                }
                val => return Poll::Ready(val),
            }
        }
    }
}

impl hyper::rt::Read for TcpStream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        mut buf: ReadBufCursor<'_>,
    ) -> Poll<io::Result<()>> {
        let initialized = unsafe {
            let buf = buf.as_mut();

            // Zero the buffer.
            std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len());

            // Safety: The buffer was initialized above.
            &mut *(buf as *mut [MaybeUninit<u8>] as *mut [u8])
        };

        self.poll_io(direction::READ, || (&self.sys).read(initialized), cx)
            .map_ok(|n| {
                // Safety: The entire buffer was initialized above.
                unsafe { buf.advance(n) };
            })
    }
}

impl hyper::rt::Write for TcpStream {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        self.poll_io(direction::WRITE, || (&self.sys).write(buf), cx)
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.poll_io(direction::WRITE, || (&self.sys).flush(), cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
        Poll::Ready(self.sys.shutdown(Shutdown::Write))
    }
}

impl Drop for TcpStream {
    fn drop(&mut self) {
        let mut sources = self.reactor.shared.sources.lock().unwrap();
        let _ = sources.remove(&self.source.token);
        let _ = self.reactor.shared.registry.deregister(&mut self.sys);
    }
}