Skip to main content

razor_stream/server/
server.rs

1use crate::proto::RpcAction;
2use crate::server::*;
3use captains_log::filter::LogFilter;
4use crossfire::{AsyncRx, MAsyncRx, mpmc, mpsc, null::CloseHandle};
5use std::io;
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8
9/// An RpcServer that listen, accept, and server connections, according to ServerFacts interface.
10pub struct RpcServer<F>
11where
12    F: ServerFacts,
13{
14    // join handles
15    listeners_abort: Vec<(<F as AsyncExec>::AsyncHandle<()>, String)>,
16    logger: Arc<LogFilter>,
17    facts: Arc<F>,
18    conn_ref_count: Arc<()>,
19    // because we have multiple listeners
20    server_close_tx: Mutex<Option<CloseHandle<mpmc::Null>>>,
21    // because we have multiple listeners
22    server_close_rx: MAsyncRx<mpmc::Null>,
23}
24
25impl<F> RpcServer<F>
26where
27    F: ServerFacts,
28{
29    pub fn new(facts: Arc<F>) -> Self {
30        let (tx, rx) = crossfire::mpmc::new::<mpmc::Null, _, _>();
31        Self {
32            listeners_abort: Vec::new(),
33            logger: facts.new_logger(),
34            facts,
35            conn_ref_count: Arc::new(()),
36            server_close_tx: Mutex::new(Some(tx)),
37            server_close_rx: rx,
38        }
39    }
40
41    pub async fn listen<T: ServerTransport, D: Dispatch>(
42        &mut self, addr: &str, dispatch: D,
43    ) -> io::Result<String> {
44        match T::bind(addr).await {
45            Err(e) => {
46                error!("bind addr {:?} err: {}", addr, e);
47                return Err(e);
48            }
49            Ok(mut listener) => {
50                let local_addr = match listener.local_addr() {
51                    Ok(addr) => addr,
52                    Err(e) => {
53                        if e.kind() == std::io::ErrorKind::AddrNotAvailable {
54                            // For Unix sockets, return a dummy address
55                            "0.0.0.0:0".parse().unwrap()
56                        } else {
57                            return Err(e);
58                        }
59                    }
60                };
61                let facts = self.facts.clone();
62                let conn_ref_count = self.conn_ref_count.clone();
63                let listener_info = format!("listener {:?}", addr);
64                let server_close_rx = self.server_close_rx.clone();
65                debug!("listening on {:?}", listener);
66                let handle = self.facts.spawn(async move {
67                    loop {
68                        match listener.accept().await {
69                            Err(e) => {
70                                warn!("{:?} accept error: {}", listener, e);
71                                return;
72                            }
73                            Ok(stream) => {
74                                let conn =
75                                    T::new_conn(stream, facts.get_config(), conn_ref_count.clone());
76                                Self::server_conn::<T, D>(
77                                    conn,
78                                    &facts,
79                                    dispatch.clone(),
80                                    server_close_rx.clone(),
81                                )
82                            }
83                        }
84                    }
85                });
86                self.listeners_abort.push((handle, listener_info));
87                return Ok(local_addr);
88            }
89        }
90    }
91
92    fn server_conn<T: ServerTransport, D: Dispatch>(
93        conn: T, facts: &F, dispatch: D, server_close_rx: MAsyncRx<mpmc::Null>,
94    ) {
95        let conn = Arc::new(conn);
96
97        let (done_tx, done_rx) = mpsc::unbounded_async();
98        let codec = Arc::new(D::Codec::default());
99
100        let noti = RespNoti(done_tx);
101        struct Reader<T: ServerTransport, D: Dispatch> {
102            noti: RespNoti<D::RespTask>,
103            conn: Arc<T>,
104            server_close_rx: MAsyncRx<mpmc::Null>,
105            codec: Arc<D::Codec>,
106            dispatch: D,
107            logger: Arc<LogFilter>,
108        }
109        let reader = Reader::<T, D> {
110            noti,
111            codec: codec.clone(),
112            dispatch,
113            conn: conn.clone(),
114            server_close_rx,
115            logger: facts.new_logger(),
116        };
117        facts.spawn_detach(async move { reader.run().await });
118
119        impl<T: ServerTransport, D: Dispatch> Reader<T, D> {
120            async fn run(self) -> Result<(), ()> {
121                loop {
122                    match self.conn.read_req(&self.logger, &self.server_close_rx).await {
123                        Ok(req) => {
124                            if req.action == RpcAction::Num(0) && req.msg.len() == 0 {
125                                // ping request
126                                self.send_quick_resp(req.seq, None)?;
127                            } else {
128                                let seq = req.seq;
129                                if self
130                                    .dispatch
131                                    .dispatch_req(&self.codec, req, self.noti.clone())
132                                    .await
133                                    .is_err()
134                                {
135                                    self.send_quick_resp(seq, Some(RpcIntErr::Decode.into()))?;
136                                }
137                            }
138                        }
139                        Err(_e) => {
140                            // XXX read_req return error not used
141                            return Err(());
142                        }
143                    }
144                }
145            }
146
147            #[inline]
148            fn send_quick_resp(&self, seq: u64, err: Option<RpcIntErr>) -> Result<(), ()> {
149                if self.noti.send_err(seq, err).is_err() {
150                    logger_warn!(self.logger, "{:?} reader abort due to writer has err", self.conn);
151                    return Err(());
152                }
153                Ok(())
154            }
155        }
156
157        struct Writer<T: ServerTransport, D: Dispatch> {
158            codec: Arc<D::Codec>,
159            done_rx: AsyncRx<mpsc::List<Result<D::RespTask, (u64, Option<RpcIntErr>)>>>,
160            conn: Arc<T>,
161            logger: Arc<LogFilter>,
162        }
163        let writer = Writer::<T, D> { done_rx, codec, conn, logger: facts.new_logger() };
164        facts.spawn_detach(async move { writer.run().await });
165
166        impl<T: ServerTransport, D: Dispatch> Writer<T, D> {
167            async fn run(self) -> Result<(), io::Error> {
168                macro_rules! process {
169                    ($task: expr) => {{
170                        match $task {
171                            Ok(_task) => {
172                                logger_trace!(self.logger, "write_resp {:?}", _task);
173                                self.conn
174                                    .write_resp::<D::RespTask>(
175                                        &self.logger,
176                                        self.codec.as_ref(),
177                                        _task,
178                                    )
179                                    .await?;
180                            }
181                            Err((seq, err)) => {
182                                self.conn.write_resp_internal(&self.logger, seq, err).await?;
183                            }
184                        }
185                    }};
186                }
187                while let Ok(task) = self.done_rx.recv().await {
188                    process!(task);
189                    while let Ok(task) = self.done_rx.try_recv() {
190                        process!(task);
191                    }
192                    self.conn.flush_resp(&self.logger).await?;
193                }
194                logger_trace!(self.logger, "{:?} writer exits", self.conn);
195                self.conn.close_conn(&self.logger).await;
196                Ok(())
197            }
198        }
199    }
200
201    #[inline]
202    fn get_alive_conn(&self) -> usize {
203        Arc::strong_count(&self.conn_ref_count) - 1
204    }
205
206    /// Gracefully close the server
207    ///
208    /// Steps:
209    /// - listeners coroutine is abort
210    /// - drop the close channel to notify connection read coroutines.
211    /// - the writer coroutines will exit after all the reference of RespNoti channel drop to 0
212    /// - wait for connection coroutines to exit with a timeout defined by
213    /// ServerConfig.server_close_wait
214    pub async fn close(&mut self) {
215        // close listeners
216        for h in self.listeners_abort.drain(0..) {
217            h.0.abort();
218            logger_info!(self.logger, "{} has closed", h.1);
219        }
220        // Notify all reader connection exit, then the reader will notify writer
221        let _ = self.server_close_tx.lock().unwrap().take();
222
223        let mut exists_count = self.get_alive_conn();
224        // wait client close all connections
225        let start_ts = Instant::now();
226        let config = self.facts.get_config();
227        while exists_count > 0 {
228            F::sleep(Duration::from_secs(1)).await;
229            exists_count = self.get_alive_conn();
230            if Instant::now().duration_since(start_ts) > config.server_close_wait {
231                logger_warn!(
232                    self.logger,
233                    "closed as wait too long for all conn closed voluntarily({} conn left)",
234                    exists_count,
235                );
236                break;
237            }
238        }
239        logger_info!(self.logger, "server closed with alive conn {}", exists_count);
240    }
241}