cat_dev/fsemul/pcfs/sata/server/
connection_flags.rs

1//! Implement connection level flags for each connection that is present.
2//!
3//! The known connection flags, or things that can change/be affected by a
4//! whole connection at once with PCFS are:
5//!
6//! 1. "Fast-File I/O Enabled", whether or not we should be using the 'ffio'
7//!    version of PCFS over SATA, where file data can be inline on the stream
8//!    without needing to be in a series of chunked packets.
9//! 2. "Combined Send/Recv Enabled", whether or not we should support using
10//!    'csr' in PCFS over sata, where sends/receives can be bundled together.
11//! 3. "Version": the version of PCFS that we should report to be for this
12//!    client.
13
14use crate::{
15	errors::CatBridgeError,
16	fsemul::pcfs::sata::proto::DEFAULT_PCFS_VERSION,
17	net::models::{FromRequest, FromRequestParts, Request, Response},
18};
19use scc::HashMap as ConcurrentHashMap;
20use std::{
21	convert::Infallible,
22	sync::{
23		Arc, LazyLock,
24		atomic::{AtomicBool, AtomicU32, Ordering},
25	},
26	task::{Context, Poll},
27};
28use tower::{Layer, Service};
29use valuable::Valuable;
30
31/// A series of connection flags that apply to a connnection.
32///
33/// See this modules documentation for more information about more of what
34/// these flags indicate, and what flags can be set.
35pub(super) static SATA_CONNECTION_FLAGS: LazyLock<ConcurrentHashMap<u64, SataConnectionFlags>> =
36	LazyLock::new(|| ConcurrentHashMap::with_capacity(1));
37
38/// The flags that are set on this particular connection.
39#[derive(Clone, Debug, Valuable)]
40pub struct SataConnectionFlags {
41	fast_file_io_enabled: Arc<AtomicBool>,
42	combined_send_recv_enabled: Arc<AtomicBool>,
43	version: Arc<AtomicU32>,
44	first_read_size: Arc<AtomicU32>,
45	first_write_size: Arc<AtomicU32>,
46	ffio_buffer_should_have_grown: Arc<AtomicBool>,
47}
48
49impl SataConnectionFlags {
50	#[must_use]
51	pub fn new() -> Self {
52		Self {
53			fast_file_io_enabled: Arc::new(AtomicBool::new(true)),
54			combined_send_recv_enabled: Arc::new(AtomicBool::new(true)),
55			version: Arc::new(AtomicU32::new(DEFAULT_PCFS_VERSION)),
56			first_read_size: Arc::new(AtomicU32::new(196_672)),
57			first_write_size: Arc::new(AtomicU32::new(196_640)),
58			ffio_buffer_should_have_grown: Arc::new(AtomicBool::new(false)),
59		}
60	}
61
62	#[must_use]
63	pub fn new_with_flags(ffio_enabled: bool, csr_enabled: bool) -> Self {
64		Self {
65			fast_file_io_enabled: Arc::new(AtomicBool::new(ffio_enabled)),
66			combined_send_recv_enabled: Arc::new(AtomicBool::new(csr_enabled)),
67			version: Arc::new(AtomicU32::new(DEFAULT_PCFS_VERSION)),
68			first_read_size: Arc::new(AtomicU32::new(196_672)),
69			first_write_size: Arc::new(AtomicU32::new(196_640)),
70			ffio_buffer_should_have_grown: Arc::new(AtomicBool::new(false)),
71		}
72	}
73
74	#[must_use]
75	pub fn ffio_enabled(&self) -> bool {
76		self.fast_file_io_enabled.load(Ordering::Acquire)
77	}
78
79	pub fn set_ffio_enabled(&self, enabled: bool) {
80		self.fast_file_io_enabled.store(enabled, Ordering::Release);
81	}
82
83	#[must_use]
84	pub fn csr_enabled(&self) -> bool {
85		self.combined_send_recv_enabled.load(Ordering::Acquire)
86	}
87
88	pub fn set_csr_enabled(&self, enabled: bool) {
89		self.combined_send_recv_enabled
90			.store(enabled, Ordering::Release);
91	}
92
93	#[must_use]
94	pub fn version(&self) -> u32 {
95		self.version.load(Ordering::Acquire)
96	}
97
98	pub fn set_version(&self, version_num: u32) {
99		self.version.store(version_num, Ordering::Release);
100	}
101
102	#[must_use]
103	pub fn first_read_size(&self) -> u32 {
104		self.first_read_size.load(Ordering::Acquire)
105	}
106
107	pub fn set_first_read_size(&self, new_size: u32) {
108		self.first_read_size.store(new_size, Ordering::Release);
109	}
110
111	#[must_use]
112	pub fn first_write_size(&self) -> u32 {
113		self.first_write_size.load(Ordering::Acquire)
114	}
115
116	pub fn set_first_write_size(&self, new_size: u32) {
117		self.first_write_size.store(new_size, Ordering::Release);
118	}
119
120	#[must_use]
121	pub fn ffio_buffer_should_have_grown(&self) -> bool {
122		self.ffio_buffer_should_have_grown.load(Ordering::Acquire)
123	}
124
125	pub fn set_ffio_buffer_should_have_grown(&self, did_grow: bool) {
126		self.ffio_buffer_should_have_grown
127			.store(did_grow, Ordering::Release);
128	}
129}
130
131impl Default for SataConnectionFlags {
132	fn default() -> Self {
133		Self::new()
134	}
135}
136
137#[derive(Clone, Debug)]
138pub struct SataConnectionFlagsLayer;
139
140impl<Layered> Layer<Layered> for SataConnectionFlagsLayer
141where
142	Layered: Clone,
143{
144	type Service = LayeredSataConnectionFlags<Layered>;
145
146	fn layer(&self, inner: Layered) -> Self::Service {
147		LayeredSataConnectionFlags { inner }
148	}
149}
150
151#[derive(Clone)]
152pub struct LayeredSataConnectionFlags<Layered> {
153	inner: Layered,
154}
155
156impl<Layered, State: Clone + Send + Sync + 'static> Service<Request<State>>
157	for LayeredSataConnectionFlags<Layered>
158where
159	Layered:
160		Service<Request<State>, Response = Response, Error = Infallible> + Clone + Send + 'static,
161	Layered::Future: Send + 'static,
162{
163	type Response = Layered::Response;
164	type Error = Layered::Error;
165	type Future = Layered::Future;
166
167	#[inline]
168	fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169		self.inner.poll_ready(ctx)
170	}
171
172	fn call(&mut self, mut req: Request<State>) -> Self::Future {
173		if let Some(flags) = SATA_CONNECTION_FLAGS.get_sync(&req.stream_id()) {
174			req.extensions_mut().insert(flags.clone());
175		}
176
177		self.inner.call(req)
178	}
179}
180
181impl<State: Clone + Send + Sync + 'static> FromRequestParts<State> for SataConnectionFlags {
182	async fn from_request_parts(req: &mut Request<State>) -> Result<Self, CatBridgeError> {
183		Ok(req
184			.extensions()
185			.get::<SataConnectionFlags>()
186			.cloned()
187			.unwrap_or_default())
188	}
189}
190
191impl<State: Clone + Send + Sync + 'static> FromRequest<State> for SataConnectionFlags {
192	async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
193		Ok(req
194			.extensions()
195			.get::<SataConnectionFlags>()
196			.cloned()
197			.unwrap_or_default())
198	}
199}