cat_dev/net/
models.rs

1//! Common models for all L4 Services.
2//!
3//! This mostly includes the [`Request`], and [`Response`] structures that
4//! services are actively passed, and expected to return.
5
6use crate::{
7	errors::CatBridgeError,
8	net::{Extensions, errors::CommonNetAPIError},
9};
10use bytes::{Bytes, BytesMut};
11use fnv::FnvHasher;
12use futures::Future;
13use std::{
14	fmt::{Debug, Formatter, Result as FmtResult},
15	hash::{Hash, Hasher},
16	marker::Send,
17	net::SocketAddr,
18};
19use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
20
21#[cfg(feature = "servers")]
22use crate::{errors::NetworkError, net::errors::CommonNetNetworkError};
23#[cfg(feature = "servers")]
24use std::sync::Arc;
25#[cfg(feature = "servers")]
26use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
27#[cfg(feature = "servers")]
28use tracing::error;
29
30/// Used to do reference-to-value conversions thus not consuming the input value.
31///
32/// This is mainly used with state's to extract "substates" from a reference to main application
33/// state.
34pub trait FromRef<InputTy> {
35	/// Converts to this type from a reference to the input type.
36	fn from_ref(input: &InputTy) -> Self;
37}
38
39impl<InnerTy> FromRef<InnerTy> for InnerTy
40where
41	InnerTy: Clone,
42{
43	fn from_ref(input: &InnerTy) -> Self {
44		input.clone()
45	}
46}
47
48/// A request that came from either a TCP/UDP source.
49pub struct Request<State: Clone + Send + Sync + 'static> {
50	/// The actual body of the the underlying request.
51	body: Bytes,
52	/// Extensions that can in particular be attached to this request.
53	ext: Extensions,
54	/// The source address of where the request came from.
55	source_address: SocketAddr,
56	/// The active state for this request.
57	state: State,
58	/// The stream ID this request came in on.
59	stream_id: Option<u64>,
60	/// Indicate that the response needs to read a certain size, ignoring
61	/// whatever the current NAGLE algorithim says.
62	///
63	/// This will still call 'post nagle hook', 'trace io', and will still
64	/// obey NAGLE timeouts. It just overrides the _kind_ of NAGLE we do.
65	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
66	#[cfg(feature = "clients")]
67	explicit_read_amount: Option<usize>,
68	/// Allow accessing the raw underlying stream while processing the request.
69	#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
70	#[cfg(feature = "servers")]
71	#[allow(
72		// TODO(mythra): refactor to type.
73		clippy::type_complexity,
74	)]
75	stream_access: Option<Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>>,
76}
77
78impl<State: Clone + Send + Sync + 'static> Request<State>
79where
80	State: Default,
81{
82	#[must_use]
83	pub fn new(body: Bytes, source_address: SocketAddr, stream_id: Option<u64>) -> Self {
84		Self {
85			body,
86			ext: Extensions::new(),
87			source_address,
88			state: Default::default(),
89			stream_id,
90			#[cfg(feature = "clients")]
91			explicit_read_amount: None,
92			#[cfg(feature = "servers")]
93			stream_access: None,
94		}
95	}
96}
97
98impl<State: Clone + Send + Sync + 'static> Request<State> {
99	#[must_use]
100	pub fn new_with_state(
101		body: Bytes,
102		source_address: SocketAddr,
103		state: State,
104		stream_id: Option<u64>,
105	) -> Self {
106		Self {
107			body,
108			ext: Extensions::new(),
109			source_address,
110			state,
111			stream_id,
112			#[cfg(feature = "clients")]
113			explicit_read_amount: None,
114			#[cfg(feature = "servers")]
115			stream_access: None,
116		}
117	}
118
119	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
120	#[cfg(feature = "clients")]
121	#[must_use]
122	pub fn new_with_state_and_read_amount(
123		body: Bytes,
124		source_address: SocketAddr,
125		state: State,
126		stream_id: Option<u64>,
127		explicit_read_amount: usize,
128	) -> Self {
129		Self {
130			body,
131			ext: Extensions::new(),
132			source_address,
133			state,
134			stream_id,
135			explicit_read_amount: Some(explicit_read_amount),
136			#[cfg(feature = "servers")]
137			stream_access: None,
138		}
139	}
140
141	#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
142	#[cfg(feature = "servers")]
143	#[allow(
144		// TODO(mythra): refactor to type.
145		clippy::type_complexity,
146	)]
147	#[must_use]
148	pub fn new_with_state_and_stream(
149		body: Bytes,
150		source_address: SocketAddr,
151		state: State,
152		stream_id: Option<u64>,
153		stream_and_nagle_cache: Arc<Mutex<Option<(Option<BytesMut>, TcpStream)>>>,
154	) -> Self {
155		Self {
156			body,
157			ext: Extensions::new(),
158			source_address,
159			state,
160			stream_id,
161			#[cfg(feature = "clients")]
162			explicit_read_amount: None,
163			stream_access: Some(stream_and_nagle_cache),
164		}
165	}
166
167	/// Swap the body of the request to something new.
168	pub fn swap_body(&mut self, new_body: Bytes) {
169		self.body = new_body;
170	}
171
172	/// Update the core request source.
173	pub const fn update_request_source(&mut self, source: SocketAddr, stream_id: Option<u64>) {
174		self.source_address = source;
175		self.stream_id = stream_id;
176	}
177
178	/// A unique identifier for the "stream" or connection of a packet.
179	///
180	/// In UDP which doesn't have stream this uses the source address as
181	/// the core identifier.
182	#[must_use]
183	pub fn stream_id(&self) -> u64 {
184		if let Some(id) = self.stream_id {
185			id
186		} else {
187			let mut hasher = FnvHasher::default();
188			self.source_address.hash(&mut hasher);
189			hasher.finish()
190		}
191	}
192
193	/// A client has requested we send this request, and then read an explicit
194	/// amount of bytes, ignoring whatever the current NAGLE method is.
195	///
196	/// This is a utility only available when we are a client, and are receiving
197	/// a packet that changes what our nagle split is for it's specific response
198	/// while keeping the nagle the same otherwise.
199	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
200	#[cfg(feature = "clients")]
201	#[must_use]
202	pub const fn explicit_read_amount(&self) -> Option<usize> {
203		self.explicit_read_amount
204	}
205
206	/// Override the current NAGLE algorithm being used by this client for this
207	/// single request/response pair. Do a single non-nagle'd receive.
208	#[cfg_attr(docsrs, doc(cfg(feature = "clients")))]
209	#[cfg(feature = "clients")]
210	pub const fn set_explicit_read_amount(&mut self, new_read_amount: usize) {
211		self.explicit_read_amount = Some(new_read_amount);
212	}
213
214	/// Attempt to read more bytes from the TCP Stream directly.
215	///
216	/// This is a utility only available when we are a server, and need to request
217	/// more info from the client.
218	///
219	/// ***THIS WILL BYPASS EVERYYTHING PROVIDED BY TCP SERVER, AND JUST READ RAW
220	/// BYTES FROM THE STREAM.*** This is only for requests like File I/O which
221	/// _need_ to bypass all the logic provided by the stream classes.
222	///
223	/// ## Errors
224	///
225	/// If the request has been moved outside of it's original processing place,
226	/// and it is no longer possible to read from the stream.
227	#[cfg_attr(docsrs, doc(cfg(feature = "servers")))]
228	#[cfg(feature = "servers")]
229	pub async fn unsafe_read_more_bytes_from_stream(
230		&self,
231		to_read: usize,
232	) -> Result<Bytes, CatBridgeError> {
233		if let Some(strm) = self.stream_access.as_ref() {
234			let mut guard = strm.lock().await;
235
236			if let Some((opt_cache, stream)) = guard.as_mut() {
237				let mut buff = BytesMut::with_capacity(to_read);
238
239				if let Some(cache) = opt_cache.as_mut() {
240					if cache.len() <= to_read {
241						buff = cache.split();
242					} else {
243						buff = cache.split_to(to_read);
244					}
245				}
246
247				if buff.len() < to_read {
248					stream.readable().await.map_err(NetworkError::IO)?;
249					let mut needed = to_read - buff.len();
250					while needed > 0 {
251						let read = stream.read_buf(&mut buff).await.map_err(NetworkError::IO)?;
252						needed -= read;
253					}
254				}
255				return Ok::<Bytes, CatBridgeError>(buff.freeze());
256			}
257		}
258
259		error!("called unsafe_read_more_bytes on a stream that is not processing!");
260		Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
261	}
262
263	#[must_use]
264	pub const fn body(&self) -> &Bytes {
265		&self.body
266	}
267	#[must_use]
268	pub fn body_mut(&mut self) -> &mut Bytes {
269		&mut self.body
270	}
271	pub fn set_body(&mut self, new_body: Bytes) {
272		self.body = new_body;
273	}
274	#[must_use]
275	pub fn body_owned(self) -> Bytes {
276		self.body
277	}
278
279	#[must_use]
280	pub const fn extensions(&self) -> &Extensions {
281		&self.ext
282	}
283	#[must_use]
284	pub fn extensions_mut(&mut self) -> &mut Extensions {
285		&mut self.ext
286	}
287	#[must_use]
288	pub fn extensions_owned(self) -> Extensions {
289		self.ext
290	}
291
292	#[must_use]
293	pub const fn state(&self) -> &State {
294		&self.state
295	}
296	#[must_use]
297	pub fn state_mut(&mut self) -> &mut State {
298		&mut self.state
299	}
300
301	#[must_use]
302	pub const fn source(&self) -> &SocketAddr {
303		&self.source_address
304	}
305	#[must_use]
306	pub fn is_ipv4(&self) -> bool {
307		self.source_address.ip().is_ipv4()
308	}
309	#[must_use]
310	pub fn is_ipv6(&self) -> bool {
311		self.source_address.ip().is_ipv6()
312	}
313}
314
315impl<State: Clone + Send + Sync + 'static> Clone for Request<State> {
316	fn clone(&self) -> Self {
317		Request {
318			body: self.body.clone(),
319			ext: Extensions::new(),
320			source_address: self.source_address,
321			state: self.state.clone(),
322			stream_id: self.stream_id,
323			#[cfg(feature = "clients")]
324			explicit_read_amount: self.explicit_read_amount,
325			#[cfg(feature = "servers")]
326			stream_access: self.stream_access.clone(),
327		}
328	}
329}
330
331impl<State: Clone + Send + Sync + 'static> Debug for Request<State>
332where
333	State: Debug,
334{
335	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
336		let mut dbg_struct = fmt.debug_struct("Request");
337
338		dbg_struct
339			.field("body", &self.body)
340			// Extensions can't be printed in debug by hyper, and in order to keep
341			// compatability ours don't.
342			.field("source_address", &self.source_address)
343			.field("stream_id", &self.stream_id);
344
345		#[cfg(feature = "clients")]
346		dbg_struct.field("explicit_read_amount", &self.explicit_read_amount);
347		#[cfg(feature = "servers")]
348		dbg_struct.field(
349			"stream_access",
350			&if self.stream_access.is_some() {
351				"<stream>"
352			} else {
353				"<none>"
354			},
355		);
356
357		dbg_struct.finish_non_exhaustive()
358	}
359}
360
361const REQUEST_FIELDS: &[NamedField<'static>] = &[
362	NamedField::new("body"),
363	NamedField::new("source_address"),
364	NamedField::new("stream_id"),
365	#[cfg(feature = "clients")]
366	NamedField::new("explicit_read_amount"),
367	#[cfg(feature = "servers")]
368	NamedField::new("stream_access"),
369];
370
371impl<State: Clone + Send + Sync + 'static> Structable for Request<State> {
372	fn definition(&self) -> StructDef<'_> {
373		StructDef::new_static("Request", Fields::Named(REQUEST_FIELDS))
374	}
375}
376
377impl<State: Clone + Send + Sync + 'static> Valuable for Request<State> {
378	fn as_value(&self) -> Value<'_> {
379		Value::Structable(self)
380	}
381
382	fn visit(&self, visitor: &mut dyn Visit) {
383		visitor.visit_named_fields(&NamedValues::new(
384			REQUEST_FIELDS,
385			&[
386				Valuable::as_value(&format!("{:02X?}", self.body)),
387				Valuable::as_value(&format!("{}", self.source_address)),
388				Valuable::as_value(&self.stream_id),
389				#[cfg(feature = "clients")]
390				Valuable::as_value(&self.explicit_read_amount),
391				#[cfg(feature = "servers")]
392				Valuable::as_value(&if self.stream_access.is_some() {
393					"<stream>"
394				} else {
395					"<none>"
396				}),
397			],
398		));
399	}
400}
401
402/// Just a generic response on an L4 Layer.
403#[derive(Clone, Debug)]
404pub struct Response {
405	/// Get the body of the actual response to send. If empty, no response is sent.
406	pub body: Option<Bytes>,
407	/// If we should request any long-lived connections should be closed.
408	///
409	/// NOTE: not every type of Level 4 connection has a long lived connection,
410	/// or stream. UDP is a prime example of this, this is not guaranteed.
411	pub request_connection_close: bool,
412}
413
414impl Response {
415	#[must_use]
416	pub const fn new_empty() -> Self {
417		Self {
418			body: None,
419			request_connection_close: false,
420		}
421	}
422	#[must_use]
423	pub const fn empty_close() -> Self {
424		Self {
425			body: None,
426			request_connection_close: true,
427		}
428	}
429	#[must_use]
430	pub const fn new_with_body(body: Bytes) -> Self {
431		Self {
432			body: Some(body),
433			request_connection_close: false,
434		}
435	}
436
437	#[must_use]
438	pub const fn body(&self) -> Option<&Bytes> {
439		self.body.as_ref()
440	}
441	#[must_use]
442	pub fn body_mut(&mut self) -> Option<&mut Bytes> {
443		self.body.as_mut()
444	}
445	pub fn set_body(&mut self, bytes: Bytes) {
446		self.body = Some(bytes);
447	}
448	#[must_use]
449	pub fn take_body(self) -> Option<Bytes> {
450		self.body
451	}
452
453	#[must_use]
454	pub const fn request_connection_close(&self) -> bool {
455		self.request_connection_close
456	}
457	pub fn should_close_connection(&mut self) {
458		self.request_connection_close = true;
459	}
460	pub fn dont_close_connection(&mut self) {
461		self.request_connection_close = false;
462	}
463}
464
465impl Default for Response {
466	fn default() -> Self {
467		Self::new_empty()
468	}
469}
470
471impl<ByteTy: Into<Bytes>> From<ByteTy> for Response {
472	fn from(resp: ByteTy) -> Self {
473		Self::new_with_body(resp.into())
474	}
475}
476
477const RESPONSE_FIELDS: &[NamedField<'static>] = &[
478	NamedField::new("body"),
479	NamedField::new("request_connection_close"),
480];
481
482impl Structable for Response {
483	fn definition(&self) -> StructDef<'_> {
484		StructDef::new_static("Response", Fields::Named(RESPONSE_FIELDS))
485	}
486}
487
488impl Valuable for Response {
489	fn as_value(&self) -> Value<'_> {
490		Value::Structable(self)
491	}
492
493	fn visit(&self, visitor: &mut dyn Visit) {
494		visitor.visit_named_fields(&NamedValues::new(
495			RESPONSE_FIELDS,
496			&[
497				Valuable::as_value(&if let Some(body_ref) = self.body.as_ref() {
498					format!("{body_ref:02X?}")
499				} else {
500					"<empty>".to_owned()
501				}),
502				Valuable::as_value(&self.request_connection_close),
503			],
504		));
505	}
506}
507
508/// Extract any value from a Request, allowing more people to keep using it.
509///
510/// Kept as our own trait so it can be async like axum.
511pub trait FromRequestParts<State: Clone + Send + Sync + 'static>: Sized {
512	fn from_request_parts(
513		req: &mut Request<State>,
514	) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
515}
516
517/// Extract any value from a Request, finalizing it.
518///
519/// Kept as our own trait so it can be async like axum.
520pub trait FromRequest<State: Clone + Send + Sync + 'static>: Sized {
521	fn from_request(
522		req: Request<State>,
523	) -> impl Future<Output = Result<Self, CatBridgeError>> + Send;
524}
525impl<State: Clone + Send + Sync + 'static> FromRequest<State> for Request<State> {
526	async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
527		Ok(req)
528	}
529}
530
531/// A blanket trait to implement into a full response.
532///
533/// This was mainly implemented so functions that return things like
534/// `Bytes`, can naturally get wrapped into a result without needing
535/// to return a result themselves.
536pub trait IntoResponse: Sized {
537	/// Convert an arbitrary type to a Response.
538	///
539	/// # Errors
540	///
541	/// If for whatever reason the type can't be turned into a response.
542	fn to_response(self) -> Result<Response, CatBridgeError>;
543}
544
545impl IntoResponse for () {
546	fn to_response(self) -> Result<Response, CatBridgeError> {
547		Ok(Response::new_empty())
548	}
549}
550impl IntoResponse for Response {
551	fn to_response(self) -> Result<Response, CatBridgeError> {
552		Ok(self)
553	}
554}
555
556macro_rules! impl_from_ok_response {
557	($ty:ty) => {
558		impl IntoResponse for $ty {
559			fn to_response(self) -> Result<Response, CatBridgeError> {
560				Ok(self.into())
561			}
562		}
563	};
564}
565impl_from_ok_response!(Bytes);
566impl_from_ok_response!(BytesMut);
567impl_from_ok_response!(String);
568impl_from_ok_response!(Vec<u8>);
569impl_from_ok_response!(&'static [u8]);
570impl_from_ok_response!(&'static str);
571
572impl IntoResponse for CatBridgeError {
573	fn to_response(self) -> Result<Response, CatBridgeError> {
574		Err(self)
575	}
576}
577
578impl<SomeTy: IntoResponse> IntoResponse for Option<SomeTy> {
579	fn to_response(self) -> Result<Response, CatBridgeError> {
580		if let Some(val) = self {
581			val.to_response()
582		} else {
583			Ok(Response::new_empty())
584		}
585	}
586}
587impl<OkTy: IntoResponse> IntoResponse for Result<OkTy, CatBridgeError> {
588	fn to_response(self) -> Result<Response, CatBridgeError> {
589		self.and_then(IntoResponse::to_response)
590	}
591}
592
593/// Nagle guard is what determines when a packet "begins", and "ends".
594///
595/// These are the various types of ways we determine where a packet begins,
596/// and "ends".
597#[derive(Clone, Debug, PartialEq, Eq, Hash, Valuable)]
598pub enum NagleGuard {
599	/// Search for a specific searchs of bytes to determine the "end" of a
600	/// packet.
601	EndSigilSearch(&'static [u8]),
602	/// All packets are guaranteed to be the exact same length.
603	StaticSize(usize),
604	/// Packets will prefix their total length with a u16.
605	///
606	/// This includes the 'endianness' to parse the number as, and you can apply
607	/// an extra length to add (incase the length doesn't say include the length
608	/// of a header).
609	U16LengthPrefixed(Endianness, Option<usize>),
610	/// Packets will prefix their total length with a u32.
611	///
612	/// This includes the 'endianness' to parse the number as, and you can apply
613	/// an extra length to add (incase the length doesn't say include the length
614	/// of a header).
615	U32LengthPrefixed(Endianness, Option<usize>),
616}
617
618impl NagleGuard {
619	/// Split a buffer of bytes from potentially multiple packets.
620	///
621	/// ## Errors
622	///
623	/// If we are an "end sigil search" without a an actual sigil.
624	pub fn split(&self, buff: &BytesMut) -> Result<Option<(usize, usize)>, CommonNetAPIError> {
625		match *self {
626			NagleGuard::EndSigilSearch(sigil) => {
627				if sigil.is_empty() {
628					return Err(CommonNetAPIError::NagleGuardEndSigilCannotBeEmpty);
629				}
630				if buff.is_empty() {
631					return Ok(None);
632				}
633
634				for (idx, byte) in buff.iter().enumerate() {
635					// Not enough room!
636					if idx + sigil.len() > buff.len() {
637						break;
638					}
639					if *byte == sigil[0] && sigil == &buff[idx..(idx + sigil.len())] {
640						return Ok(Some((0, idx + sigil.len())));
641					}
642				}
643			}
644			NagleGuard::StaticSize(size) => {
645				if buff.len() < size {
646					return Ok(None);
647				}
648
649				return Ok(Some((0, size)));
650			}
651			NagleGuard::U16LengthPrefixed(endianness, extra_len) => {
652				if buff.len() < 2 {
653					return Ok(None);
654				}
655				let extra_len_frd = extra_len.unwrap_or_default();
656
657				let total_size = match endianness {
658					Endianness::Little => u16::from_le_bytes([buff[0], buff[1]]),
659					Endianness::Big => u16::from_be_bytes([buff[0], buff[1]]),
660				};
661				if buff.len() >= usize::from(total_size) + extra_len_frd {
662					return Ok(Some((0, usize::from(total_size) + extra_len_frd)));
663				}
664			}
665			NagleGuard::U32LengthPrefixed(endianness, extra_len) => {
666				if buff.len() < 4 {
667					return Ok(None);
668				}
669				let extra_len_frd = extra_len.unwrap_or_default();
670
671				let total_size = match endianness {
672					Endianness::Little => u32::from_le_bytes([buff[0], buff[1], buff[2], buff[3]]),
673					Endianness::Big => u32::from_be_bytes([buff[0], buff[1], buff[2], buff[3]]),
674				};
675				if buff.len() >= usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd {
676					return Ok(Some((
677						0,
678						usize::try_from(total_size).unwrap_or(usize::MAX) + extra_len_frd,
679					)));
680				}
681			}
682		}
683
684		Ok(None)
685	}
686}
687
688impl From<usize> for NagleGuard {
689	fn from(value: usize) -> Self {
690		NagleGuard::StaticSize(value)
691	}
692}
693
694impl From<&'static [u8]> for NagleGuard {
695	fn from(value: &'static [u8]) -> Self {
696		NagleGuard::EndSigilSearch(value)
697	}
698}
699
700/// The endianness of a particular number coming in over the network.
701#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Valuable)]
702pub enum Endianness {
703	/// The data is in little endian.
704	Little,
705	/// The data is in big endian.
706	Big,
707}
708
709/// A function type that can be used to convert before passing onto nagle.
710///
711/// This is useful when we have an encrypted stream that needs to be decrypted,
712/// before we end up applying any NAGLE, or other splitting logic to the
713/// stream.
714pub trait PreNagleFnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static {}
715impl<FnTy: Fn(u64, &mut BytesMut) + Send + Sync + 'static> PreNagleFnTy for FnTy {}
716
717/// A function type that can be used to convert data right before sending it
718/// out.
719///
720/// This is useful when we have an encrypted stream that needs to be encrypted
721/// before it goes out to the client.
722pub trait PostNagleFnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static {}
723impl<FnTy: Fn(u64, Bytes) -> Bytes + Send + Sync + 'static> PostNagleFnTy for FnTy {}