1use crate::{
7	errors::CatBridgeError,
8	net::{Extensions, errors::CommonNetAPIError},
9};
10use bytes::{Bytes, BytesMut};
11use fnv::FnvHasher;
12use futures::Future;
13use std::{
14	fmt::{Debug, Formatter, Result as FmtResult},
15	hash::{Hash, Hasher},
16	marker::Send,
17	net::SocketAddr,
18};
19use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
20
21#[cfg(feature = "servers")]
22use crate::{errors::NetworkError, net::errors::CommonNetNetworkError};
23#[cfg(feature = "servers")]
24use std::sync::Arc;
25#[cfg(feature = "servers")]
26use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
27#[cfg(feature = "servers")]
28use tracing::error;
29
30pub trait FromRef<InputTy> {
35	fn from_ref(input: &InputTy) -> Self;
37}
38
39impl<InnerTy> FromRef<InnerTy> for InnerTy
40where
41	InnerTy: Clone,
42{
43	fn from_ref(input: &InnerTy) -> Self {
44		input.clone()
45	}
46}
47
48pub struct Request<State: Clone + Send + Sync + 'static> {
50	body: Bytes,
52	ext: Extensions,
54	source_address: SocketAddr,
56	state: State,
58	stream_id: Option<u64>,
60	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
66	#[cfg(feature = "clients")]
67	explicit_read_amount: Option<usize>,
68	#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
70	#[cfg(feature = "servers")]
71	#[allow(
72		clippy::type_complexity,
74	)]
75	stream_access: Option<Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>>,
76}
77
78impl<State: Clone + Send + Sync + 'static> Request<State>
79where
80	State: Default,
81{
82	#[must_use]
83	pub fn new(body: Bytes, source_address: SocketAddr, stream_id: Option<u64>) -> Self {
84		Self {
85			body,
86			ext: Extensions::new(),
87			source_address,
88			state: Default::default(),
89			stream_id,
90			#[cfg(feature = "clients")]
91			explicit_read_amount: None,
92			#[cfg(feature = "servers")]
93			stream_access: None,
94		}
95	}
96}
97
98impl<State: Clone + Send + Sync + 'static> Request<State> {
99	#[must_use]
100	pub fn new_with_state(
101		body: Bytes,
102		source_address: SocketAddr,
103		state: State,
104		stream_id: Option<u64>,
105	) -> Self {
106		Self {
107			body,
108			ext: Extensions::new(),
109			source_address,
110			state,
111			stream_id,
112			#[cfg(feature = "clients")]
113			explicit_read_amount: None,
114			#[cfg(feature = "servers")]
115			stream_access: None,
116		}
117	}
118
119	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
120	#[cfg(feature = "clients")]
121	#[must_use]
122	pub fn new_with_state_and_read_amount(
123		body: Bytes,
124		source_address: SocketAddr,
125		state: State,
126		stream_id: Option<u64>,
127		explicit_read_amount: usize,
128	) -> Self {
129		Self {
130			body,
131			ext: Extensions::new(),
132			source_address,
133			state,
134			stream_id,
135			explicit_read_amount: Some(explicit_read_amount),
136			#[cfg(feature = "servers")]
137			stream_access: None,
138		}
139	}
140
141	#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
142	#[cfg(feature = "servers")]
143	#[allow(
144		clippy::type_complexity,
146	)]
147	#[must_use]
148	pub fn new_with_state_and_stream(
149		body: Bytes,
150		source_address: SocketAddr,
151		state: State,
152		stream_id: Option<u64>,
153		stream_and_nagle_cache: Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>,
154	) -> Self {
155		Self {
156			body,
157			ext: Extensions::new(),
158			source_address,
159			state,
160			stream_id,
161			#[cfg(feature = "clients")]
162			explicit_read_amount: None,
163			stream_access: Some(stream_and_nagle_cache),
164		}
165	}
166
167	pub fn swap_body(&mut self, new_body: Bytes) {
169		self.body = new_body;
170	}
171
172	pub const fn update_request_source(&mut self, source: SocketAddr, stream_id: Option<u64>) {
174		self.source_address = source;
175		self.stream_id = stream_id;
176	}
177
178	#[must_use]
183	pub fn stream_id(&self) -> u64 {
184		if let Some(id) = self.stream_id {
185			id
186		} else {
187			let mut hasher = FnvHasher::default();
188			self.source_address.hash(&mut hasher);
189			hasher.finish()
190		}
191	}
192
193	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
200	#[cfg(feature = "clients")]
201	#[must_use]
202	pub const fn explicit_read_amount(&self) -> Option<usize> {
203		self.explicit_read_amount
204	}
205
206	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
209	#[cfg(feature = "clients")]
210	pub const fn set_explicit_read_amount(&mut self, new_read_amount: usize) {
211		self.explicit_read_amount = Some(new_read_amount);
212	}
213
214	#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
228	#[cfg(feature = "servers")]
229	pub async fn unsafe_read_more_bytes_from_stream(
230		&self,
231		to_read: usize,
232	) -> Result<Bytes, CatBridgeError> {
233		if let Some(strm) = self.stream_access.as_ref() {
234			let mut guard = strm.lock().await;
235
236			if let Some((opt_cache, stream)) = guard.as_mut() {
237				let mut buff = BytesMut::with_capacity(to_read);
238
239				if let Some(cache) = opt_cache.as_mut() {
240					if cache.len() <= to_read {
241						buff = cache.split();
242					} else {
243						buff = cache.split_to(to_read);
244					}
245				}
246
247				if buff.len() < to_read {
248					stream.readable().await.map_err(NetworkError::IO)?;
249					let mut needed = to_read - buff.len();
250					while needed > 0 {
251						let read = stream.read_buf(&mut buff).await.map_err(NetworkError::IO)?;
252						needed -= read;
253					}
254				}
255				return Ok::<Bytes, CatBridgeError>(buff.freeze());
256			}
257		}
258
259		error!("called unsafe_read_more_bytes on a stream that is not processing!");
260		Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
261	}
262
263	#[must_use]
264	pub const fn body(&self) -> &Bytes {
265		&self.body
266	}
267	#[must_use]
268	pub fn body_mut(&mut self) -> &mut Bytes {
269		&mut self.body
270	}
271	pub fn set_body(&mut self, new_body: Bytes) {
272		self.body = new_body;
273	}
274	#[must_use]
275	pub fn body_owned(self) -> Bytes {
276		self.body
277	}
278
279	#[must_use]
280	pub const fn extensions(&self) -> &Extensions {
281		&self.ext
282	}
283	#[must_use]
284	pub fn extensions_mut(&mut self) -> &mut Extensions {
285		&mut self.ext
286	}
287	#[must_use]
288	pub fn extensions_owned(self) -> Extensions {
289		self.ext
290	}
291
292	#[must_use]
293	pub const fn state(&self) -> &State {
294		&self.state
295	}
296	#[must_use]
297	pub fn state_mut(&mut self) -> &mut State {
298		&mut self.state
299	}
300
301	#[must_use]
302	pub const fn source(&self) -> &SocketAddr {
303		&self.source_address
304	}
305	#[must_use]
306	pub fn is_ipv4(&self) -> bool {
307		self.source_address.ip().is_ipv4()
308	}
309	#[must_use]
310	pub fn is_ipv6(&self) -> bool {
311		self.source_address.ip().is_ipv6()
312	}
313}
314
315impl<State: Clone + Send + Sync + 'static> Clone for Request<State> {
316	fn clone(&self) -> Self {
317		Request {
318			body: self.body.clone(),
319			ext: Extensions::new(),
320			source_address: self.source_address,
321			state: self.state.clone(),
322			stream_id: self.stream_id,
323			#[cfg(feature = "clients")]
324			explicit_read_amount: self.explicit_read_amount,
325			#[cfg(feature = "servers")]
326			stream_access: self.stream_access.clone(),
327		}
328	}
329}
330
331impl<State: Clone + Send + Sync + 'static> Debug for Request<State>
332where
333	State: Debug,
334{
335	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
336		let mut dbg_struct = fmt.debug_struct("Request");
337
338		dbg_struct
339			.field("body", &self.body)
340			.field("source_address", &self.source_address)
343			.field("stream_id", &self.stream_id);
344
345		#[cfg(feature = "clients")]
346		dbg_struct.field("explicit_read_amount", &self.explicit_read_amount);
347		#[cfg(feature = "servers")]
348		dbg_struct.field(
349			"stream_access",
350			&if self.stream_access.is_some() {
351				"<stream>"
352			} else {
353				"<none>"
354			},
355		);
356
357		dbg_struct.finish_non_exhaustive()
358	}
359}
360
361const REQUEST_FIELDS: &[NamedField<'static>] = &[
362	NamedField::new("body"),
363	NamedField::new("source_address"),
364	NamedField::new("stream_id"),
365	#[cfg(feature = "clients")]
366	NamedField::new("explicit_read_amount"),
367	#[cfg(feature = "servers")]
368	NamedField::new("stream_access"),
369];
370
371impl<State: Clone + Send + Sync + 'static> Structable for Request<State> {
372	fn definition(&self) -> StructDef<'_> {
373		StructDef::new_static("Request", Fields::Named(REQUEST_FIELDS))
374	}
375}
376
377impl<State: Clone + Send + Sync + 'static> Valuable for Request<State> {
378	fn as_value(&self) -> Value<'_> {
379		Value::Structable(self)
380	}
381
382	fn visit(&self, visitor: &mut dyn Visit) {
383		visitor.visit_named_fields(&NamedValues::new(
384			REQUEST_FIELDS,
385			&[
386				Valuable::as_value(&format!("{:02X?}", self.body)),
387				Valuable::as_value(&format!("{}", self.source_address)),
388				Valuable::as_value(&self.stream_id),
389				#[cfg(feature = "clients")]
390				Valuable::as_value(&self.explicit_read_amount),
391				#[cfg(feature = "servers")]
392				Valuable::as_value(&if self.stream_access.is_some() {
393					"<stream>"
394				} else {
395					"<none>"
396				}),
397			],
398		));
399	}
400}
401
402#[derive(Clone, Debug)]
404pub struct Response {
405	pub body: Option<Bytes>,
407	pub request_connection_close: bool,
412}
413
414impl Response {
415	#[must_use]
416	pub const fn new_empty() -> Self {
417		Self {
418			body: None,
419			request_connection_close: false,
420		}
421	}
422	#[must_use]
423	pub const fn empty_close() -> Self {
424		Self {
425			body: None,
426			request_connection_close: true,
427		}
428	}
429	#[must_use]
430	pub const fn new_with_body(body: Bytes) -> Self {
431		Self {
432			body: Some(body),
433			request_connection_close: false,
434		}
435	}
436
437	#[must_use]
438	pub const fn body(&self) -> Option<&Bytes> {
439		self.body.as_ref()
440	}
441	#[must_use]
442	pub fn body_mut(&mut self) -> Option<&mut Bytes> {
443		self.body.as_mut()
444	}
445	pub fn set_body(&mut self, bytes: Bytes) {
446		self.body = Some(bytes);
447	}
448	#[must_use]
449	pub fn take_body(self) -> Option<Bytes> {
450		self.body
451	}
452
453	#[must_use]
454	pub const fn request_connection_close(&self) -> bool {
455		self.request_connection_close
456	}
457	pub fn should_close_connection(&mut self) {
458		self.request_connection_close = true;
459	}
460	pub fn dont_close_connection(&mut self) {
461		self.request_connection_close = false;
462	}
463}
464
465impl Default for Response {
466	fn default() -> Self {
467		Self::new_empty()
468	}
469}
470
471impl<ByteTy: Into<Bytes>> From<ByteTy> for Response {
472	fn from(resp: ByteTy) -> Self {
473		Self::new_with_body(resp.into())
474	}
475}
476
477const RESPONSE_FIELDS: &[NamedField<'static>] = &[
478	NamedField::new("body"),
479	NamedField::new("request_connection_close"),
480];
481
482impl Structable for Response {
483	fn definition(&self) -> StructDef<'_> {
484		StructDef::new_static("Response", Fields::Named(RESPONSE_FIELDS))
485	}
486}
487
488impl Valuable for Response {
489	fn as_value(&self) -> Value<'_> {
490		Value::Structable(self)
491	}
492
493	fn visit(&self, visitor: &mut dyn Visit) {
494		visitor.visit_named_fields(&NamedValues::new(
495			RESPONSE_FIELDS,
496			&[
497				Valuable::as_value(&if let Some(body_ref) = self.body.as_ref() {
498					format!("{body_ref:02X?}")
499				} else {
500					"<empty>".to_owned()
501				}),
502				Valuable::as_value(&self.request_connection_close),
503			],
504		));
505	}
506}
507
508pub trait FromRequestParts<State: Clone + Send + Sync + 'static>: Sized {
512	fn from_request_parts(
513		req: &mut Request<State>,
514	) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
515}
516
517pub trait FromRequest<State: Clone + Send + Sync + 'static>: Sized {
521	fn from_request(
522		req: Request<State>,
523	) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
524}
525impl<State: Clone + Send + Sync + 'static> FromRequest<State> for Request<State> {
526	async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
527		Ok(req)
528	}
529}
530
531pub trait IntoResponse: Sized {
537	fn to_response(self) -> Result<Response, CatBridgeError>;
543}
544
545impl IntoResponse for () {
546	fn to_response(self) -> Result<Response, CatBridgeError> {
547		Ok(Response::new_empty())
548	}
549}
550impl IntoResponse for Response {
551	fn to_response(self) -> Result<Response, CatBridgeError> {
552		Ok(self)
553	}
554}
555
556macro_rules! impl_from_ok_response {
557	($ty:ty) => {
558		impl IntoResponse for $ty {
559			fn to_response(self) -> Result<Response, CatBridgeError> {
560				Ok(self.into())
561			}
562		}
563	};
564}
565impl_from_ok_response!(Bytes);
566impl_from_ok_response!(BytesMut);
567impl_from_ok_response!(String);
568impl_from_ok_response!(Vec<u8>);
569impl_from_ok_response!(&'static [u8]);
570impl_from_ok_response!(&'static str);
571
572impl IntoResponse for CatBridgeError {
573	fn to_response(self) -> Result<Response, CatBridgeError> {
574		Err(self)
575	}
576}
577
578impl<SomeTy: IntoResponse> IntoResponse for Option<SomeTy> {
579	fn to_response(self) -> Result<Response, CatBridgeError> {
580		if let Some(val) = self {
581			val.to_response()
582		} else {
583			Ok(Response::new_empty())
584		}
585	}
586}
587impl<OkTy: IntoResponse> IntoResponse for Result<OkTy, CatBridgeError> {
588	fn to_response(self) -> Result<Response, CatBridgeError> {
589		self.and_then(IntoResponse::to_response)
590	}
591}
592
593#[derive(Clone, Debug, PartialEq, Eq, Hash, Valuable)]
598pub enum NagleGuard {
599	EndSigilSearch(&'static [u8]),
602	StaticSize(usize),
604	U16LengthPrefixed(Endianness, Option<usize>),
610	U32LengthPrefixed(Endianness, Option<usize>),
616}
617
618impl NagleGuard {
619	pub fn split(&self, buff: &BytesMut) -> Result<Option<(usize, usize)>, CommonNetAPIError> {
625		match *self {
626			NagleGuard::EndSigilSearch(sigil) => {
627				if sigil.is_empty() {
628					return Err(CommonNetAPIError::NagleGuardEndSigilCannotBeEmpty);
629				}
630				if buff.is_empty() {
631					return Ok(None);
632				}
633
634				for (idx, byte) in buff.iter().enumerate() {
635					if idx + sigil.len() > buff.len() {
637						break;
638					}
639					if *byte == sigil[0] && sigil == &buff[idx..(idx + sigil.len())] {
640						return Ok(Some((0, idx + sigil.len())));
641					}
642				}
643			}
644			NagleGuard::StaticSize(size) => {
645				if buff.len() < size {
646					return Ok(None);
647				}
648
649				return Ok(Some((0, size)));
650			}
651			NagleGuard::U16LengthPrefixed(endianness, extra_len) => {
652				if buff.len() < 2 {
653					return Ok(None);
654				}
655				let extra_len_frd = extra_len.unwrap_or_default();
656
657				let total_size = match endianness {
658					Endianness::Little => u16::from_le_bytes([buff[0], buff[1]]),
659					Endianness::Big => u16::from_be_bytes([buff[0], buff[1]]),
660				};
661				if buff.len() >= usize::from(total_size) + extra_len_frd {
662					return Ok(Some((0, usize::from(total_size) + extra_len_frd)));
663				}
664			}
665			NagleGuard::U32LengthPrefixed(endianness, extra_len) => {
666				if buff.len() < 4 {
667					return Ok(None);
668				}
669				let extra_len_frd = extra_len.unwrap_or_default();
670
671				let total_size = match endianness {
672					Endianness::Little => u32::from_le_bytes([buff[0], buff[1], buff[2], buff[3]]),
673					Endianness::Big => u32::from_be_bytes([buff[0], buff[1], buff[2], buff[3]]),
674				};
675				if buff.len() >= usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd {
676					return Ok(Some((
677						0,
678						usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd,
679					)));
680				}
681			}
682		}
683
684		Ok(None)
685	}
686}
687
688impl From<usize> for NagleGuard {
689	fn from(value: usize) -> Self {
690		NagleGuard::StaticSize(value)
691	}
692}
693
694impl From<&'static [u8]> for NagleGuard {
695	fn from(value: &'static [u8]) -> Self {
696		NagleGuard::EndSigilSearch(value)
697	}
698}
699
700#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Valuable)]
702pub enum Endianness {
703	Little,
705	Big,
707}
708
709pub trait PreNagleFnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static {}
715impl<FnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static> PreNagleFnTy for FnTy {}
716
717pub trait PostNagleFnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static {}
723impl<FnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static> PostNagleFnTy for FnTy {}