use crate::{
errors::CatBridgeError,
net::{Extensions, errors::CommonNetAPIError},
};
use bytes::{Bytes, BytesMut};
use fnv::FnvHasher;
use futures::Future;
use std::{
fmt::{Debug, Formatter, Result as FmtResult},
hash::{Hash, Hasher},
marker::Send,
net::SocketAddr,
};
use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
#[cfg(feature = "servers")]
use crate::{errors::NetworkError, net::errors::CommonNetNetworkError};
#[cfg(feature = "servers")]
use std::sync::Arc;
#[cfg(feature = "servers")]
use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
#[cfg(feature = "servers")]
use tracing::error;
pub trait FromRef<InputTy> {
fn from_ref(input: &InputTy) -> Self;
}
impl<InnerTy> FromRef<InnerTy> for InnerTy
where
InnerTy: Clone,
{
fn from_ref(input: &InnerTy) -> Self {
input.clone()
}
}
pub struct Request<State: Clone + Send + Sync + 'static> {
body: Bytes,
ext: Extensions,
source_address: SocketAddr,
state: State,
stream_id: Option<u64>,
#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
#[cfg(feature = "clients")]
explicit_read_amount: Option<usize>,
#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
#[cfg(feature = "servers")]
#[allow(
// TODO(mythra): refactor to type.
clippy::type_complexity,
)]
stream_access: Option<Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>>,
}
impl<State: Clone + Send + Sync + 'static> Request<State>
where
State: Default,
{
#[must_use]
pub fn new(body: Bytes, source_address: SocketAddr, stream_id: Option<u64>) -> Self {
Self {
body,
ext: Extensions::new(),
source_address,
state: Default::default(),
stream_id,
#[cfg(feature = "clients")]
explicit_read_amount: None,
#[cfg(feature = "servers")]
stream_access: None,
}
}
}
impl<State: Clone + Send + Sync + 'static> Request<State> {
#[must_use]
pub fn new_with_state(
body: Bytes,
source_address: SocketAddr,
state: State,
stream_id: Option<u64>,
) -> Self {
Self {
body,
ext: Extensions::new(),
source_address,
state,
stream_id,
#[cfg(feature = "clients")]
explicit_read_amount: None,
#[cfg(feature = "servers")]
stream_access: None,
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
#[cfg(feature = "clients")]
#[must_use]
pub fn new_with_state_and_read_amount(
body: Bytes,
source_address: SocketAddr,
state: State,
stream_id: Option<u64>,
explicit_read_amount: usize,
) -> Self {
Self {
body,
ext: Extensions::new(),
source_address,
state,
stream_id,
explicit_read_amount: Some(explicit_read_amount),
#[cfg(feature = "servers")]
stream_access: None,
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
#[cfg(feature = "servers")]
#[allow(
// TODO(mythra): refactor to type.
clippy::type_complexity,
)]
#[must_use]
pub fn new_with_state_and_stream(
body: Bytes,
source_address: SocketAddr,
state: State,
stream_id: Option<u64>,
stream_and_nagle_cache: Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>,
) -> Self {
Self {
body,
ext: Extensions::new(),
source_address,
state,
stream_id,
#[cfg(feature = "clients")]
explicit_read_amount: None,
stream_access: Some(stream_and_nagle_cache),
}
}
pub fn swap_body(&mut self, new_body: Bytes) {
self.body = new_body;
}
pub const fn update_request_source(&mut self, source: SocketAddr, stream_id: Option<u64>) {
self.source_address = source;
self.stream_id = stream_id;
}
#[must_use]
pub fn stream_id(&self) -> u64 {
if let Some(id) = self.stream_id {
id
} else {
let mut hasher = FnvHasher::default();
self.source_address.hash(&mut hasher);
hasher.finish()
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
#[cfg(feature = "clients")]
#[must_use]
pub const fn explicit_read_amount(&self) -> Option<usize> {
self.explicit_read_amount
}
#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
#[cfg(feature = "clients")]
pub const fn set_explicit_read_amount(&mut self, new_read_amount: usize) {
self.explicit_read_amount = Some(new_read_amount);
}
#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
#[cfg(feature = "servers")]
pub async fn unsafe_read_more_bytes_from_stream(
&self,
to_read: usize,
) -> Result<Bytes, CatBridgeError> {
if let Some(strm) = self.stream_access.as_ref() {
let mut guard = strm.lock().await;
if let Some((opt_cache, stream)) = guard.as_mut() {
let mut buff = BytesMut::with_capacity(to_read);
if let Some(cache) = opt_cache.as_mut() {
if cache.len() <= to_read {
buff = cache.split();
} else {
buff = cache.split_to(to_read);
}
}
if buff.len() < to_read {
stream.readable().await.map_err(NetworkError::IO)?;
let mut needed = to_read - buff.len();
while needed > 0 {
let read = stream.read_buf(&mut buff).await.map_err(NetworkError::IO)?;
needed -= read;
}
}
return Ok::<Bytes, CatBridgeError>(buff.freeze());
}
}
error!("called unsafe_read_more_bytes on a stream that is not processing!");
Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
}
#[must_use]
pub const fn body(&self) -> &Bytes {
&self.body
}
#[must_use]
pub fn body_mut(&mut self) -> &mut Bytes {
&mut self.body
}
pub fn set_body(&mut self, new_body: Bytes) {
self.body = new_body;
}
#[must_use]
pub fn body_owned(self) -> Bytes {
self.body
}
#[must_use]
pub const fn extensions(&self) -> &Extensions {
&self.ext
}
#[must_use]
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.ext
}
#[must_use]
pub fn extensions_owned(self) -> Extensions {
self.ext
}
#[must_use]
pub const fn state(&self) -> &State {
&self.state
}
#[must_use]
pub fn state_mut(&mut self) -> &mut State {
&mut self.state
}
#[must_use]
pub const fn source(&self) -> &SocketAddr {
&self.source_address
}
#[must_use]
pub fn is_ipv4(&self) -> bool {
self.source_address.ip().is_ipv4()
}
#[must_use]
pub fn is_ipv6(&self) -> bool {
self.source_address.ip().is_ipv6()
}
}
impl<State: Clone + Send + Sync + 'static> Clone for Request<State> {
fn clone(&self) -> Self {
Request {
body: self.body.clone(),
ext: Extensions::new(),
source_address: self.source_address,
state: self.state.clone(),
stream_id: self.stream_id,
#[cfg(feature = "clients")]
explicit_read_amount: self.explicit_read_amount,
#[cfg(feature = "servers")]
stream_access: self.stream_access.clone(),
}
}
}
impl<State: Clone + Send + Sync + 'static> Debug for Request<State>
where
State: Debug,
{
fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
let mut dbg_struct = fmt.debug_struct("Request");
dbg_struct
.field("body", &self.body)
.field("source_address", &self.source_address)
.field("stream_id", &self.stream_id);
#[cfg(feature = "clients")]
dbg_struct.field("explicit_read_amount", &self.explicit_read_amount);
#[cfg(feature = "servers")]
dbg_struct.field(
"stream_access",
&if self.stream_access.is_some() {
"<stream>"
} else {
"<none>"
},
);
dbg_struct.finish_non_exhaustive()
}
}
const REQUEST_FIELDS: &[NamedField<'static>] = &[
NamedField::new("body"),
NamedField::new("source_address"),
NamedField::new("stream_id"),
#[cfg(feature = "clients")]
NamedField::new("explicit_read_amount"),
#[cfg(feature = "servers")]
NamedField::new("stream_access"),
];
impl<State: Clone + Send + Sync + 'static> Structable for Request<State> {
fn definition(&self) -> StructDef<'_> {
StructDef::new_static("Request", Fields::Named(REQUEST_FIELDS))
}
}
impl<State: Clone + Send + Sync + 'static> Valuable for Request<State> {
fn as_value(&self) -> Value<'_> {
Value::Structable(self)
}
fn visit(&self, visitor: &mut dyn Visit) {
visitor.visit_named_fields(&NamedValues::new(
REQUEST_FIELDS,
&[
Valuable::as_value(&format!("{:02X?}", self.body)),
Valuable::as_value(&format!("{}", self.source_address)),
Valuable::as_value(&self.stream_id),
#[cfg(feature = "clients")]
Valuable::as_value(&self.explicit_read_amount),
#[cfg(feature = "servers")]
Valuable::as_value(&if self.stream_access.is_some() {
"<stream>"
} else {
"<none>"
}),
],
));
}
}
#[derive(Clone, Debug)]
pub struct Response {
pub body: Option<Bytes>,
pub request_connection_close: bool,
}
impl Response {
#[must_use]
pub const fn new_empty() -> Self {
Self {
body: None,
request_connection_close: false,
}
}
#[must_use]
pub const fn empty_close() -> Self {
Self {
body: None,
request_connection_close: true,
}
}
#[must_use]
pub const fn new_with_body(body: Bytes) -> Self {
Self {
body: Some(body),
request_connection_close: false,
}
}
#[must_use]
pub const fn body(&self) -> Option<&Bytes> {
self.body.as_ref()
}
#[must_use]
pub fn body_mut(&mut self) -> Option<&mut Bytes> {
self.body.as_mut()
}
pub fn set_body(&mut self, bytes: Bytes) {
self.body = Some(bytes);
}
#[must_use]
pub fn take_body(self) -> Option<Bytes> {
self.body
}
#[must_use]
pub const fn request_connection_close(&self) -> bool {
self.request_connection_close
}
pub fn should_close_connection(&mut self) {
self.request_connection_close = true;
}
pub fn dont_close_connection(&mut self) {
self.request_connection_close = false;
}
}
impl Default for Response {
fn default() -> Self {
Self::new_empty()
}
}
impl<ByteTy: Into<Bytes>> From<ByteTy> for Response {
fn from(resp: ByteTy) -> Self {
Self::new_with_body(resp.into())
}
}
const RESPONSE_FIELDS: &[NamedField<'static>] = &[
NamedField::new("body"),
NamedField::new("request_connection_close"),
];
impl Structable for Response {
fn definition(&self) -> StructDef<'_> {
StructDef::new_static("Response", Fields::Named(RESPONSE_FIELDS))
}
}
impl Valuable for Response {
fn as_value(&self) -> Value<'_> {
Value::Structable(self)
}
fn visit(&self, visitor: &mut dyn Visit) {
visitor.visit_named_fields(&NamedValues::new(
RESPONSE_FIELDS,
&[
Valuable::as_value(&if let Some(body_ref) = self.body.as_ref() {
format!("{body_ref:02X?}")
} else {
"<empty>".to_owned()
}),
Valuable::as_value(&self.request_connection_close),
],
));
}
}
pub trait FromRequestParts<State: Clone + Send + Sync + 'static>: Sized {
fn from_request_parts(
req: &mut Request<State>,
) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
}
pub trait FromRequest<State: Clone + Send + Sync + 'static>: Sized {
fn from_request(
req: Request<State>,
) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
}
impl<State: Clone + Send + Sync + 'static> FromRequest<State> for Request<State> {
async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
Ok(req)
}
}
pub trait IntoResponse: Sized {
fn to_response(self) -> Result<Response, CatBridgeError>;
}
impl IntoResponse for () {
fn to_response(self) -> Result<Response, CatBridgeError> {
Ok(Response::new_empty())
}
}
impl IntoResponse for Response {
fn to_response(self) -> Result<Response, CatBridgeError> {
Ok(self)
}
}
macro_rules! impl_from_ok_response {
($ty:ty) => {
impl IntoResponse for $ty {
fn to_response(self) -> Result<Response, CatBridgeError> {
Ok(self.into())
}
}
};
}
impl_from_ok_response!(Bytes);
impl_from_ok_response!(BytesMut);
impl_from_ok_response!(String);
impl_from_ok_response!(Vec<u8>);
impl_from_ok_response!(&'static [u8]);
impl_from_ok_response!(&'static str);
impl IntoResponse for CatBridgeError {
fn to_response(self) -> Result<Response, CatBridgeError> {
Err(self)
}
}
impl<SomeTy: IntoResponse> IntoResponse for Option<SomeTy> {
fn to_response(self) -> Result<Response, CatBridgeError> {
if let Some(val) = self {
val.to_response()
} else {
Ok(Response::new_empty())
}
}
}
impl<OkTy: IntoResponse> IntoResponse for Result<OkTy, CatBridgeError> {
fn to_response(self) -> Result<Response, CatBridgeError> {
self.and_then(IntoResponse::to_response)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Valuable)]
pub enum NagleGuard {
EndSigilSearch(&'static [u8]),
StaticSize(usize),
U16LengthPrefixed(Endianness, Option<usize>),
U32LengthPrefixed(Endianness, Option<usize>),
}
impl NagleGuard {
pub fn split(&self, buff: &BytesMut) -> Result<Option<(usize, usize)>, CommonNetAPIError> {
match *self {
NagleGuard::EndSigilSearch(sigil) => {
if sigil.is_empty() {
return Err(CommonNetAPIError::NagleGuardEndSigilCannotBeEmpty);
}
if buff.is_empty() {
return Ok(None);
}
for (idx, byte) in buff.iter().enumerate() {
if idx + sigil.len() > buff.len() {
break;
}
if *byte == sigil[0] && sigil == &buff[idx..(idx + sigil.len())] {
return Ok(Some((0, idx + sigil.len())));
}
}
}
NagleGuard::StaticSize(size) => {
if buff.len() < size {
return Ok(None);
}
return Ok(Some((0, size)));
}
NagleGuard::U16LengthPrefixed(endianness, extra_len) => {
if buff.len() < 2 {
return Ok(None);
}
let extra_len_frd = extra_len.unwrap_or_default();
let total_size = match endianness {
Endianness::Little => u16::from_le_bytes([buff[0], buff[1]]),
Endianness::Big => u16::from_be_bytes([buff[0], buff[1]]),
};
if buff.len() >= usize::from(total_size) + extra_len_frd {
return Ok(Some((0, usize::from(total_size) + extra_len_frd)));
}
}
NagleGuard::U32LengthPrefixed(endianness, extra_len) => {
if buff.len() < 4 {
return Ok(None);
}
let extra_len_frd = extra_len.unwrap_or_default();
let total_size = match endianness {
Endianness::Little => u32::from_le_bytes([buff[0], buff[1], buff[2], buff[3]]),
Endianness::Big => u32::from_be_bytes([buff[0], buff[1], buff[2], buff[3]]),
};
if buff.len() >= usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd {
return Ok(Some((
0,
usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd,
)));
}
}
}
Ok(None)
}
}
impl From<usize> for NagleGuard {
fn from(value: usize) -> Self {
NagleGuard::StaticSize(value)
}
}
impl From<&'static [u8]> for NagleGuard {
fn from(value: &'static [u8]) -> Self {
NagleGuard::EndSigilSearch(value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Valuable)]
pub enum Endianness {
Little,
Big,
}
pub trait PreNagleFnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static {}
impl<FnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static> PreNagleFnTy for FnTy {}
pub trait PostNagleFnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static {}
impl<FnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static> PostNagleFnTy for FnTy {}