1use crate::proto::RpcAction;
2use crate::server::*;
3use captains_log::filter::LogFilter;
4use crossfire::{AsyncRx, MAsyncRx, mpmc, mpsc, null::CloseHandle};
5use std::io;
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8
9pub struct RpcServer<F>
11where
12 F: ServerFacts,
13{
14 listeners_abort: Vec<(<F as AsyncExec>::AsyncHandle<()>, String)>,
16 logger: Arc<LogFilter>,
17 facts: Arc<F>,
18 conn_ref_count: Arc<()>,
19 server_close_tx: Mutex<Option<CloseHandle<mpmc::Null>>>,
21 server_close_rx: MAsyncRx<mpmc::Null>,
23}
24
25impl<F> RpcServer<F>
26where
27 F: ServerFacts,
28{
29 pub fn new(facts: Arc<F>) -> Self {
30 let (tx, rx) = crossfire::mpmc::new::<mpmc::Null, _, _>();
31 Self {
32 listeners_abort: Vec::new(),
33 logger: facts.new_logger(),
34 facts,
35 conn_ref_count: Arc::new(()),
36 server_close_tx: Mutex::new(Some(tx)),
37 server_close_rx: rx,
38 }
39 }
40
41 pub async fn listen<T: ServerTransport, D: Dispatch>(
42 &mut self, addr: &str, dispatch: D,
43 ) -> io::Result<String> {
44 match T::bind(addr).await {
45 Err(e) => {
46 error!("bind addr {:?} err: {}", addr, e);
47 return Err(e);
48 }
49 Ok(mut listener) => {
50 let local_addr = match listener.local_addr() {
51 Ok(addr) => addr,
52 Err(e) => {
53 if e.kind() == std::io::ErrorKind::AddrNotAvailable {
54 "0.0.0.0:0".parse().unwrap()
56 } else {
57 return Err(e);
58 }
59 }
60 };
61 let facts = self.facts.clone();
62 let conn_ref_count = self.conn_ref_count.clone();
63 let listener_info = format!("listener {:?}", addr);
64 let server_close_rx = self.server_close_rx.clone();
65 debug!("listening on {:?}", listener);
66 let handle = self.facts.spawn(async move {
67 loop {
68 match listener.accept().await {
69 Err(e) => {
70 warn!("{:?} accept error: {}", listener, e);
71 return;
72 }
73 Ok(stream) => {
74 let conn =
75 T::new_conn(stream, facts.get_config(), conn_ref_count.clone());
76 Self::server_conn::<T, D>(
77 conn,
78 &facts,
79 dispatch.clone(),
80 server_close_rx.clone(),
81 )
82 }
83 }
84 }
85 });
86 self.listeners_abort.push((handle, listener_info));
87 return Ok(local_addr);
88 }
89 }
90 }
91
92 fn server_conn<T: ServerTransport, D: Dispatch>(
93 conn: T, facts: &F, dispatch: D, server_close_rx: MAsyncRx<mpmc::Null>,
94 ) {
95 let conn = Arc::new(conn);
96
97 let (done_tx, done_rx) = mpsc::unbounded_async();
98 let codec = Arc::new(D::Codec::default());
99
100 let noti = RespNoti(done_tx);
101 struct Reader<T: ServerTransport, D: Dispatch> {
102 noti: RespNoti<D::RespTask>,
103 conn: Arc<T>,
104 server_close_rx: MAsyncRx<mpmc::Null>,
105 codec: Arc<D::Codec>,
106 dispatch: D,
107 logger: Arc<LogFilter>,
108 }
109 let reader = Reader::<T, D> {
110 noti,
111 codec: codec.clone(),
112 dispatch,
113 conn: conn.clone(),
114 server_close_rx,
115 logger: facts.new_logger(),
116 };
117 facts.spawn_detach(async move { reader.run().await });
118
119 impl<T: ServerTransport, D: Dispatch> Reader<T, D> {
120 async fn run(self) -> Result<(), ()> {
121 loop {
122 match self.conn.read_req(&self.logger, &self.server_close_rx).await {
123 Ok(req) => {
124 if req.action == RpcAction::Num(0) && req.msg.len() == 0 {
125 self.send_quick_resp(req.seq, None)?;
127 } else {
128 let seq = req.seq;
129 if self
130 .dispatch
131 .dispatch_req(&self.codec, req, self.noti.clone())
132 .await
133 .is_err()
134 {
135 self.send_quick_resp(seq, Some(RpcIntErr::Decode.into()))?;
136 }
137 }
138 }
139 Err(_e) => {
140 return Err(());
142 }
143 }
144 }
145 }
146
147 #[inline]
148 fn send_quick_resp(&self, seq: u64, err: Option<RpcIntErr>) -> Result<(), ()> {
149 if self.noti.send_err(seq, err).is_err() {
150 logger_warn!(self.logger, "{:?} reader abort due to writer has err", self.conn);
151 return Err(());
152 }
153 Ok(())
154 }
155 }
156
157 struct Writer<T: ServerTransport, D: Dispatch> {
158 codec: Arc<D::Codec>,
159 done_rx: AsyncRx<mpsc::List<Result<D::RespTask, (u64, Option<RpcIntErr>)>>>,
160 conn: Arc<T>,
161 logger: Arc<LogFilter>,
162 }
163 let writer = Writer::<T, D> { done_rx, codec, conn, logger: facts.new_logger() };
164 facts.spawn_detach(async move { writer.run().await });
165
166 impl<T: ServerTransport, D: Dispatch> Writer<T, D> {
167 async fn run(self) -> Result<(), io::Error> {
168 macro_rules! process {
169 ($task: expr) => {{
170 match $task {
171 Ok(_task) => {
172 logger_trace!(self.logger, "write_resp {:?}", _task);
173 self.conn
174 .write_resp::<D::RespTask>(
175 &self.logger,
176 self.codec.as_ref(),
177 _task,
178 )
179 .await?;
180 }
181 Err((seq, err)) => {
182 self.conn.write_resp_internal(&self.logger, seq, err).await?;
183 }
184 }
185 }};
186 }
187 while let Ok(task) = self.done_rx.recv().await {
188 process!(task);
189 while let Ok(task) = self.done_rx.try_recv() {
190 process!(task);
191 }
192 self.conn.flush_resp(&self.logger).await?;
193 }
194 logger_trace!(self.logger, "{:?} writer exits", self.conn);
195 self.conn.close_conn(&self.logger).await;
196 Ok(())
197 }
198 }
199 }
200
201 #[inline]
202 fn get_alive_conn(&self) -> usize {
203 Arc::strong_count(&self.conn_ref_count) - 1
204 }
205
206 pub async fn close(&mut self) {
215 for h in self.listeners_abort.drain(0..) {
217 h.0.abort();
218 logger_info!(self.logger, "{} has closed", h.1);
219 }
220 let _ = self.server_close_tx.lock().unwrap().take();
222
223 let mut exists_count = self.get_alive_conn();
224 let start_ts = Instant::now();
226 let config = self.facts.get_config();
227 while exists_count > 0 {
228 F::sleep(Duration::from_secs(1)).await;
229 exists_count = self.get_alive_conn();
230 if Instant::now().duration_since(start_ts) > config.server_close_wait {
231 logger_warn!(
232 self.logger,
233 "closed as wait too long for all conn closed voluntarily({} conn left)",
234 exists_count,
235 );
236 break;
237 }
238 }
239 logger_info!(self.logger, "server closed with alive conn {}", exists_count);
240 }
241}