richat_shared/transports/
tcp.rs1use {
2 crate::{
3 config::deserialize_x_token_set,
4 shutdown::Shutdown,
5 transports::{RecvError, RecvStream, Subscribe, SubscribeError, WriteVectored},
6 },
7 futures::stream::StreamExt,
8 prost::Message,
9 richat_proto::richat::{
10 QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeResponse,
11 QuicSubscribeResponseError, TcpSubscribeRequest,
12 },
13 serde::Deserialize,
14 std::{
15 borrow::Cow,
16 collections::HashSet,
17 future::Future,
18 io::{self, IoSlice},
19 mem,
20 net::{IpAddr, Ipv4Addr, SocketAddr},
21 sync::Arc,
22 },
23 tokio::{
24 io::{AsyncReadExt, AsyncWriteExt},
25 net::{TcpListener, TcpSocket, TcpStream},
26 task::JoinError,
27 },
28 tracing::{error, info, warn},
29};
30
31#[derive(Debug, Clone, Deserialize)]
32#[serde(deny_unknown_fields, default)]
33pub struct ConfigTcpServer {
34 pub endpoint: SocketAddr,
35 pub backlog: u32,
36 pub keepalive: Option<bool>,
37 pub nodelay: Option<bool>,
38 pub send_buffer_size: Option<usize>,
39 pub max_request_size: usize,
41 #[serde(deserialize_with = "deserialize_x_token_set")]
42 pub x_tokens: HashSet<Vec<u8>>,
43}
44
45impl Default for ConfigTcpServer {
46 fn default() -> Self {
47 Self {
48 endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101),
49 backlog: 1024,
50 keepalive: None,
51 nodelay: None,
52 send_buffer_size: None,
53 max_request_size: 1024,
54 x_tokens: HashSet::new(),
55 }
56 }
57}
58
59impl ConfigTcpServer {
60 pub fn listen(&self) -> io::Result<TcpListener> {
61 let socket = match self.endpoint {
62 SocketAddr::V4(_) => TcpSocket::new_v4(),
63 SocketAddr::V6(_) => TcpSocket::new_v6(),
64 }?;
65 socket.bind(self.endpoint)?;
66 socket.listen(self.backlog)
67 }
68
69 pub fn set_accepted_socket_options(&self, stream: &TcpStream) -> io::Result<()> {
70 if let Some(keepalive) = self.keepalive {
71 let sock_ref = socket2::SockRef::from(&stream);
72 sock_ref.set_keepalive(keepalive)?;
73 }
74 if let Some(nodelay) = self.nodelay {
75 stream.set_nodelay(nodelay)?;
76 }
77 if let Some(send_buffer_size) = self.send_buffer_size {
78 let sock_ref = socket2::SockRef::from(&stream);
79 sock_ref.set_send_buffer_size(send_buffer_size)?;
80 }
81 Ok(())
82 }
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum ConnectionError {
87 #[error(transparent)]
88 Io(#[from] io::Error),
89 #[error(transparent)]
90 Prost(#[from] prost::DecodeError),
91}
92
93#[derive(Debug)]
94pub struct TcpServer;
95
96impl TcpServer {
97 pub async fn spawn(
98 mut config: ConfigTcpServer,
99 messages: impl Subscribe + Clone + Send + 'static,
100 on_conn_new_cb: impl Fn() + Copy + Send + 'static,
101 on_conn_drop_cb: impl Fn() + Copy + Send + 'static,
102 shutdown: Shutdown,
103 ) -> io::Result<impl Future<Output = Result<(), JoinError>>> {
104 let listener = config.listen()?;
105 info!("start server at {}", config.endpoint);
106
107 Ok(tokio::spawn(async move {
108 let mut id = 0;
109 let x_tokens = Arc::new(mem::take(&mut config.x_tokens));
110 tokio::pin!(shutdown);
111 loop {
112 tokio::select! {
113 incoming = listener.accept() => {
114 let stream = match incoming {
115 Ok((stream, addr)) => {
116 if let Err(error) = config.set_accepted_socket_options(&stream) {
117 warn!("#{id}: failed to set socket options {error:?}");
118 }
119 info!("#{id}: new connection from {addr:?}");
120 stream
121 }
122 Err(error) => {
123 error!("failed to accept new connection: {error}");
124 break;
125 }
126 };
127
128 let messages = messages.clone();
129 let x_tokens = Arc::clone(&x_tokens);
130 tokio::spawn(async move {
131 on_conn_new_cb();
132 if let Err(error) = Self::handle_incoming(
133 id,
134 stream,
135 messages,
136 config.max_request_size as u64,
137 x_tokens
138 ).await {
139 error!("#{id}: connection failed: {error}");
140 } else {
141 info!("#{id}: connection closed");
142 }
143 on_conn_drop_cb();
144 });
145 id += 1;
146 }
147 () = &mut shutdown => {
148 info!("shutdown");
149 break
150 },
151 }
152 }
153 }))
154 }
155
156 async fn handle_incoming(
157 id: u64,
158 mut stream: TcpStream,
159 messages: impl Subscribe,
160 max_request_size: u64,
161 x_tokens: Arc<HashSet<Vec<u8>>>,
162 ) -> Result<(), ConnectionError> {
163 let (response, maybe_rx) =
165 Self::handle_request(id, &mut stream, messages, max_request_size, x_tokens).await?;
166
167 let buf = response.encode_to_vec();
169 stream.write_u64(buf.len() as u64).await?;
170 stream.write_all(&buf).await?;
171
172 let Some(mut rx) = maybe_rx else {
173 return Ok(());
174 };
175
176 loop {
178 match rx.next().await {
179 Some(Ok(message)) => {
180 WriteVectored::new(
181 &mut stream,
182 &mut [
183 IoSlice::new(&(message.len() as u64).to_be_bytes()),
184 IoSlice::new(&message),
185 ],
186 )
187 .await?;
188 }
189 Some(Err(error)) => {
190 error!("#{id}: failed to get message: {error}");
191 let msg = QuicSubscribeClose {
192 error: match error {
193 RecvError::Lagged => QuicSubscribeCloseError::Lagged,
194 RecvError::Closed => QuicSubscribeCloseError::Closed,
195 } as i32,
196 };
197 let message = msg.encode_to_vec();
198
199 stream.write_u64(u64::MAX).await?;
200 stream.write_u64(message.len() as u64).await?;
201 stream.write_all(&message).await?;
202 }
203 None => break,
204 }
205 }
206
207 Ok(())
208 }
209
210 async fn handle_request(
211 id: u64,
212 stream: &mut TcpStream,
213 messages: impl Subscribe,
214 max_request_size: u64,
215 x_tokens: Arc<HashSet<Vec<u8>>>,
216 ) -> Result<(QuicSubscribeResponse, Option<RecvStream>), ConnectionError> {
217 let size = stream.read_u64().await?;
219 if size > max_request_size {
220 let msg = QuicSubscribeResponse {
221 error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
222 ..Default::default()
223 };
224 return Ok((msg, None));
225 }
226 let mut buf = vec![0; size as usize]; stream.read_exact(buf.as_mut_slice()).await?;
228
229 let TcpSubscribeRequest {
231 x_token,
232 replay_from_slot,
233 filter,
234 } = Message::decode(buf.as_slice())?;
235
236 if !x_tokens.is_empty() {
238 if let Some(error) = match x_token {
239 Some(x_token) if !x_tokens.contains(&x_token) => {
240 Some(QuicSubscribeResponseError::XTokenInvalid as i32)
241 }
242 None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
243 _ => None,
244 } {
245 let msg = QuicSubscribeResponse {
246 error: Some(error),
247 ..Default::default()
248 };
249 return Ok((msg, None));
250 }
251 }
252
253 Ok(match messages.subscribe(replay_from_slot, filter) {
254 Ok(rx) => {
255 let pos = replay_from_slot
256 .map(|slot| format!("slot {slot}").into())
257 .unwrap_or(Cow::Borrowed("latest"));
258 info!("#{id}: subscribed from {pos}");
259 (QuicSubscribeResponse::default(), Some(rx))
260 }
261 Err(SubscribeError::NotInitialized) => {
262 let msg = QuicSubscribeResponse {
263 error: Some(QuicSubscribeResponseError::NotInitialized as i32),
264 ..Default::default()
265 };
266 (msg, None)
267 }
268 Err(SubscribeError::SlotNotAvailable { first_available }) => {
269 let msg = QuicSubscribeResponse {
270 error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
271 first_available_slot: Some(first_available),
272 ..Default::default()
273 };
274 (msg, None)
275 }
276 })
277 }
278}