Skip to main content

razor_stream/server/
server.rs

1use crate::proto::RpcAction;
2use crate::server::*;
3use captains_log::filter::LogFilter;
4use crossfire::{AsyncRx, MAsyncRx, mpmc, mpsc, null::CloseHandle};
5use orb::prelude::AsyncTime;
6use std::io;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10/// An RpcServer that listen, accept, and server connections, according to ServerFacts interface.
11pub struct RpcServer<F>
12where
13    F: ServerFacts,
14{
15    // join handles
16    listeners_abort: Vec<(Box<dyn AsyncHandle<()>>, String)>,
17    logger: Arc<LogFilter>,
18    facts: Arc<F>,
19    conn_ref_count: Arc<()>,
20    // because we have multiple listeners
21    server_close_tx: Mutex<Option<CloseHandle<mpmc::Null>>>,
22    // because we have multiple listeners
23    server_close_rx: MAsyncRx<mpmc::Null>,
24}
25
26impl<F> RpcServer<F>
27where
28    F: ServerFacts,
29{
30    pub fn new(facts: Arc<F>) -> Self {
31        let (tx, rx) = crossfire::mpmc::new::<mpmc::Null, _, _>();
32        Self {
33            listeners_abort: Vec::new(),
34            logger: facts.new_logger(),
35            facts,
36            conn_ref_count: Arc::new(()),
37            server_close_tx: Mutex::new(Some(tx)),
38            server_close_rx: rx,
39        }
40    }
41
42    pub async fn listen<P: ServerTransport, D: Dispatch>(
43        &mut self, rt: P::RT, addr: &str, dispatch: D,
44    ) -> io::Result<String> {
45        match P::bind(addr).await {
46            Err(e) => {
47                error!("bind addr {:?} err: {}", addr, e);
48                return Err(e);
49            }
50            Ok(mut listener) => {
51                let local_addr = match listener.local_addr() {
52                    Ok(addr) => addr,
53                    Err(e) => {
54                        if e.kind() == std::io::ErrorKind::AddrNotAvailable {
55                            // For Unix sockets, return a dummy address
56                            "0.0.0.0:0".parse().unwrap()
57                        } else {
58                            return Err(e);
59                        }
60                    }
61                };
62                let facts = self.facts.clone();
63                let conn_ref_count = self.conn_ref_count.clone();
64                let listener_info = format!("listener {:?}", addr);
65                let server_close_rx = self.server_close_rx.clone();
66                debug!("listening on {:?}", listener);
67                let _rt = rt.clone();
68                let handle = rt.spawn(async move {
69                    loop {
70                        match listener.accept().await {
71                            Err(e) => {
72                                warn!("{:?} accept error: {}", listener, e);
73                                return;
74                            }
75                            Ok(stream) => {
76                                let conn =
77                                    P::new_conn(stream, facts.get_config(), conn_ref_count.clone());
78                                Self::server_conn::<P, D>(
79                                    &_rt,
80                                    conn,
81                                    &facts,
82                                    dispatch.clone(),
83                                    server_close_rx.clone(),
84                                )
85                            }
86                        }
87                    }
88                });
89                self.listeners_abort.push((Box::new(handle), listener_info));
90                return Ok(local_addr);
91            }
92        }
93    }
94
95    fn server_conn<P: ServerTransport, D: Dispatch>(
96        rt: &P::RT, conn: P, facts: &F, dispatch: D, server_close_rx: MAsyncRx<mpmc::Null>,
97    ) {
98        let conn = Arc::new(conn);
99
100        let (done_tx, done_rx) = mpsc::unbounded_async();
101        let codec = Arc::new(D::Codec::default());
102
103        let noti = RespNoti(done_tx);
104        struct Reader<P: ServerTransport, D: Dispatch> {
105            noti: RespNoti<D::RespTask>,
106            conn: Arc<P>,
107            server_close_rx: MAsyncRx<mpmc::Null>,
108            codec: Arc<D::Codec>,
109            dispatch: D,
110            logger: Arc<LogFilter>,
111        }
112        let reader = Reader::<P, D> {
113            noti,
114            codec: codec.clone(),
115            dispatch,
116            conn: conn.clone(),
117            server_close_rx,
118            logger: facts.new_logger(),
119        };
120        rt.spawn_detach(async move { reader.run().await });
121
122        impl<P: ServerTransport, D: Dispatch> Reader<P, D> {
123            async fn run(self) -> Result<(), ()> {
124                loop {
125                    match self.conn.read_req(&self.logger, &self.server_close_rx).await {
126                        Ok(req) => {
127                            if req.action == RpcAction::Num(0) && req.msg.is_empty() {
128                                // ping request
129                                self.send_quick_resp(req.seq, None)?;
130                            } else {
131                                let seq = req.seq;
132                                if self
133                                    .dispatch
134                                    .dispatch_req(&self.codec, req, self.noti.clone())
135                                    .await
136                                    .is_err()
137                                {
138                                    self.send_quick_resp(seq, Some(RpcIntErr::Decode))?;
139                                }
140                            }
141                        }
142                        Err(_e) => {
143                            // XXX read_req return error not used
144                            return Err(());
145                        }
146                    }
147                }
148            }
149
150            #[inline]
151            fn send_quick_resp(&self, seq: u64, err: Option<RpcIntErr>) -> Result<(), ()> {
152                if self.noti.send_err(seq, err).is_err() {
153                    logger_warn!(self.logger, "{:?} reader abort due to writer has err", self.conn);
154                    return Err(());
155                }
156                Ok(())
157            }
158        }
159
160        struct Writer<P: ServerTransport, D: Dispatch> {
161            codec: Arc<D::Codec>,
162            done_rx: AsyncRx<mpsc::List<Result<D::RespTask, (u64, Option<RpcIntErr>)>>>,
163            conn: Arc<P>,
164            logger: Arc<LogFilter>,
165        }
166        let writer = Writer::<P, D> { done_rx, codec, conn, logger: facts.new_logger() };
167        rt.spawn_detach(async move { writer.run().await });
168
169        impl<P: ServerTransport, D: Dispatch> Writer<P, D> {
170            async fn run(self) -> Result<(), io::Error> {
171                macro_rules! process {
172                    ($task: expr) => {{
173                        match $task {
174                            Ok(_task) => {
175                                logger_trace!(self.logger, "write_resp {:?}", _task);
176                                self.conn
177                                    .write_resp::<D::RespTask>(
178                                        &self.logger,
179                                        self.codec.as_ref(),
180                                        _task,
181                                    )
182                                    .await?;
183                            }
184                            Err((seq, err)) => {
185                                self.conn.write_resp_internal(&self.logger, seq, err).await?;
186                            }
187                        }
188                    }};
189                }
190                while let Ok(task) = self.done_rx.recv().await {
191                    process!(task);
192                    while let Ok(task) = self.done_rx.try_recv() {
193                        process!(task);
194                    }
195                    self.conn.flush_resp(&self.logger).await?;
196                }
197                logger_trace!(self.logger, "{:?} writer exits", self.conn);
198                self.conn.close_conn(&self.logger).await;
199                Ok(())
200            }
201        }
202    }
203
204    #[inline]
205    fn get_alive_conn(&self) -> usize {
206        Arc::strong_count(&self.conn_ref_count) - 1
207    }
208
209    /// Gracefully close the server
210    ///
211    /// Steps:
212    /// - listeners coroutine is abort
213    /// - drop the close channel to notify connection read coroutines.
214    /// - the writer coroutines will exit after all the reference of RespNoti channel drop to 0
215    /// - wait for connection coroutines to exit with a timeout defined by
216    ///   ServerConfig.server_close_wait
217    pub async fn close<RT: AsyncTime>(&mut self) {
218        // close listeners
219        while let Some((h, addr)) = self.listeners_abort.pop() {
220            h.abort_boxed();
221            logger_info!(self.logger, "{} has closed", addr);
222        }
223        // Notify all reader connection exit, then the reader will notify writer
224        let _ = self.server_close_tx.lock().unwrap().take();
225
226        let mut exists_count = self.get_alive_conn();
227        // wait client close all connections
228        let start_ts = Instant::now();
229        let config = self.facts.get_config();
230        while exists_count > 0 {
231            RT::sleep(Duration::from_secs(1)).await;
232            exists_count = self.get_alive_conn();
233            if Instant::now().duration_since(start_ts) > config.server_close_wait {
234                logger_warn!(
235                    self.logger,
236                    "closed as wait too long for all conn closed voluntarily({} conn left)",
237                    exists_count,
238                );
239                break;
240            }
241        }
242        logger_info!(self.logger, "server closed with alive conn {}", exists_count);
243    }
244}