richat_shared/transports/
tcp.rs1use {
2 crate::{
3 config::deserialize_x_token_set,
4 shutdown::Shutdown,
5 transports::{RecvError, RecvStream, Subscribe, SubscribeError, WriteVectored},
6 },
7 futures::stream::StreamExt,
8 prost::Message,
9 richat_proto::richat::{
10 QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeResponse,
11 QuicSubscribeResponseError, TcpSubscribeRequest,
12 },
13 serde::Deserialize,
14 std::{
15 borrow::Cow,
16 collections::HashSet,
17 future::Future,
18 io::{self, IoSlice},
19 mem,
20 net::{IpAddr, Ipv4Addr, SocketAddr},
21 sync::Arc,
22 },
23 tokio::{
24 io::{AsyncReadExt, AsyncWriteExt},
25 net::{TcpListener, TcpSocket, TcpStream},
26 task::JoinError,
27 },
28 tracing::{error, info, warn},
29};
30
31#[derive(Debug, Clone, Deserialize)]
32#[serde(deny_unknown_fields, default)]
33pub struct ConfigTcpServer {
34 pub endpoint: SocketAddr,
35 pub backlog: u32,
36 pub keepalive: Option<bool>,
37 pub nodelay: Option<bool>,
38 pub send_buffer_size: Option<usize>,
39 pub max_request_size: usize,
41 #[serde(deserialize_with = "deserialize_x_token_set")]
42 pub x_tokens: HashSet<Vec<u8>>,
43}
44
45impl Default for ConfigTcpServer {
46 fn default() -> Self {
47 Self {
48 endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101),
49 backlog: 1024,
50 keepalive: None,
51 nodelay: None,
52 send_buffer_size: None,
53 max_request_size: 1024,
54 x_tokens: HashSet::new(),
55 }
56 }
57}
58
59impl ConfigTcpServer {
60 pub fn listen(&self) -> io::Result<TcpListener> {
61 let socket = match self.endpoint {
62 SocketAddr::V4(_) => TcpSocket::new_v4(),
63 SocketAddr::V6(_) => TcpSocket::new_v6(),
64 }?;
65 socket.bind(self.endpoint)?;
66 socket.listen(self.backlog)
67 }
68
69 pub fn set_accepted_socket_options(&self, stream: &TcpStream) -> io::Result<()> {
70 if let Some(keepalive) = self.keepalive {
71 let sock_ref = socket2::SockRef::from(&stream);
72 sock_ref.set_keepalive(keepalive)?;
73 }
74 if let Some(nodelay) = self.nodelay {
75 stream.set_nodelay(nodelay)?;
76 }
77 if let Some(send_buffer_size) = self.send_buffer_size {
78 let sock_ref = socket2::SockRef::from(&stream);
79 sock_ref.set_send_buffer_size(send_buffer_size)?;
80 }
81 Ok(())
82 }
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum ConnectionError {
87 #[error(transparent)]
88 Io(#[from] io::Error),
89 #[error(transparent)]
90 Prost(#[from] prost::DecodeError),
91}
92
93#[derive(Debug)]
94pub struct TcpServer;
95
96impl TcpServer {
97 pub async fn spawn(
98 mut config: ConfigTcpServer,
99 messages: impl Subscribe + Clone + Send + 'static,
100 on_conn_new_cb: impl Fn() + Clone + Send + 'static,
101 on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
102 shutdown: Shutdown,
103 ) -> io::Result<impl Future<Output = Result<(), JoinError>>> {
104 let listener = config.listen()?;
105 info!("start server at {}", config.endpoint);
106
107 Ok(tokio::spawn(async move {
108 let mut id = 0;
109 let x_tokens = Arc::new(mem::take(&mut config.x_tokens));
110 tokio::pin!(shutdown);
111 loop {
112 tokio::select! {
113 incoming = listener.accept() => {
114 let stream = match incoming {
115 Ok((stream, addr)) => {
116 if let Err(error) = config.set_accepted_socket_options(&stream) {
117 warn!("#{id}: failed to set socket options {error:?}");
118 }
119 info!("#{id}: new connection from {addr:?}");
120 stream
121 }
122 Err(error) => {
123 error!("failed to accept new connection: {error}");
124 break;
125 }
126 };
127
128 let messages = messages.clone();
129 let on_conn_new_cb = on_conn_new_cb.clone();
130 let on_conn_drop_cb = on_conn_drop_cb.clone();
131 let x_tokens = Arc::clone(&x_tokens);
132 tokio::spawn(async move {
133 on_conn_new_cb();
134 if let Err(error) = Self::handle_incoming(
135 id,
136 stream,
137 messages,
138 config.max_request_size as u64,
139 x_tokens
140 ).await {
141 error!("#{id}: connection failed: {error}");
142 } else {
143 info!("#{id}: connection closed");
144 }
145 on_conn_drop_cb();
146 });
147 id += 1;
148 }
149 () = &mut shutdown => {
150 info!("shutdown");
151 break
152 },
153 }
154 }
155 }))
156 }
157
158 async fn handle_incoming(
159 id: u64,
160 mut stream: TcpStream,
161 messages: impl Subscribe,
162 max_request_size: u64,
163 x_tokens: Arc<HashSet<Vec<u8>>>,
164 ) -> Result<(), ConnectionError> {
165 let (response, maybe_rx) =
167 Self::handle_request(id, &mut stream, messages, max_request_size, x_tokens).await?;
168
169 let buf = response.encode_to_vec();
171 stream.write_u64(buf.len() as u64).await?;
172 stream.write_all(&buf).await?;
173
174 let Some(mut rx) = maybe_rx else {
175 return Ok(());
176 };
177
178 loop {
180 match rx.next().await {
181 Some(Ok(message)) => {
182 WriteVectored::new(
183 &mut stream,
184 &mut [
185 IoSlice::new(&(message.len() as u64).to_be_bytes()),
186 IoSlice::new(&message),
187 ],
188 )
189 .await?;
190 }
191 Some(Err(error)) => {
192 error!("#{id}: failed to get message: {error}");
193 let msg = QuicSubscribeClose {
194 error: match error {
195 RecvError::Lagged => QuicSubscribeCloseError::Lagged,
196 RecvError::Closed => QuicSubscribeCloseError::Closed,
197 } as i32,
198 };
199 let message = msg.encode_to_vec();
200
201 stream.write_u64(u64::MAX).await?;
202 stream.write_u64(message.len() as u64).await?;
203 stream.write_all(&message).await?;
204 }
205 None => break,
206 }
207 }
208
209 Ok(())
210 }
211
212 async fn handle_request(
213 id: u64,
214 stream: &mut TcpStream,
215 messages: impl Subscribe,
216 max_request_size: u64,
217 x_tokens: Arc<HashSet<Vec<u8>>>,
218 ) -> Result<(QuicSubscribeResponse, Option<RecvStream>), ConnectionError> {
219 let size = stream.read_u64().await?;
221 if size > max_request_size {
222 let msg = QuicSubscribeResponse {
223 error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
224 ..Default::default()
225 };
226 return Ok((msg, None));
227 }
228 let mut buf = vec![0; size as usize]; stream.read_exact(buf.as_mut_slice()).await?;
230
231 let TcpSubscribeRequest {
233 x_token,
234 replay_from_slot,
235 filter,
236 } = Message::decode(buf.as_slice())?;
237
238 if !x_tokens.is_empty() {
240 if let Some(error) = match x_token {
241 Some(x_token) if !x_tokens.contains(&x_token) => {
242 Some(QuicSubscribeResponseError::XTokenInvalid as i32)
243 }
244 None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
245 _ => None,
246 } {
247 let msg = QuicSubscribeResponse {
248 error: Some(error),
249 ..Default::default()
250 };
251 return Ok((msg, None));
252 }
253 }
254
255 Ok(match messages.subscribe(replay_from_slot, filter) {
256 Ok(rx) => {
257 let pos = replay_from_slot
258 .map(|slot| format!("slot {slot}").into())
259 .unwrap_or(Cow::Borrowed("latest"));
260 info!("#{id}: subscribed from {pos}");
261 (QuicSubscribeResponse::default(), Some(rx))
262 }
263 Err(SubscribeError::NotInitialized) => {
264 let msg = QuicSubscribeResponse {
265 error: Some(QuicSubscribeResponseError::NotInitialized as i32),
266 ..Default::default()
267 };
268 (msg, None)
269 }
270 Err(SubscribeError::SlotNotAvailable { first_available }) => {
271 let msg = QuicSubscribeResponse {
272 error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
273 first_available_slot: Some(first_available),
274 ..Default::default()
275 };
276 (msg, None)
277 }
278 })
279 }
280}