richat_shared/transports/
tcp.rs

1use {
2    crate::{
3        config::deserialize_x_token_set,
4        shutdown::Shutdown,
5        transports::{RecvError, RecvStream, Subscribe, SubscribeError, WriteVectored},
6    },
7    futures::stream::StreamExt,
8    prost::Message,
9    richat_proto::richat::{
10        QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeResponse,
11        QuicSubscribeResponseError, TcpSubscribeRequest,
12    },
13    serde::Deserialize,
14    std::{
15        borrow::Cow,
16        collections::HashSet,
17        future::Future,
18        io::{self, IoSlice},
19        mem,
20        net::{IpAddr, Ipv4Addr, SocketAddr},
21        sync::Arc,
22    },
23    tokio::{
24        io::{AsyncReadExt, AsyncWriteExt},
25        net::{TcpListener, TcpSocket, TcpStream},
26        task::JoinError,
27    },
28    tracing::{error, info, warn},
29};
30
31#[derive(Debug, Clone, Deserialize)]
32#[serde(deny_unknown_fields, default)]
33pub struct ConfigTcpServer {
34    pub endpoint: SocketAddr,
35    pub backlog: u32,
36    pub keepalive: Option<bool>,
37    pub nodelay: Option<bool>,
38    pub send_buffer_size: Option<usize>,
39    /// Max request size in bytes
40    pub max_request_size: usize,
41    #[serde(deserialize_with = "deserialize_x_token_set")]
42    pub x_tokens: HashSet<Vec<u8>>,
43}
44
45impl Default for ConfigTcpServer {
46    fn default() -> Self {
47        Self {
48            endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101),
49            backlog: 1024,
50            keepalive: None,
51            nodelay: None,
52            send_buffer_size: None,
53            max_request_size: 1024,
54            x_tokens: HashSet::new(),
55        }
56    }
57}
58
59impl ConfigTcpServer {
60    pub fn listen(&self) -> io::Result<TcpListener> {
61        let socket = match self.endpoint {
62            SocketAddr::V4(_) => TcpSocket::new_v4(),
63            SocketAddr::V6(_) => TcpSocket::new_v6(),
64        }?;
65        socket.bind(self.endpoint)?;
66        socket.listen(self.backlog)
67    }
68
69    pub fn set_accepted_socket_options(&self, stream: &TcpStream) -> io::Result<()> {
70        if let Some(keepalive) = self.keepalive {
71            let sock_ref = socket2::SockRef::from(&stream);
72            sock_ref.set_keepalive(keepalive)?;
73        }
74        if let Some(nodelay) = self.nodelay {
75            stream.set_nodelay(nodelay)?;
76        }
77        if let Some(send_buffer_size) = self.send_buffer_size {
78            let sock_ref = socket2::SockRef::from(&stream);
79            sock_ref.set_send_buffer_size(send_buffer_size)?;
80        }
81        Ok(())
82    }
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum ConnectionError {
87    #[error(transparent)]
88    Io(#[from] io::Error),
89    #[error(transparent)]
90    Prost(#[from] prost::DecodeError),
91}
92
93#[derive(Debug)]
94pub struct TcpServer;
95
96impl TcpServer {
97    pub async fn spawn(
98        mut config: ConfigTcpServer,
99        messages: impl Subscribe + Clone + Send + 'static,
100        on_conn_new_cb: impl Fn() + Copy + Send + 'static,
101        on_conn_drop_cb: impl Fn() + Copy + Send + 'static,
102        shutdown: Shutdown,
103    ) -> io::Result<impl Future<Output = Result<(), JoinError>>> {
104        let listener = config.listen()?;
105        info!("start server at {}", config.endpoint);
106
107        Ok(tokio::spawn(async move {
108            let mut id = 0;
109            let x_tokens = Arc::new(mem::take(&mut config.x_tokens));
110            tokio::pin!(shutdown);
111            loop {
112                tokio::select! {
113                    incoming = listener.accept() => {
114                        let stream = match incoming {
115                            Ok((stream, addr)) => {
116                                if let Err(error) = config.set_accepted_socket_options(&stream) {
117                                    warn!("#{id}: failed to set socket options {error:?}");
118                                }
119                                info!("#{id}: new connection from {addr:?}");
120                                stream
121                            }
122                            Err(error) => {
123                                error!("failed to accept new connection: {error}");
124                                break;
125                            }
126                        };
127
128                        let messages = messages.clone();
129                        let x_tokens = Arc::clone(&x_tokens);
130                        tokio::spawn(async move {
131                            on_conn_new_cb();
132                            if let Err(error) = Self::handle_incoming(
133                                id,
134                                stream,
135                                messages,
136                                config.max_request_size as u64,
137                                x_tokens
138                            ).await {
139                                error!("#{id}: connection failed: {error}");
140                            } else {
141                                info!("#{id}: connection closed");
142                            }
143                            on_conn_drop_cb();
144                        });
145                        id += 1;
146                    }
147                    () = &mut shutdown => {
148                        info!("shutdown");
149                        break
150                    },
151                }
152            }
153        }))
154    }
155
156    async fn handle_incoming(
157        id: u64,
158        mut stream: TcpStream,
159        messages: impl Subscribe,
160        max_request_size: u64,
161        x_tokens: Arc<HashSet<Vec<u8>>>,
162    ) -> Result<(), ConnectionError> {
163        // Read request and subscribe
164        let (response, maybe_rx) =
165            Self::handle_request(id, &mut stream, messages, max_request_size, x_tokens).await?;
166
167        // Send response
168        let buf = response.encode_to_vec();
169        stream.write_u64(buf.len() as u64).await?;
170        stream.write_all(&buf).await?;
171
172        let Some(mut rx) = maybe_rx else {
173            return Ok(());
174        };
175
176        // Send loop
177        loop {
178            match rx.next().await {
179                Some(Ok(message)) => {
180                    WriteVectored::new(
181                        &mut stream,
182                        &mut [
183                            IoSlice::new(&(message.len() as u64).to_be_bytes()),
184                            IoSlice::new(&message),
185                        ],
186                    )
187                    .await?;
188                }
189                Some(Err(error)) => {
190                    error!("#{id}: failed to get message: {error}");
191                    let msg = QuicSubscribeClose {
192                        error: match error {
193                            RecvError::Lagged => QuicSubscribeCloseError::Lagged,
194                            RecvError::Closed => QuicSubscribeCloseError::Closed,
195                        } as i32,
196                    };
197                    let message = msg.encode_to_vec();
198
199                    stream.write_u64(u64::MAX).await?;
200                    stream.write_u64(message.len() as u64).await?;
201                    stream.write_all(&message).await?;
202                }
203                None => break,
204            }
205        }
206
207        Ok(())
208    }
209
210    async fn handle_request(
211        id: u64,
212        stream: &mut TcpStream,
213        messages: impl Subscribe,
214        max_request_size: u64,
215        x_tokens: Arc<HashSet<Vec<u8>>>,
216    ) -> Result<(QuicSubscribeResponse, Option<RecvStream>), ConnectionError> {
217        // Read request
218        let size = stream.read_u64().await?;
219        if size > max_request_size {
220            let msg = QuicSubscribeResponse {
221                error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
222                ..Default::default()
223            };
224            return Ok((msg, None));
225        }
226        let mut buf = vec![0; size as usize]; // TODO: use MaybeUninit
227        stream.read_exact(buf.as_mut_slice()).await?;
228
229        // Decode request
230        let TcpSubscribeRequest {
231            x_token,
232            replay_from_slot,
233            filter,
234        } = Message::decode(buf.as_slice())?;
235
236        // verify access token
237        if !x_tokens.is_empty() {
238            if let Some(error) = match x_token {
239                Some(x_token) if !x_tokens.contains(&x_token) => {
240                    Some(QuicSubscribeResponseError::XTokenInvalid as i32)
241                }
242                None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
243                _ => None,
244            } {
245                let msg = QuicSubscribeResponse {
246                    error: Some(error),
247                    ..Default::default()
248                };
249                return Ok((msg, None));
250            }
251        }
252
253        Ok(match messages.subscribe(replay_from_slot, filter) {
254            Ok(rx) => {
255                let pos = replay_from_slot
256                    .map(|slot| format!("slot {slot}").into())
257                    .unwrap_or(Cow::Borrowed("latest"));
258                info!("#{id}: subscribed from {pos}");
259                (QuicSubscribeResponse::default(), Some(rx))
260            }
261            Err(SubscribeError::NotInitialized) => {
262                let msg = QuicSubscribeResponse {
263                    error: Some(QuicSubscribeResponseError::NotInitialized as i32),
264                    ..Default::default()
265                };
266                (msg, None)
267            }
268            Err(SubscribeError::SlotNotAvailable { first_available }) => {
269                let msg = QuicSubscribeResponse {
270                    error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
271                    first_available_slot: Some(first_available),
272                    ..Default::default()
273                };
274                (msg, None)
275            }
276        })
277    }
278}