#![warn(missing_docs)]
#![doc = include_str!("../README.md")]
#![doc = include_str!("../examples/simple.rs")]
use log::{trace, warn};
use std::fmt::Debug;
use std::marker::Unpin;
use std::io::{Cursor, Read, Write};
use std::collections::{HashMap, hash_map::Entry};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::sync::{Mutex, MutexGuard};
use std::convert::TryFrom;
use byteorder::{BigEndian, ReadBytesExt};
use std::future::Future;
const RECORD_HEADER_SIZE: usize = 8;
const FCGI_KEEP_CONN: u8 = 0x01;
const ERR_LOCK_FAILED: &str = "A request must not be processed by multiple threads.";
type RequestId = u16;
type ParamsIterator<'i> = dyn Iterator<Item=(&'i str, &'i [u8])> + 'i;
type StrParamsIterator<'i> = dyn Iterator<Item=(&'i str, Option<&'i str>)> + 'i;
pub type OwnedInStream<'a> = MutexGuard<'a, InStream>;
#[derive(Debug)]
enum TypeError {
UnknownRecordType(u8)
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum StdReqType {
BeginRequest = 1,
Params = 4,
StdIn = 5,
Data = 8
}
impl From<StdReqType> for u8 {
fn from(rt: StdReqType) -> Self {
rt as u8
}
}
impl TryFrom<u8> for StdReqType {
type Error = TypeError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(Self::BeginRequest),
4 => Ok(Self::Params),
5 => Ok(Self::StdIn),
8 => Ok(Self::Data),
_ => Err(TypeError::UnknownRecordType(value))
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum StdRespType {
EndRequest = 3,
StdOut = 6,
StdErr = 7
}
impl From<StdRespType> for u8 {
fn from(rt: StdRespType) -> Self {
rt as u8
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum SysReqType {
AbortRequest = 2,
GetValues = 9
}
impl From<SysReqType> for u8 {
fn from(rt: SysReqType) -> Self {
rt as u8
}
}
impl TryFrom<u8> for SysReqType {
type Error = TypeError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
2 => Ok(Self::AbortRequest),
9 => Ok(Self::GetValues),
_ => Err(TypeError::UnknownRecordType(value))
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum SysRespType {
GetValuesResult = 10,
UnknownType = 11
}
impl From<SysRespType> for u8 {
fn from(rt: SysRespType) -> Self {
rt as u8
}
}
#[derive(Clone, Copy, Debug)]
enum Category<S: Copy, T: Copy> {
Std(S),
Sys(T)
}
impl <S: Copy + TryFrom<u8, Error = TypeError>, T: Copy + TryFrom<u8, Error = TypeError>> TryFrom<u8> for Category<S, T> {
type Error = TypeError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
if let Ok(result) = S::try_from(value) {
Ok(Self::Std(result))
} else {
T::try_from(value).map(Self::Sys)
}
}
}
impl <S: std::convert::Into<u8> + Copy, T: std::convert::Into<u8> + Copy> From<Category<S, T>> for u8 {
fn from(cat: Category<S, T>) -> Self {
match cat {
Category::<S, T>::Std(std) => std.into(),
Category::<S, T>::Sys(sys) => sys.into()
}
}
}
type RequestType = Category<StdReqType, SysReqType>;
type ResponseType = Category<StdRespType, SysRespType>;
#[derive(PartialEq, Debug)]
pub enum Role {
Responder,
Authorizer,
Filter
}
impl Role {
fn from_number(rl_num: u16) -> Option<Self> {
match rl_num {
1 => Some(Role::Responder),
2 => Some(Role::Authorizer),
3 => Some(Role::Filter),
_ => None
}
}
}
#[derive(Copy, Clone)]
pub enum RequestResult {
Complete(u32),
Overloaded,
UnknownRole
}
impl RequestResult {
fn app_status(self) -> u32 {
match self {
Self::Complete(app_status) => app_status,
_ => 0
}
}
}
impl From<RequestResult> for u8 {
fn from(rr: RequestResult) -> Self {
match rr {
RequestResult::Complete(_) => 0,
RequestResult::Overloaded => 2,
RequestResult::UnknownRole => 3
}
}
}
#[derive(Debug)]
pub enum Error {
StreamAlreadyDone,
StreamAlreadyClosed,
SequenceError,
InvalidRecordVersion,
InvalidRoleNumber,
UnknownRecordType(RequestId, u8),
IoError(std::io::Error)
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::StreamAlreadyDone => write!(f, "Input stream is already done"),
Error::StreamAlreadyClosed => write!(f, "Output stream is already closed"),
Error::SequenceError => write!(f, "Records out of sequence "),
Error::InvalidRecordVersion => write!(f, "Only record version 1 supported"),
Error::InvalidRoleNumber => write!(f, "Unkown role pass from server"),
Error::UnknownRecordType(request_id, type_id) => write!(f, "Unkown record type {} in request {} received", type_id, request_id),
Error::IoError(error) => write!(f, "I/O error: {}", error)
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::IoError(source) => Some(source),
_ => None
}
}
}
impl From<std::io::Error> for Error {
fn from(io_error: std::io::Error) -> Self {
Error::IoError(io_error)
}
}
struct Record {
record_type: RequestType,
request_id: RequestId,
content: Vec<u8>
}
impl Record {
async fn new<R: AsyncRead + Unpin>(rd: &mut R) -> Result<Self, Error> {
let mut header_buffer = [0; RECORD_HEADER_SIZE];
rd.read_exact(&mut header_buffer).await?;
let mut header_slice = &header_buffer[..];
if byteorder::ReadBytesExt::read_u8(&mut header_slice).unwrap() != 1 {
return Err(Error::InvalidRecordVersion);
}
let record_type = RequestType::try_from(byteorder::ReadBytesExt::read_u8(&mut header_slice).unwrap());
let request_id = byteorder::ReadBytesExt::read_u16::<BigEndian>(&mut header_slice)?;
let content_length = byteorder::ReadBytesExt::read_u16::<BigEndian>(&mut header_slice).unwrap() as usize;
let padding_length = byteorder::ReadBytesExt::read_u8(&mut header_slice).unwrap() as u64;
let mut content = vec![0; content_length];
rd.read_exact(&mut content).await?;
if padding_length > 0 {
tokio::io::copy(&mut rd.take(padding_length), &mut tokio::io::sink()).await?;
}
trace!("FastCGI: In record {{T:{:?}, ID: {}, L:{}}}", record_type, request_id, RECORD_HEADER_SIZE + content.len() + padding_length as usize);
let record_type = record_type.map_err(|error| {
let TypeError::UnknownRecordType(record_type_nr) = error;
Error::UnknownRecordType(request_id, record_type_nr)
})?;
Ok(Self {
record_type,
request_id,
content
})
}
fn is_sys_record(&self) -> bool {
matches!(self.record_type, Category::Sys(_))
}
fn get_content(&self) -> &[u8] {
&self.content
}
fn get_request_id(&self) -> RequestId {
self.request_id
}
}
#[derive(Debug)]
pub struct InStream {
data: Vec<u8>,
read_pos: Option<usize>
}
impl Read for InStream {
fn read(&mut self, out: &mut [u8]) -> std::result::Result<usize, std::io::Error> {
let read_pos = self.read_pos.unwrap();
let c = std::io::Read::read(&mut &self.data[read_pos..], out)?;
self.read_pos = Some(read_pos + c);
Ok(c)
}
fn read_exact(&mut self, out: &mut [u8]) -> std::result::Result<(), std::io::Error> {
let read_pos = self.read_pos.unwrap();
std::io::Read::read_exact(&mut &self.data[read_pos..], out)?;
self.read_pos = Some(read_pos + out.len());
Ok(())
}
}
impl InStream{
fn new(already_done: bool) -> Self {
InStream {
data: Vec::new(),
read_pos: if already_done { Some(0) } else { None }
}
}
fn append(&mut self, data: &[u8]) -> Result<(), Error> {
if ! data.is_empty() {
if self.read_pos.is_none() {
self.data.extend_from_slice(data);
Ok(())
} else {
Err(Error::StreamAlreadyDone)
}
} else {
self.read_pos = Some(0);
Ok(())
}
}
fn is_done(&self) -> bool {
self.read_pos.is_some()
}
}
pub struct Request <W: AsyncWrite + Unpin> {
pub role: Role,
keep_connection: bool,
request_id: RequestId,
params: HashMap<String, Vec<u8>>,
params_done: bool,
orw: Arc<OutRecordWriter<W>>,
stdin: Mutex<InStream>,
data: Mutex<InStream>
}
impl <W: AsyncWrite + Unpin> Request<W> {
fn new(record: &Record, writer: Arc<Mutex<W>>) -> Result<Self, Error> {
let mut content = record.get_content();
if let Category::Std(StdReqType::BeginRequest) = record.record_type {
if let Some(role) = Role::from_number(byteorder::ReadBytesExt::read_u16::<BigEndian>(&mut content).unwrap()) { let keep_connection = (byteorder::ReadBytesExt::read_u8(&mut content)? & FCGI_KEEP_CONN) == FCGI_KEEP_CONN;
Ok(Self {
params: HashMap::new(),
params_done: false,
orw: Arc::from(OutRecordWriter::new(writer, record.request_id)),
stdin: Mutex::from(InStream::new(role == Role::Authorizer)), data: Mutex::from(InStream::new(role != Role::Filter)), role,
keep_connection,
request_id: record.request_id
})
} else {
Err(Error::InvalidRoleNumber)
}
} else {
Err(Error::SequenceError)
}
}
fn read_length<T: Read>(src: &mut T) -> Result<u32, std::io::Error> {
let length: u32 = u32::from(src.read_u8()?);
if length & 0x80 == 0 {
Ok(length)
} else {
let length_byte2 = u32::from(src.read_u8()?);
let length_byte10 = u32::from(src.read_u16::<BigEndian>()?);
Ok((length & 0x7f) << 24 | length_byte2 << 16 | length_byte10)
}
}
fn add_nv_pairs(params: &mut HashMap<String, Vec<u8>>, src: &[u8], lowercase_keys: bool) -> Result<(), std::io::Error>{
let mut src_slice = src;
while !src_slice.is_empty() {
let name_length = Request::<W>::read_length(&mut src_slice)?;
let value_length = Request::<W>::read_length(&mut src_slice)?;
let mut name_buffer = vec![0; name_length as usize];
let mut value_buffer = vec![0; value_length as usize];
std::io::Read::read_exact(&mut src_slice, &mut name_buffer)?;
std::io::Read::read_exact(&mut src_slice, &mut value_buffer)?;
let key = String::from_utf8_lossy(&name_buffer);
let key = if lowercase_keys {
key.to_ascii_lowercase()
} else {
key.into_owned()
};
trace!("FastCGI: NV-Pair[\"{}\"]=\"{}\"", key, String::from_utf8_lossy(&value_buffer));
params.insert(key, value_buffer);
}
Ok(())
}
pub fn get_param(&self, name: &str) -> Option<&Vec<u8>> {
if self.params_done {
self.params.get(&name.to_ascii_lowercase())
} else {
None
}
}
pub fn get_str_param(&self, name: &str) -> Option<&str> {
if self.params_done {
match self.params.get(&name.to_ascii_lowercase()).map(|v| std::str::from_utf8(v)) {
None => None,
Some(Ok(value)) => Some(value),
Some(Err(_)) => {
warn!("FastCGI: Parameter {} is not valid utf8.", name);
None
}
}
} else {
None
}
}
pub fn params_iter(&self) -> Option<Box<ParamsIterator>> {
if self.params_done {
Some(Box::new(self.params.iter().map(|v| {
(v.0.as_str(), &v.1[..])
})))
} else {
None
}
}
pub fn str_params_iter(&self) -> Option<Box<StrParamsIterator>> {
if self.params_done {
Some(Box::new(self.params.iter().map(|v| {
(v.0.as_str(), std::str::from_utf8(v.1).ok())
})))
} else {
None
}
}
fn check_ready(&mut self) -> bool {
self.get_stdin().is_done() && self.get_data().is_done() && self.params_done
}
fn update(&mut self, record: &Record) -> Result<bool, Error> {
assert!(record.request_id == self.request_id);
if self.check_ready() {
return Err(Error::SequenceError);
}
if let Category::Std(record_type) = record.record_type {
match record_type {
StdReqType::BeginRequest => {
return Err(Error::SequenceError);
},
StdReqType::Params => {
if record.content.is_empty() {
self.params_done = true;
} else {
if self.params_done { warn!("FastCGI: Protocol error. Params received after params stream was marked as done."); }
Self::add_nv_pairs(&mut self.params, record.get_content(), true)?;
}
},
StdReqType::StdIn => {
self.get_stdin().append(record.get_content())?;
},
StdReqType::Data => {
self.get_data().append(record.get_content())?;
}
};
Ok(self.check_ready())
} else {
Err(Error::SequenceError)
}
}
pub fn get_request_id(&self) -> RequestId {
self.request_id
}
pub fn get_stdout(&self) -> OutStream<W> {
OutStream::new(Category::Std(StdRespType::StdOut), self.orw.clone())
}
pub fn get_stderr(&self) -> OutStream<W> {
OutStream::new(Category::Std(StdRespType::StdErr), self.orw.clone())
}
pub fn get_stdin(&self) -> OwnedInStream {
self.stdin.try_lock().expect(ERR_LOCK_FAILED)
}
pub fn get_data(&self) -> OwnedInStream {
self.data.try_lock().expect(ERR_LOCK_FAILED)
}
pub async fn process<F: Future<Output = RequestResult>, C: FnOnce(Arc<Self>) -> F>(self, callback: C) -> Result<(), Error> {
let rc_self = Arc::from(self);
let result = callback(rc_self.clone()).await;
if let Ok(this) = Arc::try_unwrap(rc_self) {
this.get_stdout().close().await?;
this.get_stderr().close().await?;
this.orw.write_finish(result).await?;
} else {
panic!("StdErr or StdOut leaked out of process.")
}
Ok(())
}
}
pub struct Requests <R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send> {
reader: R,
writer: Arc<Mutex<W>>,
requests: HashMap<RequestId, Request<W>>,
close_on_next: bool,
max_conns: u8,
max_reqs: u8
}
impl <'w, R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send> Requests<R, W> {
pub fn new(rd: R, wr: W, max_conns: u8, max_reqs: u8) -> Self {
Self {
requests: HashMap::with_capacity(1),
reader: rd,
writer: Arc::from(Mutex::from(wr)),
close_on_next: false,
max_conns,
max_reqs
}
}
pub fn from_split_socket(split_socket: (R, W), max_conns: u8, max_reqs: u8) -> Self {
Self::new(split_socket.0, split_socket.1, max_conns, max_reqs)
}
async fn process_sys(&self, record: Record) -> Result<Option<RequestId>, Error> {
if let Category::Sys(record_type) = record.record_type {
let output_stream = OutRecordWriter::new(self.writer.clone(), record.request_id);
let result = match record_type {
SysReqType::GetValues => {
let mut params = HashMap::new();
Request::<W>::add_nv_pairs(&mut params, record.get_content(), false)?;
#[cfg(debug_assertions)]
let mut params: Vec<(String, _)> = params.into_iter().collect();
#[cfg(debug_assertions)]
params.sort_by(|a, b| { a.0.cmp(&b.0) });
let mut output = Vec::with_capacity(128);
for (name, _) in params {
let result = match &*name {
"FCGI_MAX_CONNS" => Some(self.max_conns),
"FCGI_MAX_REQS" => Some(self.max_reqs),
"FCGI_MPXS_CONNS" => Some(1),
_ => None
};
if let Some(result) = result {
let result_str = result.to_string();
Write::write_all(&mut output, &[name.len() as u8])?;
Write::write_all(&mut output, &[result_str.len() as u8])?;
Write::write_all(&mut output, name.as_bytes())?;
Write::write_all(&mut output, result_str.as_bytes())?;
}
}
output_stream.write_data(Category::Sys(SysRespType::GetValuesResult), &output[..]).await?;
Ok(None)
},
SysReqType::AbortRequest => {
output_stream.write_finish(RequestResult::Complete(0)).await?;
Ok(Some(record.get_request_id()))
}
};
output_stream.flush().await?;
result
} else {
panic!("process_sys called with non sys record.");
}
}
pub async fn next(&mut self) -> Result<Option<Request<W>>, Error> {
if self.close_on_next {
if !self.requests.is_empty() {
warn!("FastCGI: The web-server interleaved requests on this connection but did not use the FCGI_KEEP_CONN flag. {} requests will get lost.", self.requests.len());
}
Ok(None)
} else {
loop
{
match Record::new(&mut self.reader).await {
Ok(record) => {
if record.is_sys_record() {
if let Some(canceled_request_id) = self.process_sys(record).await? {
self.requests.remove(&canceled_request_id);
}
} else {
let request_ready = match self.requests.entry(record.get_request_id()) {
Entry::Occupied(mut e) => { e.get_mut().update(&record) },
Entry::Vacant(e) => { e.insert(Request::new(&record, self.writer.clone())?); Ok(false) }
}?;
if request_ready {
let request = self.requests.remove(&record.get_request_id()).unwrap();
self.close_on_next = !request.keep_connection;
return Ok(Some(request));
}
}
},
Err(Error::IoError(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
if self.requests.is_empty() {
return Ok(None)
} else {
return Err(Error::from(err));
}
},
Err(Error::UnknownRecordType(request_id, type_id)) => {
let output_stream = OutRecordWriter::new(self.writer.clone(), request_id);
output_stream.write_unkown_type(type_id).await?;
},
Err(err) => {
return Err(err);
}
}
}
}
}
}
impl <W: AsyncWrite + Unpin> Debug for Request<W> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Request {{ request_id: {}, keep_connection: {:?}, stdin: {:?}, data: {:?}, params: {{", self.request_id, self.keep_connection, self.stdin, self.data)?;
for (param_index, param_key) in self.params.keys().enumerate() {
let delimiter = if param_index > 0 { ", " } else { "" };
if let Some(str_value) = self.get_str_param(param_key) {
write!(f, "{}{}: \"{}\"", delimiter, param_key, str_value)?;
} else {
write!(f, "{}{}: {:?}", delimiter, param_key, self.get_param(param_key))?;
}
}
writeln!(f, "}}")
}
}
#[derive(Debug)]
struct OutRecordWriter<W: AsyncWrite> {
inner_stream: Arc<Mutex<W>>,
request_id: RequestId,
}
impl <W: AsyncWrite + Unpin> OutRecordWriter<W> {
fn new(inner_stream: Arc<Mutex<W>>, request_id: RequestId) -> Self {
Self {
inner_stream,
request_id
}
}
async fn write_data(&self, record_type: ResponseType, data: &[u8]) -> std::result::Result<usize, std::io::Error> {
trace!("FastCGI: Out record {{T:{:?}, ID: {}, L:{}}}", record_type, self.request_id, RECORD_HEADER_SIZE + data.len());
let mut message_header = Vec::with_capacity(8);
byteorder::WriteBytesExt::write_u8(&mut message_header, 1).unwrap(); byteorder::WriteBytesExt::write_u8(&mut message_header, record_type.into()).unwrap(); byteorder::WriteBytesExt::write_u16::<BigEndian>(&mut message_header, self.request_id).unwrap(); byteorder::WriteBytesExt::write_u16::<BigEndian>(&mut message_header, data.len() as u16).unwrap(); byteorder::WriteBytesExt::write_u8(&mut message_header, 0).unwrap(); byteorder::WriteBytesExt::write_u8(&mut message_header, 0).unwrap();
let mut is = self.inner_stream.try_lock().expect(ERR_LOCK_FAILED);
is.write_all_buf(&mut Cursor::new(message_header)).await?;
if !data.is_empty() {
is.write_all_buf(&mut Cursor::new(data)).await?;
Ok(data.len())
} else {
Ok(0)
}
}
async fn write_finish(&self, result: RequestResult) -> Result<(), std::io::Error> {
let mut end_message = Vec::with_capacity(8);
byteorder::WriteBytesExt::write_u32::<BigEndian>(&mut end_message, result.app_status()).unwrap();
byteorder::WriteBytesExt::write_u8(&mut end_message, result.into()).unwrap();
std::io::Write::write_all(&mut end_message, &[0u8; 3]).unwrap();
self.write_data(Category::Std(StdRespType::EndRequest), &end_message[..]).await?;
Ok(())
}
async fn write_unkown_type(&self, type_id: u8) -> Result<(), std::io::Error> {
let mut ut_message = Vec::with_capacity(8);
byteorder::WriteBytesExt::write_u8(&mut ut_message, type_id).unwrap();
std::io::Write::write_all(&mut ut_message, &[0u8; 7]).unwrap();
self.write_data(ResponseType::Sys(SysRespType::UnknownType), &ut_message[..]).await?;
Ok(())
}
async fn flush(&self) -> std::result::Result<(), std::io::Error> {
self.inner_stream.try_lock().expect(ERR_LOCK_FAILED).flush().await
}
}
pub struct OutStream<W: AsyncWrite + Unpin> {
orw: Arc<OutRecordWriter<W>>,
record_type: ResponseType,
closed: bool
}
impl <W: AsyncWrite + Unpin> OutStream<W> {
fn new(record_type: ResponseType, orw: Arc<OutRecordWriter<W>>) -> Self {
Self {
orw,
record_type,
closed: false
}
}
pub async fn write(&mut self, data: &[u8]) -> std::result::Result<usize, Error> {
if self.closed {
return Err(Error::StreamAlreadyClosed);
}
if data.len() < u16::max_value() as usize {
Ok(self.orw.write_data(self.record_type, data).await?)
} else {
const JUNK_SIZE: usize = (u16::max_value() - 1) as usize;
for offset in (0..data.len()).step_by(JUNK_SIZE) {
self.orw.write_data(self.record_type, &data[offset..(offset + JUNK_SIZE).min(data.len())]).await?;
}
Ok(data.len())
}
}
pub async fn flush(&self) -> std::result::Result<(), std::io::Error> {
self.orw.flush().await
}
async fn close(&mut self) -> Result<(), Error>{
self.write(&[0u8; 0]).await?;
self.flush().await?;
self.closed = true;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_test::io::Builder;
fn is_send<T: Send>(_: T) { }
#[test]
fn check_send() {
is_send(async move {
let mut requests = Requests::new(Builder::new().build(), Builder::new().build(), 10, 10);
is_send(&requests);
while let Ok(Some(request)) = requests.next().await {
request.process(|_request| async move {
RequestResult::Complete(0)
}).await.unwrap();
}
});
}
}