use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::future::Future;
use std::sync::mpsc::Receiver;
use std::task::Poll;
use std::task::Context;
use std::pin::Pin;
use std::collections::HashMap;
use std::cell::Cell;
use std::io::Error;
use pasts;
use async_std;
use async_std::prelude::*;
use async_std::net::TcpStream;
enum AsyncMsg {
Quit,
NewTask(Receiver<Message>, WebserverTask),
OldTask,
}
type WebserverTask = Box<dyn Future<Output = AsyncMsg> + Send>;
async fn slice_select<T>(
tasks: &mut Vec<Box<dyn Future<Output = T> + Send>>,
) -> T
{
struct SliceSelect<'a, T> {
tasks: &'a mut Vec<Box<dyn Future<Output = T> + Send>>,
}
impl<'a, T> Future for SliceSelect<'a, T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
for future_id in 0..self.tasks.len() {
let mut future = unsafe {
Pin::new_unchecked(self.tasks[future_id].as_mut())
};
match future.as_mut().poll(cx) {
Poll::Ready(ret) => {
let _ = self.tasks.remove(future_id);
return Poll::Ready(ret);
},
Poll::Pending => {}
}
}
Poll::Pending
}
}
SliceSelect { tasks }.await
}
fn async_thread_main_future(recv: Receiver<Message>) -> AsyncMsg {
match recv.recv().unwrap() {
Message::NewJob(task) => AsyncMsg::NewTask(recv, task),
Message::Terminate => AsyncMsg::Quit,
}
}
async fn async_thread_main(recv: Receiver<Message>, num_tasks: Arc<AtomicUsize>) {
let mut tasks: Vec<WebserverTask> = vec![];
tasks.push(Box::new(pasts::spawn_blocking(move ||
async_thread_main_future(recv)
)));
loop {
match slice_select(&mut tasks).await {
AsyncMsg::NewTask(recv, task) => {
tasks.push(Box::new(pasts::spawn_blocking(move ||
async_thread_main_future(recv)
)));
tasks.push(task)
}
AsyncMsg::OldTask => {
num_tasks.fetch_sub(1, Ordering::Relaxed);
}
AsyncMsg::Quit => {
break
}
}
}
}
fn thread_main(recv: Receiver<Message>, num_tasks: Arc<AtomicUsize>) {
<pasts::ThreadInterrupt as pasts::Interrupt>::block_on(
async_thread_main(recv, num_tasks)
);
}
struct Thread {
num_tasks: Arc<AtomicUsize>,
join: Option<std::thread::JoinHandle<()>>,
sender: std::sync::mpsc::Sender<Message>,
}
impl Thread {
pub fn new() -> Self {
let (sender, receiver) = std::sync::mpsc::channel();
let num_tasks = Arc::new(AtomicUsize::new(0));
let thread_num_tasks = Arc::clone(&num_tasks);
let join = Some(std::thread::spawn(move ||
thread_main(receiver, thread_num_tasks)
));
Thread {
num_tasks, join, sender,
}
}
pub fn tasks(&self) -> usize {
self.num_tasks.load(Ordering::Relaxed)
}
pub fn send<T>(&self, future: T)
where T: Future<Output = AsyncMsg> + Send + 'static
{
self.num_tasks.fetch_add(1, Ordering::Relaxed);
self.sender.send(Message::NewJob(Box::new(future))).unwrap();
}
}
impl Drop for Thread {
fn drop(&mut self) {
self.sender.send(Message::Terminate).unwrap();
if let Some(thread) = self.join.take() {
thread.join().unwrap();
}
}
}
async fn async_main(web: Arc<Web>) {
let listener = async_std::net::TcpListener::bind("127.0.0.1:7878")
.await
.unwrap();
let mut threads = vec![];
let mut incoming = listener.incoming();
for _ in 0..4 {
threads.push(Thread::new());
}
while let Some(stream) = incoming.next().await {
let mut thread_id = 0;
let mut thread_tasks = threads[0].tasks();
for id in 1..threads.len() {
let n_tasks = threads[id].tasks();
if n_tasks < thread_tasks {
thread_id = id;
thread_tasks = n_tasks;
}
}
let stream = stream.unwrap();
let stream = Arc::new(stream);
let future = handle_connection(stream, Arc::clone(&web));
threads[thread_id].send(future);
}
}
type ResourceGenerator = Box<dyn Fn(Stream) -> Box<dyn Future<Output = Result<(), Error>> + Send> + Send + Sync>;
pub struct WebServer {
web: Web,
}
impl WebServer {
pub fn with_resources(path: &'static str) -> WebServer {
let urls = HashMap::new();
WebServer { web: Web { path, urls } }
}
pub fn url<F: 'static, G: 'static>(mut self, url: &'static str, func: G)
-> Self
where F: Future<Output = Result<(), std::io::Error>> + Send, G: Fn(Stream) -> F + Sync + Send
{
self.web.urls.insert(url, ("text/html; charset=utf-8", Box::new(
move |stream| Box::new(func(stream))
)));
self
}
pub fn url_with_type<F: 'static, G: 'static>(
mut self,
url: &'static str,
func: G,
content_type: &'static str)
-> Self
where F: Future<Output = Result<(), std::io::Error>> + Send, G: Fn(Stream) -> F + Sync + Send
{
self.web.urls.insert(url, (content_type, Box::new(
move |stream| Box::new(func(stream))
)));
self
}
pub fn start(self) {
let web = Arc::new(self.web);
<pasts::ThreadInterrupt as pasts::Interrupt>::block_on(async_main(web));
}
}
struct Web {
path: &'static str,
urls: HashMap<&'static str, (&'static str, ResourceGenerator)>,
}
unsafe impl Sync for Stream {}
pub struct Stream {
internal: Cell<Option<InternalStream>>
}
impl Stream {
pub async fn send(&self) -> Result<(), std::io::Error> {
let mut this = self.internal.take().unwrap();
let ret = this.send().await;
self.internal.set(Some(this));
ret
}
pub fn push_str(&self, text: &str) {
let mut this = self.internal.take().unwrap();
this.push_str(text);
self.internal.set(Some(this));
}
pub fn push_data(&self, bytes: &[u8]) {
let mut this = self.internal.take().unwrap();
this.push_data(bytes);
self.internal.set(Some(this));
}
}
struct InternalStream {
stream: Arc<TcpStream>,
output: Vec<u8>,
}
impl InternalStream {
pub async fn send(&mut self) -> Result<(), std::io::Error> {
let stream = Arc::get_mut(&mut self.stream).unwrap();
stream.write(&self.output).await?;
stream.flush().await?;
Ok(())
}
pub fn push_str(&mut self, text: &str) {
self.output.extend(text.bytes());
}
pub fn push_data(&mut self, bytes: &[u8]) {
self.output.extend(bytes);
}
}
enum Message {
NewJob(WebserverTask),
Terminate,
}
async fn handle_connection(mut streama: Arc<TcpStream>, web: Arc<Web>) -> AsyncMsg {
let stream = Arc::get_mut(&mut streama).unwrap();
let mut buffer = [0; 512];
stream.read(&mut buffer).await.unwrap();
if !buffer.starts_with(b"GET ") {
return AsyncMsg::OldTask;
}
let mut end = 4;
let path = loop {
if end == buffer.len() {
return AsyncMsg::OldTask;
}
if buffer[end] == b' ' {
break &buffer[4..end];
}
end += 1;
};
if !buffer[end+1..].starts_with(b"HTTP/1.1\r\n") {
return AsyncMsg::OldTask;
}
let mut streamb = InternalStream { stream: streama, output: vec![] };
let path = if let Ok(path) = std::str::from_utf8(path) {
path
} else {
return AsyncMsg::OldTask;
};
let mut index = web.path.to_string();
index.push_str("/index.html");
let mut e404 = web.path.to_string();
e404.push_str("/404.html");
if "/" == path {
if let Ok(contents) = std::fs::read_to_string(index) {
streamb.push_str("HTTP/1.1 200 OK\nContent-Type: ");
streamb.push_str("text/html; charset=utf-8");
streamb.push_str("\r\n\r\n");
streamb.push_str(&contents);
streamb.send().await.unwrap();
} else {
streamb.push_str("HTTP/1.1 404 NOT FOUND\nContent-Type: ");
streamb.push_str("text/html; charset=utf-8");
streamb.push_str("\r\n\r\n");
if let Ok(cs) = std::fs::read_to_string(e404) {
streamb.push_str(&cs);
} else {
streamb.push_str("404 NOT FOUND");
};
streamb.send().await.unwrap();
}
} else {
let mut page = web.path.to_string();
page.push_str(path);
if let Some(request) = web.urls.get(path) {
{
streamb.push_str("HTTP/1.1 200 OK\nContent-Type: ");
streamb.push_str(request.0);
streamb.push_str("\r\n\r\n");
}
Pin::from(request.1(Stream { internal: Cell::new(Some(streamb)) }))
.await
.unwrap();
} else if let Ok(contents) = std::fs::read_to_string(page) {
streamb.push_str("HTTP/1.1 200 OK\nContent-Type: ");
streamb.push_str("text/html; charset=utf-8");
streamb.push_str("\r\n\r\n");
streamb.push_str(&contents);
streamb.send().await.unwrap();
} else {
streamb.push_str("HTTP/1.1 404 NOT FOUND\nContent-Type: ");
streamb.push_str("text/html; charset=utf-8");
streamb.push_str("\r\n\r\n");
if let Ok(cs) = std::fs::read_to_string(e404) {
streamb.push_str(&cs);
} else {
streamb.push_str("404 NOT FOUND");
};
streamb.send().await.unwrap();
}
};
AsyncMsg::OldTask
}