razor_stream/server/
server.rs

1use crate::proto::RpcAction;
2use crate::server::*;
3use captains_log::filter::LogFilter;
4use std::io;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8/// An RpcServer that listen, accept, and server connections, according to ServerFacts interface.
9pub struct RpcServer<F>
10where
11    F: ServerFacts,
12{
13    listeners_abort: Vec<(<F as AsyncExec>::AsyncHandle<()>, String)>,
14    logger: Arc<LogFilter>,
15    facts: Arc<F>,
16    conn_ref_count: Arc<()>,
17    server_close_tx: Mutex<Option<crossfire::MTx<()>>>,
18    server_close_rx: crossfire::MAsyncRx<()>,
19}
20
21impl<F> RpcServer<F>
22where
23    F: ServerFacts,
24{
25    pub fn new(facts: Arc<F>) -> Self {
26        let (tx, rx) = crossfire::mpmc::unbounded_async();
27        Self {
28            listeners_abort: Vec::new(),
29            logger: facts.new_logger(),
30            facts,
31            conn_ref_count: Arc::new(()),
32            server_close_tx: Mutex::new(Some(tx)),
33            server_close_rx: rx,
34        }
35    }
36
37    pub async fn listen<T: ServerTransport, D: Dispatch>(
38        &mut self, addr: &str, dispatch: D,
39    ) -> io::Result<String> {
40        match T::bind(addr).await {
41            Err(e) => {
42                error!("bind addr {:?} err: {}", addr, e);
43                return Err(e);
44            }
45            Ok(mut listener) => {
46                let local_addr = match listener.local_addr() {
47                    Ok(addr) => addr,
48                    Err(e) => {
49                        if e.kind() == std::io::ErrorKind::AddrNotAvailable {
50                            // For Unix sockets, return a dummy address
51                            "0.0.0.0:0".parse().unwrap()
52                        } else {
53                            return Err(e);
54                        }
55                    }
56                };
57                let facts = self.facts.clone();
58                let conn_ref_count = self.conn_ref_count.clone();
59                let listener_info = format!("listener {:?}", addr);
60                let server_close_rx = self.server_close_rx.clone();
61                debug!("listening on {:?}", listener);
62                let handle = self.facts.spawn(async move {
63                    loop {
64                        match listener.accept().await {
65                            Err(e) => {
66                                warn!("{:?} accept error: {}", listener, e);
67                                return;
68                            }
69                            Ok(stream) => {
70                                let conn =
71                                    T::new_conn(stream, facts.get_config(), conn_ref_count.clone());
72                                Self::server_conn::<T, D>(
73                                    conn,
74                                    &facts,
75                                    dispatch.clone(),
76                                    server_close_rx.clone(),
77                                )
78                            }
79                        }
80                    }
81                });
82                self.listeners_abort.push((handle, listener_info));
83                return Ok(local_addr);
84            }
85        }
86    }
87
88    fn server_conn<T: ServerTransport, D: Dispatch>(
89        conn: T, facts: &F, dispatch: D, server_close_rx: crossfire::MAsyncRx<()>,
90    ) {
91        let conn = Arc::new(conn);
92
93        let (done_tx, done_rx) = crossfire::mpsc::unbounded_async();
94        let codec = Arc::new(D::Codec::default());
95
96        let noti = RespNoti(done_tx);
97        struct Reader<T: ServerTransport, D: Dispatch> {
98            noti: RespNoti<D::RespTask>,
99            conn: Arc<T>,
100            server_close_rx: crossfire::MAsyncRx<()>,
101            codec: Arc<D::Codec>,
102            dispatch: D,
103            logger: Arc<LogFilter>,
104        }
105        let reader = Reader::<T, D> {
106            noti,
107            codec: codec.clone(),
108            dispatch,
109            conn: conn.clone(),
110            server_close_rx,
111            logger: facts.new_logger(),
112        };
113        facts.spawn_detach(async move { reader.run().await });
114
115        impl<T: ServerTransport, D: Dispatch> Reader<T, D> {
116            async fn run(self) -> Result<(), ()> {
117                loop {
118                    match self.conn.read_req(&self.logger, &self.server_close_rx).await {
119                        Ok(req) => {
120                            if req.action == RpcAction::Num(0) && req.msg.len() == 0 {
121                                // ping request
122                                self.send_quick_resp(req.seq, None)?;
123                            } else {
124                                let seq = req.seq;
125                                if self
126                                    .dispatch
127                                    .dispatch_req(&self.codec, req, self.noti.clone())
128                                    .await
129                                    .is_err()
130                                {
131                                    self.send_quick_resp(seq, Some(RpcIntErr::Decode.into()))?;
132                                }
133                            }
134                        }
135                        Err(_e) => {
136                            // XXX read_req return error not used
137                            return Err(());
138                        }
139                    }
140                }
141            }
142
143            #[inline]
144            fn send_quick_resp(&self, seq: u64, err: Option<RpcIntErr>) -> Result<(), ()> {
145                if self.noti.send_err(seq, err).is_err() {
146                    logger_warn!(self.logger, "{:?} reader abort due to writer has err", self.conn);
147                    return Err(());
148                }
149                Ok(())
150            }
151        }
152
153        struct Writer<T: ServerTransport, D: Dispatch> {
154            codec: Arc<D::Codec>,
155            done_rx: crossfire::AsyncRx<Result<D::RespTask, (u64, Option<RpcIntErr>)>>,
156            conn: Arc<T>,
157            logger: Arc<LogFilter>,
158        }
159        let writer = Writer::<T, D> { done_rx, codec, conn, logger: facts.new_logger() };
160        facts.spawn_detach(async move { writer.run().await });
161
162        impl<T: ServerTransport, D: Dispatch> Writer<T, D> {
163            async fn run(self) -> Result<(), io::Error> {
164                macro_rules! process {
165                    ($task: expr) => {{
166                        match $task {
167                            Ok(_task) => {
168                                logger_trace!(self.logger, "write_resp {:?}", _task);
169                                self.conn
170                                    .write_resp::<D::RespTask>(
171                                        &self.logger,
172                                        self.codec.as_ref(),
173                                        _task,
174                                    )
175                                    .await?;
176                            }
177                            Err((seq, err)) => {
178                                self.conn.write_resp_internal(&self.logger, seq, err).await?;
179                            }
180                        }
181                    }};
182                }
183                while let Ok(task) = self.done_rx.recv().await {
184                    process!(task);
185                    while let Ok(task) = self.done_rx.try_recv() {
186                        process!(task);
187                    }
188                    self.conn.flush_resp(&self.logger).await?;
189                }
190                logger_trace!(self.logger, "{:?} writer exits", self.conn);
191                self.conn.close_conn(&self.logger).await;
192                Ok(())
193            }
194        }
195    }
196
197    #[inline]
198    fn get_alive_conn(&self) -> usize {
199        Arc::strong_count(&self.conn_ref_count) - 1
200    }
201
202    /// Gracefully close the server
203    ///
204    /// Steps:
205    /// - listeners coroutine is abort
206    /// - drop the close channel to notify connection read coroutines.
207    /// - the writer coroutines will exit after all the reference of RespNoti channel drop to 0
208    /// - wait for connection coroutines to exit with a timeout defined by
209    /// ServerConfig.server_close_wait
210    pub async fn close(&mut self) {
211        // close listeners
212        for h in self.listeners_abort.drain(0..) {
213            h.0.abort();
214            logger_info!(self.logger, "{} has closed", h.1);
215        }
216        // Notify all reader connection exit, then the reader will notify writer
217        let _ = self.server_close_tx.lock().unwrap().take();
218
219        let mut exists_count = self.get_alive_conn();
220        // wait client close all connections
221        let start_ts = Instant::now();
222        let config = self.facts.get_config();
223        while exists_count > 0 {
224            F::sleep(Duration::from_secs(1)).await;
225            exists_count = self.get_alive_conn();
226            if Instant::now().duration_since(start_ts) > config.server_close_wait {
227                logger_warn!(
228                    self.logger,
229                    "closed as wait too long for all conn closed voluntarily({} conn left)",
230                    exists_count,
231                );
232                break;
233            }
234        }
235        logger_info!(self.logger, "server closed with alive conn {}", exists_count);
236    }
237}