richat_shared/transports/
tcp.rs

1use {
2    crate::{
3        config::deserialize_x_token_set,
4        shutdown::Shutdown,
5        transports::{RecvError, RecvStream, Subscribe, SubscribeError, WriteVectored},
6    },
7    futures::stream::StreamExt,
8    prost::Message,
9    richat_proto::richat::{
10        QuicSubscribeClose, QuicSubscribeCloseError, QuicSubscribeResponse,
11        QuicSubscribeResponseError, TcpSubscribeRequest,
12    },
13    serde::Deserialize,
14    std::{
15        borrow::Cow,
16        collections::HashSet,
17        future::Future,
18        io::{self, IoSlice},
19        mem,
20        net::{IpAddr, Ipv4Addr, SocketAddr},
21        sync::Arc,
22    },
23    tokio::{
24        io::{AsyncReadExt, AsyncWriteExt},
25        net::{TcpListener, TcpSocket, TcpStream},
26        task::JoinError,
27    },
28    tracing::{error, info, warn},
29};
30
31#[derive(Debug, Clone, Deserialize)]
32#[serde(deny_unknown_fields, default)]
33pub struct ConfigTcpServer {
34    pub endpoint: SocketAddr,
35    pub backlog: u32,
36    pub keepalive: Option<bool>,
37    pub nodelay: Option<bool>,
38    pub send_buffer_size: Option<usize>,
39    /// Max request size in bytes
40    pub max_request_size: usize,
41    #[serde(deserialize_with = "deserialize_x_token_set")]
42    pub x_tokens: HashSet<Vec<u8>>,
43}
44
45impl Default for ConfigTcpServer {
46    fn default() -> Self {
47        Self {
48            endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10101),
49            backlog: 1024,
50            keepalive: None,
51            nodelay: None,
52            send_buffer_size: None,
53            max_request_size: 1024,
54            x_tokens: HashSet::new(),
55        }
56    }
57}
58
59impl ConfigTcpServer {
60    pub fn listen(&self) -> io::Result<TcpListener> {
61        let socket = match self.endpoint {
62            SocketAddr::V4(_) => TcpSocket::new_v4(),
63            SocketAddr::V6(_) => TcpSocket::new_v6(),
64        }?;
65        socket.bind(self.endpoint)?;
66        socket.listen(self.backlog)
67    }
68
69    pub fn set_accepted_socket_options(&self, stream: &TcpStream) -> io::Result<()> {
70        if let Some(keepalive) = self.keepalive {
71            let sock_ref = socket2::SockRef::from(&stream);
72            sock_ref.set_keepalive(keepalive)?;
73        }
74        if let Some(nodelay) = self.nodelay {
75            stream.set_nodelay(nodelay)?;
76        }
77        if let Some(send_buffer_size) = self.send_buffer_size {
78            let sock_ref = socket2::SockRef::from(&stream);
79            sock_ref.set_send_buffer_size(send_buffer_size)?;
80        }
81        Ok(())
82    }
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum ConnectionError {
87    #[error(transparent)]
88    Io(#[from] io::Error),
89    #[error(transparent)]
90    Prost(#[from] prost::DecodeError),
91}
92
93#[derive(Debug)]
94pub struct TcpServer;
95
96impl TcpServer {
97    pub async fn spawn(
98        mut config: ConfigTcpServer,
99        messages: impl Subscribe + Clone + Send + 'static,
100        on_conn_new_cb: impl Fn() + Clone + Send + 'static,
101        on_conn_drop_cb: impl Fn() + Clone + Send + 'static,
102        shutdown: Shutdown,
103    ) -> io::Result<impl Future<Output = Result<(), JoinError>>> {
104        let listener = config.listen()?;
105        info!("start server at {}", config.endpoint);
106
107        Ok(tokio::spawn(async move {
108            let mut id = 0;
109            let x_tokens = Arc::new(mem::take(&mut config.x_tokens));
110            tokio::pin!(shutdown);
111            loop {
112                tokio::select! {
113                    incoming = listener.accept() => {
114                        let stream = match incoming {
115                            Ok((stream, addr)) => {
116                                if let Err(error) = config.set_accepted_socket_options(&stream) {
117                                    warn!("#{id}: failed to set socket options {error:?}");
118                                }
119                                info!("#{id}: new connection from {addr:?}");
120                                stream
121                            }
122                            Err(error) => {
123                                error!("failed to accept new connection: {error}");
124                                break;
125                            }
126                        };
127
128                        let messages = messages.clone();
129                        let on_conn_new_cb = on_conn_new_cb.clone();
130                        let on_conn_drop_cb = on_conn_drop_cb.clone();
131                        let x_tokens = Arc::clone(&x_tokens);
132                        tokio::spawn(async move {
133                            on_conn_new_cb();
134                            if let Err(error) = Self::handle_incoming(
135                                id,
136                                stream,
137                                messages,
138                                config.max_request_size as u64,
139                                x_tokens
140                            ).await {
141                                error!("#{id}: connection failed: {error}");
142                            } else {
143                                info!("#{id}: connection closed");
144                            }
145                            on_conn_drop_cb();
146                        });
147                        id += 1;
148                    }
149                    () = &mut shutdown => {
150                        info!("shutdown");
151                        break
152                    },
153                }
154            }
155        }))
156    }
157
158    async fn handle_incoming(
159        id: u64,
160        mut stream: TcpStream,
161        messages: impl Subscribe,
162        max_request_size: u64,
163        x_tokens: Arc<HashSet<Vec<u8>>>,
164    ) -> Result<(), ConnectionError> {
165        // Read request and subscribe
166        let (response, maybe_rx) =
167            Self::handle_request(id, &mut stream, messages, max_request_size, x_tokens).await?;
168
169        // Send response
170        let buf = response.encode_to_vec();
171        stream.write_u64(buf.len() as u64).await?;
172        stream.write_all(&buf).await?;
173
174        let Some(mut rx) = maybe_rx else {
175            return Ok(());
176        };
177
178        // Send loop
179        loop {
180            match rx.next().await {
181                Some(Ok(message)) => {
182                    WriteVectored::new(
183                        &mut stream,
184                        &mut [
185                            IoSlice::new(&(message.len() as u64).to_be_bytes()),
186                            IoSlice::new(&message),
187                        ],
188                    )
189                    .await?;
190                }
191                Some(Err(error)) => {
192                    error!("#{id}: failed to get message: {error}");
193                    let msg = QuicSubscribeClose {
194                        error: match error {
195                            RecvError::Lagged => QuicSubscribeCloseError::Lagged,
196                            RecvError::Closed => QuicSubscribeCloseError::Closed,
197                        } as i32,
198                    };
199                    let message = msg.encode_to_vec();
200
201                    stream.write_u64(u64::MAX).await?;
202                    stream.write_u64(message.len() as u64).await?;
203                    stream.write_all(&message).await?;
204                }
205                None => break,
206            }
207        }
208
209        Ok(())
210    }
211
212    async fn handle_request(
213        id: u64,
214        stream: &mut TcpStream,
215        messages: impl Subscribe,
216        max_request_size: u64,
217        x_tokens: Arc<HashSet<Vec<u8>>>,
218    ) -> Result<(QuicSubscribeResponse, Option<RecvStream>), ConnectionError> {
219        // Read request
220        let size = stream.read_u64().await?;
221        if size > max_request_size {
222            let msg = QuicSubscribeResponse {
223                error: Some(QuicSubscribeResponseError::RequestSizeTooLarge as i32),
224                ..Default::default()
225            };
226            return Ok((msg, None));
227        }
228        let mut buf = vec![0; size as usize]; // TODO: use MaybeUninit
229        stream.read_exact(buf.as_mut_slice()).await?;
230
231        // Decode request
232        let TcpSubscribeRequest {
233            x_token,
234            replay_from_slot,
235            filter,
236        } = Message::decode(buf.as_slice())?;
237
238        // verify access token
239        if !x_tokens.is_empty() {
240            if let Some(error) = match x_token {
241                Some(x_token) if !x_tokens.contains(&x_token) => {
242                    Some(QuicSubscribeResponseError::XTokenInvalid as i32)
243                }
244                None => Some(QuicSubscribeResponseError::XTokenRequired as i32),
245                _ => None,
246            } {
247                let msg = QuicSubscribeResponse {
248                    error: Some(error),
249                    ..Default::default()
250                };
251                return Ok((msg, None));
252            }
253        }
254
255        Ok(match messages.subscribe(replay_from_slot, filter) {
256            Ok(rx) => {
257                let pos = replay_from_slot
258                    .map(|slot| format!("slot {slot}").into())
259                    .unwrap_or(Cow::Borrowed("latest"));
260                info!("#{id}: subscribed from {pos}");
261                (QuicSubscribeResponse::default(), Some(rx))
262            }
263            Err(SubscribeError::NotInitialized) => {
264                let msg = QuicSubscribeResponse {
265                    error: Some(QuicSubscribeResponseError::NotInitialized as i32),
266                    ..Default::default()
267                };
268                (msg, None)
269            }
270            Err(SubscribeError::SlotNotAvailable { first_available }) => {
271                let msg = QuicSubscribeResponse {
272                    error: Some(QuicSubscribeResponseError::SlotNotAvailable as i32),
273                    first_available_slot: Some(first_available),
274                    ..Default::default()
275                };
276                (msg, None)
277            }
278        })
279    }
280}