Skip to main content

razor_stream/server/
server.rs

1use crate::proto::RpcAction;
2use crate::server::*;
3use captains_log::filter::LogFilter;
4use crossfire::{AsyncRx, MAsyncRx, mpmc, mpsc, null::CloseHandle};
5use orb::prelude::{AsyncJoiner, AsyncTime};
6use std::io;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10/// An RpcServer that listen, accept, and server connections, according to ServerFacts interface.
11pub struct RpcServer<F>
12where
13    F: ServerFacts,
14{
15    // join handles
16    listeners_abort: Vec<(Box<dyn AsyncJoiner<()>>, String)>,
17    logger: Arc<LogFilter>,
18    facts: Arc<F>,
19    conn_ref_count: Arc<()>,
20    // because we have multiple listeners
21    server_close_tx: Mutex<Option<CloseHandle<mpmc::Null>>>,
22    // because we have multiple listeners
23    server_close_rx: MAsyncRx<mpmc::Null>,
24}
25
26impl<F> RpcServer<F>
27where
28    F: ServerFacts,
29{
30    pub fn new(facts: Arc<F>) -> Self {
31        let (tx, rx) = crossfire::mpmc::new::<mpmc::Null, _, _>();
32        Self {
33            listeners_abort: Vec::new(),
34            logger: facts.new_logger(),
35            facts,
36            conn_ref_count: Arc::new(()),
37            server_close_tx: Mutex::new(Some(tx)),
38            server_close_rx: rx,
39        }
40    }
41
42    pub async fn listen<P: ServerTransport, D: Dispatch>(
43        &mut self, addr: &str, dispatch: D,
44    ) -> io::Result<String> {
45        match P::bind(addr).await {
46            Err(e) => {
47                error!("bind addr {:?} err: {}", addr, e);
48                return Err(e);
49            }
50            Ok(mut listener) => {
51                let local_addr = match listener.local_addr() {
52                    Ok(addr) => addr,
53                    Err(e) => {
54                        if e.kind() == std::io::ErrorKind::AddrNotAvailable {
55                            // For Unix sockets, return a dummy address
56                            "0.0.0.0:0".parse().unwrap()
57                        } else {
58                            return Err(e);
59                        }
60                    }
61                };
62                let facts = self.facts.clone();
63                let conn_ref_count = self.conn_ref_count.clone();
64                let listener_info = format!("listener {:?}", addr);
65                let server_close_rx = self.server_close_rx.clone();
66                debug!("listening on {:?}", listener);
67                let handle = P::RT::spawn(async move {
68                    loop {
69                        match listener.accept().await {
70                            Err(e) => {
71                                warn!("{:?} accept error: {}", listener, e);
72                                return;
73                            }
74                            Ok(stream) => {
75                                let conn =
76                                    P::new_conn(stream, facts.get_config(), conn_ref_count.clone());
77                                Self::server_conn::<P, D>(
78                                    conn,
79                                    &facts,
80                                    dispatch.clone(),
81                                    server_close_rx.clone(),
82                                )
83                            }
84                        }
85                    }
86                });
87                self.listeners_abort.push((Box::new(handle), listener_info));
88                return Ok(local_addr);
89            }
90        }
91    }
92
93    fn server_conn<P: ServerTransport, D: Dispatch>(
94        conn: P, facts: &F, dispatch: D, server_close_rx: MAsyncRx<mpmc::Null>,
95    ) {
96        let conn = Arc::new(conn);
97
98        let (done_tx, done_rx) = mpsc::unbounded_async();
99        let codec = Arc::new(D::Codec::default());
100
101        let noti = RespNoti(done_tx);
102        struct Reader<P: ServerTransport, D: Dispatch> {
103            noti: RespNoti<D::RespTask>,
104            conn: Arc<P>,
105            server_close_rx: MAsyncRx<mpmc::Null>,
106            codec: Arc<D::Codec>,
107            dispatch: D,
108            logger: Arc<LogFilter>,
109        }
110        let reader = Reader::<P, D> {
111            noti,
112            codec: codec.clone(),
113            dispatch,
114            conn: conn.clone(),
115            server_close_rx,
116            logger: facts.new_logger(),
117        };
118        P::RT::spawn_detach(async move { reader.run().await });
119
120        impl<P: ServerTransport, D: Dispatch> Reader<P, D> {
121            async fn run(self) -> Result<(), ()> {
122                loop {
123                    match self.conn.read_req(&self.logger, &self.server_close_rx).await {
124                        Ok(req) => {
125                            if req.action == RpcAction::Num(0) && req.msg.is_empty() {
126                                // ping request
127                                self.send_quick_resp(req.seq, None)?;
128                            } else {
129                                let seq = req.seq;
130                                if self
131                                    .dispatch
132                                    .dispatch_req(&self.codec, req, self.noti.clone())
133                                    .await
134                                    .is_err()
135                                {
136                                    self.send_quick_resp(seq, Some(RpcIntErr::Decode))?;
137                                }
138                            }
139                        }
140                        Err(_e) => {
141                            // XXX read_req return error not used
142                            return Err(());
143                        }
144                    }
145                }
146            }
147
148            #[inline]
149            fn send_quick_resp(&self, seq: u64, err: Option<RpcIntErr>) -> Result<(), ()> {
150                if self.noti.send_err(seq, err).is_err() {
151                    logger_warn!(self.logger, "{:?} reader abort due to writer has err", self.conn);
152                    return Err(());
153                }
154                Ok(())
155            }
156        }
157
158        struct Writer<P: ServerTransport, D: Dispatch> {
159            codec: Arc<D::Codec>,
160            done_rx: AsyncRx<mpsc::List<Result<D::RespTask, (u64, Option<RpcIntErr>)>>>,
161            conn: Arc<P>,
162            logger: Arc<LogFilter>,
163        }
164        let writer = Writer::<P, D> { done_rx, codec, conn, logger: facts.new_logger() };
165        P::RT::spawn_detach(async move { writer.run().await });
166
167        impl<P: ServerTransport, D: Dispatch> Writer<P, D> {
168            async fn run(self) -> Result<(), io::Error> {
169                macro_rules! process {
170                    ($task: expr) => {{
171                        match $task {
172                            Ok(_task) => {
173                                logger_trace!(self.logger, "write_resp {:?}", _task);
174                                self.conn
175                                    .write_resp::<D::RespTask>(
176                                        &self.logger,
177                                        self.codec.as_ref(),
178                                        _task,
179                                    )
180                                    .await?;
181                            }
182                            Err((seq, err)) => {
183                                self.conn.write_resp_internal(&self.logger, seq, err).await?;
184                            }
185                        }
186                    }};
187                }
188                while let Ok(task) = self.done_rx.recv().await {
189                    process!(task);
190                    while let Ok(task) = self.done_rx.try_recv() {
191                        process!(task);
192                    }
193                    self.conn.flush_resp(&self.logger).await?;
194                }
195                logger_trace!(self.logger, "{:?} writer exits", self.conn);
196                self.conn.close_conn(&self.logger).await;
197                Ok(())
198            }
199        }
200    }
201
202    #[inline]
203    fn get_alive_conn(&self) -> usize {
204        Arc::strong_count(&self.conn_ref_count) - 1
205    }
206
207    /// Gracefully close the server
208    ///
209    /// Steps:
210    /// - listeners coroutine is abort
211    /// - drop the close channel to notify connection read coroutines.
212    /// - the writer coroutines will exit after all the reference of RespNoti channel drop to 0
213    /// - wait for connection coroutines to exit with a timeout defined by
214    ///   ServerConfig.server_close_wait
215    pub async fn close<RT: AsyncTime>(&mut self) {
216        // close listeners
217        while let Some((h, addr)) = self.listeners_abort.pop() {
218            h.abort_boxed();
219            logger_info!(self.logger, "{} has closed", addr);
220        }
221        // Notify all reader connection exit, then the reader will notify writer
222        let _ = self.server_close_tx.lock().unwrap().take();
223
224        let mut exists_count = self.get_alive_conn();
225        // wait client close all connections
226        let start_ts = Instant::now();
227        let config = self.facts.get_config();
228        while exists_count > 0 {
229            RT::sleep(Duration::from_secs(1)).await;
230            exists_count = self.get_alive_conn();
231            if Instant::now().duration_since(start_ts) > config.server_close_wait {
232                logger_warn!(
233                    self.logger,
234                    "closed as wait too long for all conn closed voluntarily({} conn left)",
235                    exists_count,
236                );
237                break;
238            }
239        }
240        logger_info!(self.logger, "server closed with alive conn {}", exists_count);
241    }
242}