use crate::network::http::session::Session;
use bytes::{Buf, BufMut, BytesMut};
use http::{HeaderName, HeaderValue};
use std::io::{self, Read, Write};
use std::mem::MaybeUninit;
use std::net::IpAddr;
use std::str::FromStr;
#[cfg(feature = "net-ws-server")]
use crate::network::http::ws;
const HTTP11: &[u8] = b"HTTP/1.1 ";
const CRLF: &[u8] = b"\r\n";
pub(crate) const BUF_LEN: usize = 8 * 4096;
pub(crate) const MAX_HEADERS: usize = 32;
#[cfg(feature = "net-ws-server")]
#[inline]
fn drain_nb<W: Write>(w: &mut W, buf: &mut BytesMut) -> io::Result<()> {
use std::io::ErrorKind;
while !buf.is_empty() {
match w.write(&buf[..]) {
Ok(0) => {
return Err(io::Error::new(
ErrorKind::WriteZero,
"drain_nb: write returned 0",
));
}
Ok(n) => {
buf.advance(n);
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
may::coroutine::yield_now();
}
Err(e) => return Err(e),
}
}
Ok(())
}
pub struct H1Session<'buf, 'header, 'stream, S>
where
S: Read + Write,
'buf: 'stream,
{
peer_addr: &'stream IpAddr,
req: httparse::Request<'header, 'buf>,
req_buf: &'buf mut BytesMut,
rsp_headers_len: usize,
rsp_buf: &'buf mut BytesMut,
stream: &'stream mut S,
status_set: bool,
status_buf: heapless::Vec<u8, 192>,
streaming: bool,
}
#[async_trait::async_trait(?Send)]
impl<'buf, 'header, 'stream, S> Session for H1Session<'buf, 'header, 'stream, S>
where
S: Read + Write,
{
#[inline]
fn peer_addr(&self) -> &IpAddr {
self.peer_addr
}
#[inline]
fn req_host(&self) -> Option<(String, Option<u16>)> {
use super::server::parse_authority;
if let Some(host) = self
.req
.headers
.iter()
.find(|h| h.name.eq_ignore_ascii_case("host"))
.and_then(|h| std::str::from_utf8(h.value).ok())
&& let Some(a) = parse_authority(host.trim())
{
return Some(a);
}
if matches!(self.req.method, Some("CONNECT"))
&& let Some(path) = self.req.path
&& let Some(a) = parse_authority(path.trim())
{
return Some(a);
}
if let Some(path) = self.req.path
&& let Some((scheme, rest)) = path.split_once("://")
&& (scheme.eq_ignore_ascii_case("http") || scheme.eq_ignore_ascii_case("https"))
{
let auth_end = rest.find('/').unwrap_or(rest.len());
if let Some(a) = parse_authority(rest[..auth_end].trim()) {
return Some(a);
}
}
None
}
#[inline]
fn req_method(&self) -> http::Method {
if let Some(str) = self.req.method {
return http::Method::from_str(str).unwrap_or_default();
}
http::Method::GET
}
#[inline]
fn req_method_str(&self) -> Option<&str> {
self.req.method
}
#[inline]
fn req_path(&self) -> String {
self.req.path.unwrap_or_default().into()
}
#[inline]
fn req_path_bytes(&self) -> &[u8] {
self.req.path.unwrap_or_default().as_bytes()
}
#[inline]
fn req_query(&self) -> String {
if let Some(path) = self.req.path
&& let Some((_, query)) = path.split_once('?')
{
return query.to_string();
}
String::new()
}
#[inline]
fn req_http_version(&self) -> http::Version {
match self.req.version {
Some(1) => http::Version::HTTP_11,
Some(0) => http::Version::HTTP_10,
_ => http::Version::HTTP_09,
}
}
#[inline]
fn req_headers(&self) -> http::HeaderMap {
let mut map = http::HeaderMap::new();
for h in self.req.headers.iter() {
if let Ok(v) = HeaderValue::from_bytes(h.value)
&& let Ok(header_name) = HeaderName::from_str(h.name)
{
map.insert(header_name, v);
}
}
map
}
#[inline]
fn req_header(&self, header: &http::HeaderName) -> Option<http::HeaderValue> {
for h in self.req.headers.iter() {
if h.name.eq_ignore_ascii_case(header.as_str()) {
return HeaderValue::from_bytes(h.value).ok();
}
}
None
}
#[inline]
fn req_body(&mut self, timeout: std::time::Duration) -> io::Result<&[u8]> {
let content_length = self
.req
.headers
.iter()
.find(|h| h.name.eq_ignore_ascii_case("Content-Length"))
.and_then(|h| std::str::from_utf8(h.value).ok())
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
if content_length == 0 {
return Ok(&[]);
}
if self.req_buf.len() >= content_length {
return Ok(&self.req_buf[..content_length]);
}
self.req_buf.reserve(content_length - self.req_buf.len());
let mut read = self.req_buf.len();
let deadline = std::time::Instant::now() + timeout;
while read < content_length {
if std::time::Instant::now() > deadline {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"body read timed out",
));
}
let spare = self.req_buf.spare_capacity_mut();
let to_read = spare.len().min(content_length - read);
if to_read == 0 {
may::coroutine::yield_now();
continue;
}
let buf =
unsafe { std::slice::from_raw_parts_mut(spare.as_mut_ptr() as *mut u8, to_read) };
match self.stream.read(buf) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed before body fully read",
));
}
Ok(n) => {
unsafe {
self.req_buf.advance_mut(n);
}
read += n;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
may::coroutine::yield_now();
}
Err(e) => return Err(e),
}
if read.is_multiple_of(1024) {
may::coroutine::yield_now();
}
}
Ok(&self.req_buf[..content_length])
}
#[inline]
async fn req_body_async(
&mut self,
_timeout: std::time::Duration,
) -> Option<std::io::Result<bytes::Bytes>> {
None
}
#[inline]
fn write_all_eom(&mut self, status: &[u8]) -> std::io::Result<()> {
self.rsp_buf.extend_from_slice(status);
Ok(())
}
#[inline]
fn status_code(&mut self, status: http::StatusCode) -> &mut Self {
const SERVER_NAME: &str =
concat!("\r\nServer: Sib ", env!("SIB_BUILD_VERSION"), "\r\nDate: ");
self.status_buf.clear();
self.status_buf.extend_from_slice(HTTP11).ok();
self.status_buf
.extend_from_slice(status.as_str().as_bytes())
.ok();
self.status_buf.extend_from_slice(b" ").ok();
if let Some(reason) = status.canonical_reason() {
self.status_buf.extend_from_slice(reason.as_bytes()).ok();
}
self.status_buf
.extend_from_slice(SERVER_NAME.as_bytes())
.ok();
self.status_buf
.extend_from_slice(crate::network::http::date::current_date_str().as_bytes())
.ok();
self.status_buf.extend_from_slice(CRLF).ok();
self.status_set = true;
self
}
fn start_h1_streaming(&mut self) -> std::io::Result<()> {
use std::io::{ErrorKind, IoSlice};
if self.streaming {
return Err(std::io::Error::other(
"start_h1_streaming called while already streaming",
));
}
if !self.status_set {
self.status_code(http::StatusCode::OK);
}
self.rsp_buf.extend_from_slice(CRLF);
let mut off_status = 0usize;
let mut off_body = 0usize;
loop {
let status = &self.status_buf[off_status..];
let body = &self.rsp_buf[off_body..];
if status.is_empty() && body.is_empty() {
break;
}
let bufs = if !status.is_empty() && !body.is_empty() {
[IoSlice::new(status), IoSlice::new(body)]
} else if !status.is_empty() {
[IoSlice::new(status), IoSlice::new(&[])]
} else {
[IoSlice::new(body), IoSlice::new(&[])]
};
match self.stream.write_vectored(&bufs) {
Ok(0) => {
return Err(std::io::Error::other(
"write_vectored got zero in start_h1_streaming",
));
}
Ok(n) => {
let status_len = status.len();
if n < status_len {
off_status += n;
} else {
off_status = status_len;
off_body += n - status_len;
}
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
may::coroutine::yield_now();
}
Err(e) => return Err(e),
}
}
self.status_buf.clear();
self.rsp_buf.clear();
self.rsp_headers_len = 0;
self.streaming = true;
Ok(())
}
async fn start_h1_streaming_async(&mut self) -> std::io::Result<()> {
Err(std::io::Error::other(
"start_h1_streaming_async is not supported in H1Session",
))
}
#[cfg(feature = "net-h2-server")]
#[inline]
fn start_h2_streaming(&mut self) -> std::io::Result<super::h2_session::H2Stream> {
Err(std::io::Error::other(
"start_h2_streaming is not supported in H1Session",
))
}
#[inline]
async fn start_h3_streaming(&mut self) -> std::io::Result<()> {
Err(std::io::Error::other(
"start_h3_streaming is not supported in H1Session",
))
}
fn send_h1_data(&mut self, chunk: &[u8], end_stream: bool) -> std::io::Result<()> {
if !self.streaming {
return Err(std::io::Error::other(
"send_h1_data called before start_h1_streaming",
));
}
let mut data = chunk;
while !data.is_empty() {
match self.stream.write(data) {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"send_h1_data got write zero",
));
}
Ok(n) => data = &data[n..],
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
may::coroutine::yield_now();
}
Err(e) => return Err(e),
}
}
if end_stream {
self.streaming = false;
}
Ok(())
}
async fn send_h1_data_async(&mut self, _data: &[u8], _last: bool) -> io::Result<()> {
Err(io::Error::other(
"send_h1_data_async is not supported in H1Session",
))
}
#[inline]
async fn send_h3_data(
&mut self,
_chunk: bytes::Bytes,
_end_stream: bool,
) -> std::io::Result<()> {
Err(std::io::Error::other(
"send_h3_data is not supported in H1Session",
))
}
#[inline]
fn header(&mut self, name: HeaderName, value: HeaderValue) -> std::io::Result<&mut Self> {
if self.rsp_headers_len >= MAX_HEADERS {
return Err(io::Error::new(
io::ErrorKind::ArgumentListTooLong,
"too many headers",
));
}
self.rsp_buf.extend_from_slice(format!("{name}").as_bytes());
self.rsp_buf.extend_from_slice(b": ");
self.rsp_buf.extend_from_slice(value.as_bytes());
self.rsp_buf.extend_from_slice(CRLF);
self.rsp_headers_len += 1;
Ok(self)
}
#[inline]
fn header_str(&mut self, name: &str, value: &str) -> std::io::Result<&mut Self> {
if self.rsp_headers_len >= MAX_HEADERS {
return Err(io::Error::new(
io::ErrorKind::ArgumentListTooLong,
"too many headers",
));
}
self.rsp_buf.extend_from_slice(name.as_bytes());
self.rsp_buf.extend_from_slice(b": ");
self.rsp_buf.extend_from_slice(value.as_bytes());
self.rsp_buf.extend_from_slice(CRLF);
self.rsp_headers_len += 1;
Ok(self)
}
#[inline]
fn headers(&mut self, headers: &http::HeaderMap) -> std::io::Result<&mut Self> {
for (k, v) in headers {
self.header(k.clone(), v.clone())?;
}
Ok(self)
}
#[inline]
fn headers_str(&mut self, header_val: &[(&str, &str)]) -> std::io::Result<&mut Self> {
for (name, value) in header_val {
self.header_str(name, value)?;
}
Ok(self)
}
#[inline]
fn body(&mut self, body: bytes::Bytes) -> &mut Self {
if !self.status_set {
self.status_code(http::StatusCode::OK);
}
self.rsp_buf.extend_from_slice(CRLF);
self.rsp_buf.extend_from_slice(&body);
self
}
#[inline]
fn eom(&mut self) -> std::io::Result<()> {
use std::io::{ErrorKind, IoSlice};
if self.streaming {
self.rsp_buf.clear();
self.status_buf.clear();
self.status_set = false;
self.streaming = false;
return Ok(());
}
if !self.status_set {
self.status_code(http::StatusCode::OK);
}
let mut off_status = 0usize;
let mut off_body = 0usize;
loop {
let s1 = &self.status_buf[off_status..];
let s2 = &self.rsp_buf[off_body..];
if s1.is_empty() && s2.is_empty() {
break;
}
let bufs = if !s1.is_empty() && !s2.is_empty() {
[IoSlice::new(s1), IoSlice::new(s2)]
} else if !s1.is_empty() {
[IoSlice::new(s1), IoSlice::new(&[])]
} else {
[IoSlice::new(s2), IoSlice::new(&[])]
};
match self.stream.write_vectored(&bufs) {
Ok(0) => return Err(io::Error::new(ErrorKind::WriteZero, "h1 eom write zero")),
Ok(n) => {
let s1_len = s1.len();
if n < s1_len {
off_status += n;
} else {
off_status = s1_len;
off_body += n - s1_len;
}
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
may::coroutine::yield_now();
}
Err(e) => return Err(e),
}
}
self.rsp_buf.clear();
self.status_buf.clear();
self.status_set = false;
Ok(())
}
#[inline]
async fn eom_async(&mut self) -> std::io::Result<()> {
Err(std::io::Error::other(
"eom_async is not supported in H1Session",
))
}
#[cfg(feature = "net-ws-server")]
#[inline]
fn is_ws(&self) -> bool {
ws::is_h1_ws_upgrade(&self.req_method(), &self.req_headers())
}
#[cfg(feature = "net-ws-server")]
#[inline]
fn ws_accept(&mut self) -> std::io::Result<()> {
let method = self.req_method();
let headers = self.req_headers();
if !ws::is_h1_ws_upgrade(&method, &headers) {
return self
.status_code(http::StatusCode::BAD_REQUEST)
.header_str("Connection", "close")?
.eom();
}
let key = match self.req_header(&HeaderName::from_static("sec-websocket-key")) {
Some(v) => v,
None => {
return self
.status_code(http::StatusCode::BAD_REQUEST)
.header_str("Connection", "close")?
.eom();
}
};
let key_str = key.to_str().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "bad sec-websocket-key")
})?;
let accept = ws::sec_websocket_accept(key_str)?;
let mut resp = format!(
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {accept}\r\n"
);
if let Some(sub_protocol) =
self.req_header(&HeaderName::from_static("sec-websocket-protocol"))
{
resp.push_str(&format!(
"Sec-WebSocket-Protocol: {}\r\n",
sub_protocol.to_str().unwrap_or("")
));
}
resp.push_str("\r\n");
self.stream.write_all(resp.as_bytes())?;
self.req_buf.clear();
self.rsp_buf.clear();
self.status_buf.clear();
self.status_set = false;
self.rsp_headers_len = 0;
self.streaming = false;
Ok(())
}
#[cfg(feature = "net-ws-server")]
#[inline]
fn ws_read(&mut self) -> std::io::Result<(ws::OpCode, bytes::Bytes, bool)> {
const MAX_BUFFERED: usize = BUF_LEN + 64 * 1024;
loop {
if let Some(frame) = ws::try_parse_frame(self.req_buf)? {
if frame.payload.len() > BUF_LEN {
return Err(std::io::Error::other(format!(
"max WS frame is {}",
BUF_LEN
)));
}
return Ok((frame.op, frame.payload, frame.fin));
}
if self.req_buf.len() > MAX_BUFFERED {
return Err(std::io::Error::other("ws buffered data too large"));
}
if !crate::network::http::h1_server::read(self.stream, self.req_buf)? {
may::coroutine::yield_now();
continue;
}
}
}
#[cfg(feature = "net-ws-server")]
#[inline]
fn ws_write(
&mut self,
code: ws::OpCode,
payload: &bytes::Bytes,
fin: bool,
) -> std::io::Result<()> {
let frame = ws::encode_frame(code, payload, fin, None);
self.rsp_buf.extend_from_slice(&frame);
drain_nb(self.stream, self.rsp_buf)
}
#[cfg(feature = "net-ws-server")]
#[inline]
fn ws_close(&mut self, reason: Option<&bytes::Bytes>) -> std::io::Result<()> {
let mut payload = [0u8; 2 + 123]; payload[..2].copy_from_slice(&1000u16.to_be_bytes());
let rlen = reason.map(|r| r.len()).unwrap_or(0);
if rlen > 123 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"close reason too long",
));
}
if let Some(r) = reason {
if std::str::from_utf8(r).is_err() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"close reason not utf8",
));
}
payload[2..2 + rlen].copy_from_slice(r);
}
let total = 2 + rlen;
let mut hdr = [0u8; 4];
hdr[0] = 0x88; hdr[1] = (total as u8) & 0x7F;
self.rsp_buf.extend_from_slice(&hdr[..2]);
self.rsp_buf.extend_from_slice(&payload[..total]);
drain_nb(self.stream, self.rsp_buf)
}
}
pub fn new_session<'header, 'buf, 'stream, S>(
stream: &'stream mut S,
peer_addr: &'stream IpAddr,
headers: &'header mut [MaybeUninit<httparse::Header<'buf>>; MAX_HEADERS],
req_buf: &'buf mut BytesMut,
rsp_buf: &'buf mut BytesMut,
) -> io::Result<Option<H1Session<'buf, 'header, 'stream, S>>>
where
S: Read + Write,
{
let mut req = httparse::Request::new(&mut []);
let buf: &[u8] = unsafe { std::mem::transmute(req_buf.chunk()) };
let status = match req.parse_with_uninit_headers(buf, headers) {
Ok(s) => s,
Err(e) => {
return Err(io::Error::other(format!(
"failed to parse http request: {e:?}"
)));
}
};
let count = match status {
httparse::Status::Complete(num) => num,
httparse::Status::Partial => return Ok(None),
};
req_buf.advance(count);
let rem = rsp_buf.capacity() - rsp_buf.len();
if rem < 1024 {
rsp_buf.reserve(BUF_LEN - rem);
}
Ok(Some(H1Session {
peer_addr,
req,
req_buf,
rsp_headers_len: 0,
rsp_buf,
stream,
status_set: false,
status_buf: heapless::Vec::new(),
streaming: false,
}))
}