1use crate::{
7 errors::CatBridgeError,
8 net::{Extensions, errors::CommonNetAPIError},
9};
10use bytes::{Bytes, BytesMut};
11use fnv::FnvHasher;
12use futures::Future;
13use std::{
14 fmt::{Debug, Formatter, Result as FmtResult},
15 hash::{Hash, Hasher},
16 marker::Send,
17 net::SocketAddr,
18};
19use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
20
21#[cfg(feature = "servers")]
22use crate::{errors::NetworkError, net::errors::CommonNetNetworkError};
23#[cfg(feature = "servers")]
24use std::sync::Arc;
25#[cfg(feature = "servers")]
26use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
27#[cfg(feature = "servers")]
28use tracing::error;
29
30pub trait FromRef<InputTy> {
35 fn from_ref(input: &InputTy) -> Self;
37}
38
39impl<InnerTy> FromRef<InnerTy> for InnerTy
40where
41 InnerTy: Clone,
42{
43 fn from_ref(input: &InnerTy) -> Self {
44 input.clone()
45 }
46}
47
48pub struct Request<State: Clone + Send + Sync + 'static> {
50 body: Bytes,
52 ext: Extensions,
54 source_address: SocketAddr,
56 state: State,
58 stream_id: Option<u64>,
60 #[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
66 #[cfg(feature = "clients")]
67 explicit_read_amount: Option<usize>,
68 #[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
70 #[cfg(feature = "servers")]
71 #[allow(
72 clippy::type_complexity,
74 )]
75 stream_access: Option<Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>>,
76}
77
78impl<State: Clone + Send + Sync + 'static> Request<State>
79where
80 State: Default,
81{
82 #[must_use]
83 pub fn new(body: Bytes, source_address: SocketAddr, stream_id: Option<u64>) -> Self {
84 Self {
85 body,
86 ext: Extensions::new(),
87 source_address,
88 state: Default::default(),
89 stream_id,
90 #[cfg(feature = "clients")]
91 explicit_read_amount: None,
92 #[cfg(feature = "servers")]
93 stream_access: None,
94 }
95 }
96}
97
98impl<State: Clone + Send + Sync + 'static> Request<State> {
99 #[must_use]
100 pub fn new_with_state(
101 body: Bytes,
102 source_address: SocketAddr,
103 state: State,
104 stream_id: Option<u64>,
105 ) -> Self {
106 Self {
107 body,
108 ext: Extensions::new(),
109 source_address,
110 state,
111 stream_id,
112 #[cfg(feature = "clients")]
113 explicit_read_amount: None,
114 #[cfg(feature = "servers")]
115 stream_access: None,
116 }
117 }
118
119 #[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
120 #[cfg(feature = "clients")]
121 #[must_use]
122 pub fn new_with_state_and_read_amount(
123 body: Bytes,
124 source_address: SocketAddr,
125 state: State,
126 stream_id: Option<u64>,
127 explicit_read_amount: usize,
128 ) -> Self {
129 Self {
130 body,
131 ext: Extensions::new(),
132 source_address,
133 state,
134 stream_id,
135 explicit_read_amount: Some(explicit_read_amount),
136 #[cfg(feature = "servers")]
137 stream_access: None,
138 }
139 }
140
141 #[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
142 #[cfg(feature = "servers")]
143 #[allow(
144 clippy::type_complexity,
146 )]
147 #[must_use]
148 pub fn new_with_state_and_stream(
149 body: Bytes,
150 source_address: SocketAddr,
151 state: State,
152 stream_id: Option<u64>,
153 stream_and_nagle_cache: Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>,
154 ) -> Self {
155 Self {
156 body,
157 ext: Extensions::new(),
158 source_address,
159 state,
160 stream_id,
161 #[cfg(feature = "clients")]
162 explicit_read_amount: None,
163 stream_access: Some(stream_and_nagle_cache),
164 }
165 }
166
167 pub fn swap_body(&mut self, new_body: Bytes) {
169 self.body = new_body;
170 }
171
172 pub const fn update_request_source(&mut self, source: SocketAddr, stream_id: Option<u64>) {
174 self.source_address = source;
175 self.stream_id = stream_id;
176 }
177
178 #[must_use]
183 pub fn stream_id(&self) -> u64 {
184 if let Some(id) = self.stream_id {
185 id
186 } else {
187 let mut hasher = FnvHasher::default();
188 self.source_address.hash(&mut hasher);
189 hasher.finish()
190 }
191 }
192
193 #[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
200 #[cfg(feature = "clients")]
201 #[must_use]
202 pub const fn explicit_read_amount(&self) -> Option<usize> {
203 self.explicit_read_amount
204 }
205
206 #[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
209 #[cfg(feature = "clients")]
210 pub const fn set_explicit_read_amount(&mut self, new_read_amount: usize) {
211 self.explicit_read_amount = Some(new_read_amount);
212 }
213
214 #[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
228 #[cfg(feature = "servers")]
229 pub async fn unsafe_read_more_bytes_from_stream(
230 &self,
231 to_read: usize,
232 ) -> Result<Bytes, CatBridgeError> {
233 if let Some(strm) = self.stream_access.as_ref() {
234 let mut guard = strm.lock().await;
235
236 if let Some((opt_cache, stream)) = guard.as_mut() {
237 let mut buff = BytesMut::with_capacity(to_read);
238
239 if let Some(cache) = opt_cache.as_mut() {
240 if cache.len() <= to_read {
241 buff = cache.split();
242 } else {
243 buff = cache.split_to(to_read);
244 }
245 }
246
247 if buff.len() < to_read {
248 stream.readable().await.map_err(NetworkError::IO)?;
249 let mut needed = to_read - buff.len();
250 while needed > 0 {
251 let read = stream.read_buf(&mut buff).await.map_err(NetworkError::IO)?;
252 needed -= read;
253 }
254 }
255 return Ok::<Bytes, CatBridgeError>(buff.freeze());
256 }
257 }
258
259 error!("called unsafe_read_more_bytes on a stream that is not processing!");
260 Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
261 }
262
263 #[must_use]
264 pub const fn body(&self) -> &Bytes {
265 &self.body
266 }
267 #[must_use]
268 pub fn body_mut(&mut self) -> &mut Bytes {
269 &mut self.body
270 }
271 pub fn set_body(&mut self, new_body: Bytes) {
272 self.body = new_body;
273 }
274 #[must_use]
275 pub fn body_owned(self) -> Bytes {
276 self.body
277 }
278
279 #[must_use]
280 pub const fn extensions(&self) -> &Extensions {
281 &self.ext
282 }
283 #[must_use]
284 pub fn extensions_mut(&mut self) -> &mut Extensions {
285 &mut self.ext
286 }
287 #[must_use]
288 pub fn extensions_owned(self) -> Extensions {
289 self.ext
290 }
291
292 #[must_use]
293 pub const fn state(&self) -> &State {
294 &self.state
295 }
296 #[must_use]
297 pub fn state_mut(&mut self) -> &mut State {
298 &mut self.state
299 }
300
301 #[must_use]
302 pub const fn source(&self) -> &SocketAddr {
303 &self.source_address
304 }
305 #[must_use]
306 pub fn is_ipv4(&self) -> bool {
307 self.source_address.ip().is_ipv4()
308 }
309 #[must_use]
310 pub fn is_ipv6(&self) -> bool {
311 self.source_address.ip().is_ipv6()
312 }
313}
314
315impl<State: Clone + Send + Sync + 'static> Clone for Request<State> {
316 fn clone(&self) -> Self {
317 Request {
318 body: self.body.clone(),
319 ext: Extensions::new(),
320 source_address: self.source_address,
321 state: self.state.clone(),
322 stream_id: self.stream_id,
323 #[cfg(feature = "clients")]
324 explicit_read_amount: self.explicit_read_amount,
325 #[cfg(feature = "servers")]
326 stream_access: self.stream_access.clone(),
327 }
328 }
329}
330
331impl<State: Clone + Send + Sync + 'static> Debug for Request<State>
332where
333 State: Debug,
334{
335 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
336 let mut dbg_struct = fmt.debug_struct("Request");
337
338 dbg_struct
339 .field("body", &self.body)
340 .field("source_address", &self.source_address)
343 .field("stream_id", &self.stream_id);
344
345 #[cfg(feature = "clients")]
346 dbg_struct.field("explicit_read_amount", &self.explicit_read_amount);
347 #[cfg(feature = "servers")]
348 dbg_struct.field(
349 "stream_access",
350 &if self.stream_access.is_some() {
351 "<stream>"
352 } else {
353 "<none>"
354 },
355 );
356
357 dbg_struct.finish_non_exhaustive()
358 }
359}
360
361const REQUEST_FIELDS: &[NamedField<'static>] = &[
362 NamedField::new("body"),
363 NamedField::new("source_address"),
364 NamedField::new("stream_id"),
365 #[cfg(feature = "clients")]
366 NamedField::new("explicit_read_amount"),
367 #[cfg(feature = "servers")]
368 NamedField::new("stream_access"),
369];
370
371impl<State: Clone + Send + Sync + 'static> Structable for Request<State> {
372 fn definition(&self) -> StructDef<'_> {
373 StructDef::new_static("Request", Fields::Named(REQUEST_FIELDS))
374 }
375}
376
377impl<State: Clone + Send + Sync + 'static> Valuable for Request<State> {
378 fn as_value(&self) -> Value<'_> {
379 Value::Structable(self)
380 }
381
382 fn visit(&self, visitor: &mut dyn Visit) {
383 visitor.visit_named_fields(&NamedValues::new(
384 REQUEST_FIELDS,
385 &[
386 Valuable::as_value(&format!("{:02X?}", self.body)),
387 Valuable::as_value(&format!("{}", self.source_address)),
388 Valuable::as_value(&self.stream_id),
389 #[cfg(feature = "clients")]
390 Valuable::as_value(&self.explicit_read_amount),
391 #[cfg(feature = "servers")]
392 Valuable::as_value(&if self.stream_access.is_some() {
393 "<stream>"
394 } else {
395 "<none>"
396 }),
397 ],
398 ));
399 }
400}
401
402#[derive(Clone, Debug)]
404pub struct Response {
405 pub body: Option<Bytes>,
407 pub request_connection_close: bool,
412}
413
414impl Response {
415 #[must_use]
416 pub const fn new_empty() -> Self {
417 Self {
418 body: None,
419 request_connection_close: false,
420 }
421 }
422 #[must_use]
423 pub const fn empty_close() -> Self {
424 Self {
425 body: None,
426 request_connection_close: true,
427 }
428 }
429 #[must_use]
430 pub const fn new_with_body(body: Bytes) -> Self {
431 Self {
432 body: Some(body),
433 request_connection_close: false,
434 }
435 }
436
437 #[must_use]
438 pub const fn body(&self) -> Option<&Bytes> {
439 self.body.as_ref()
440 }
441 #[must_use]
442 pub fn body_mut(&mut self) -> Option<&mut Bytes> {
443 self.body.as_mut()
444 }
445 pub fn set_body(&mut self, bytes: Bytes) {
446 self.body = Some(bytes);
447 }
448 #[must_use]
449 pub fn take_body(self) -> Option<Bytes> {
450 self.body
451 }
452
453 #[must_use]
454 pub const fn request_connection_close(&self) -> bool {
455 self.request_connection_close
456 }
457 pub fn should_close_connection(&mut self) {
458 self.request_connection_close = true;
459 }
460 pub fn dont_close_connection(&mut self) {
461 self.request_connection_close = false;
462 }
463}
464
465impl Default for Response {
466 fn default() -> Self {
467 Self::new_empty()
468 }
469}
470
471impl<ByteTy: Into<Bytes>> From<ByteTy> for Response {
472 fn from(resp: ByteTy) -> Self {
473 Self::new_with_body(resp.into())
474 }
475}
476
477const RESPONSE_FIELDS: &[NamedField<'static>] = &[
478 NamedField::new("body"),
479 NamedField::new("request_connection_close"),
480];
481
482impl Structable for Response {
483 fn definition(&self) -> StructDef<'_> {
484 StructDef::new_static("Response", Fields::Named(RESPONSE_FIELDS))
485 }
486}
487
488impl Valuable for Response {
489 fn as_value(&self) -> Value<'_> {
490 Value::Structable(self)
491 }
492
493 fn visit(&self, visitor: &mut dyn Visit) {
494 visitor.visit_named_fields(&NamedValues::new(
495 RESPONSE_FIELDS,
496 &[
497 Valuable::as_value(&if let Some(body_ref) = self.body.as_ref() {
498 format!("{body_ref:02X?}")
499 } else {
500 "<empty>".to_owned()
501 }),
502 Valuable::as_value(&self.request_connection_close),
503 ],
504 ));
505 }
506}
507
508pub trait FromRequestParts<State: Clone + Send + Sync + 'static>: Sized {
512 fn from_request_parts(
513 req: &mut Request<State>,
514 ) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
515}
516
517pub trait FromRequest<State: Clone + Send + Sync + 'static>: Sized {
521 fn from_request(
522 req: Request<State>,
523 ) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
524}
525impl<State: Clone + Send + Sync + 'static> FromRequest<State> for Request<State> {
526 async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
527 Ok(req)
528 }
529}
530
531pub trait IntoResponse: Sized {
537 fn to_response(self) -> Result<Response, CatBridgeError>;
543}
544
545impl IntoResponse for () {
546 fn to_response(self) -> Result<Response, CatBridgeError> {
547 Ok(Response::new_empty())
548 }
549}
550impl IntoResponse for Response {
551 fn to_response(self) -> Result<Response, CatBridgeError> {
552 Ok(self)
553 }
554}
555
556macro_rules! impl_from_ok_response {
557 ($ty:ty) => {
558 impl IntoResponse for $ty {
559 fn to_response(self) -> Result<Response, CatBridgeError> {
560 Ok(self.into())
561 }
562 }
563 };
564}
565impl_from_ok_response!(Bytes);
566impl_from_ok_response!(BytesMut);
567impl_from_ok_response!(String);
568impl_from_ok_response!(Vec<u8>);
569impl_from_ok_response!(&'static [u8]);
570impl_from_ok_response!(&'static str);
571
572impl IntoResponse for CatBridgeError {
573 fn to_response(self) -> Result<Response, CatBridgeError> {
574 Err(self)
575 }
576}
577
578impl<SomeTy: IntoResponse> IntoResponse for Option<SomeTy> {
579 fn to_response(self) -> Result<Response, CatBridgeError> {
580 if let Some(val) = self {
581 val.to_response()
582 } else {
583 Ok(Response::new_empty())
584 }
585 }
586}
587impl<OkTy: IntoResponse> IntoResponse for Result<OkTy, CatBridgeError> {
588 fn to_response(self) -> Result<Response, CatBridgeError> {
589 self.and_then(IntoResponse::to_response)
590 }
591}
592
593#[derive(Clone, Debug, PartialEq, Eq, Hash, Valuable)]
598pub enum NagleGuard {
599 EndSigilSearch(&'static [u8]),
602 StaticSize(usize),
604 U16LengthPrefixed(Endianness, Option<usize>),
610 U32LengthPrefixed(Endianness, Option<usize>),
616}
617
618impl NagleGuard {
619 pub fn split(&self, buff: &BytesMut) -> Result<Option<(usize, usize)>, CommonNetAPIError> {
625 match *self {
626 NagleGuard::EndSigilSearch(sigil) => {
627 if sigil.is_empty() {
628 return Err(CommonNetAPIError::NagleGuardEndSigilCannotBeEmpty);
629 }
630 if buff.is_empty() {
631 return Ok(None);
632 }
633
634 for (idx, byte) in buff.iter().enumerate() {
635 if idx + sigil.len() > buff.len() {
637 break;
638 }
639 if *byte == sigil[0] && sigil == &buff[idx..(idx + sigil.len())] {
640 return Ok(Some((0, idx + sigil.len())));
641 }
642 }
643 }
644 NagleGuard::StaticSize(size) => {
645 if buff.len() < size {
646 return Ok(None);
647 }
648
649 return Ok(Some((0, size)));
650 }
651 NagleGuard::U16LengthPrefixed(endianness, extra_len) => {
652 if buff.len() < 2 {
653 return Ok(None);
654 }
655 let extra_len_frd = extra_len.unwrap_or_default();
656
657 let total_size = match endianness {
658 Endianness::Little => u16::from_le_bytes([buff[0], buff[1]]),
659 Endianness::Big => u16::from_be_bytes([buff[0], buff[1]]),
660 };
661 if buff.len() >= usize::from(total_size) + extra_len_frd {
662 return Ok(Some((0, usize::from(total_size) + extra_len_frd)));
663 }
664 }
665 NagleGuard::U32LengthPrefixed(endianness, extra_len) => {
666 if buff.len() < 4 {
667 return Ok(None);
668 }
669 let extra_len_frd = extra_len.unwrap_or_default();
670
671 let total_size = match endianness {
672 Endianness::Little => u32::from_le_bytes([buff[0], buff[1], buff[2], buff[3]]),
673 Endianness::Big => u32::from_be_bytes([buff[0], buff[1], buff[2], buff[3]]),
674 };
675 if buff.len() >= usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd {
676 return Ok(Some((
677 0,
678 usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd,
679 )));
680 }
681 }
682 }
683
684 Ok(None)
685 }
686}
687
688impl From<usize> for NagleGuard {
689 fn from(value: usize) -> Self {
690 NagleGuard::StaticSize(value)
691 }
692}
693
694impl From<&'static [u8]> for NagleGuard {
695 fn from(value: &'static [u8]) -> Self {
696 NagleGuard::EndSigilSearch(value)
697 }
698}
699
700#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Valuable)]
702pub enum Endianness {
703 Little,
705 Big,
707}
708
709pub trait PreNagleFnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static {}
715impl<FnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static> PreNagleFnTy for FnTy {}
716
717pub trait PostNagleFnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static {}
723impl<FnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static> PostNagleFnTy for FnTy {}