1use crate::proto::RpcAction;
2use crate::server::*;
3use captains_log::filter::LogFilter;
4use std::io;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8pub struct RpcServer<F>
10where
11 F: ServerFacts,
12{
13 listeners_abort: Vec<(<F as AsyncExec>::AsyncHandle<()>, String)>,
14 logger: Arc<LogFilter>,
15 facts: Arc<F>,
16 conn_ref_count: Arc<()>,
17 server_close_tx: Mutex<Option<crossfire::MTx<()>>>,
18 server_close_rx: crossfire::MAsyncRx<()>,
19}
20
21impl<F> RpcServer<F>
22where
23 F: ServerFacts,
24{
25 pub fn new(facts: Arc<F>) -> Self {
26 let (tx, rx) = crossfire::mpmc::unbounded_async();
27 Self {
28 listeners_abort: Vec::new(),
29 logger: facts.new_logger(),
30 facts,
31 conn_ref_count: Arc::new(()),
32 server_close_tx: Mutex::new(Some(tx)),
33 server_close_rx: rx,
34 }
35 }
36
37 pub async fn listen<T: ServerTransport, D: Dispatch>(
38 &mut self, addr: &str, dispatch: D,
39 ) -> io::Result<String> {
40 match T::bind(addr).await {
41 Err(e) => {
42 error!("bind addr {:?} err: {}", addr, e);
43 return Err(e);
44 }
45 Ok(mut listener) => {
46 let local_addr = match listener.local_addr() {
47 Ok(addr) => addr,
48 Err(e) => {
49 if e.kind() == std::io::ErrorKind::AddrNotAvailable {
50 "0.0.0.0:0".parse().unwrap()
52 } else {
53 return Err(e);
54 }
55 }
56 };
57 let facts = self.facts.clone();
58 let conn_ref_count = self.conn_ref_count.clone();
59 let listener_info = format!("listener {:?}", addr);
60 let server_close_rx = self.server_close_rx.clone();
61 debug!("listening on {:?}", listener);
62 let handle = self.facts.spawn(async move {
63 loop {
64 match listener.accept().await {
65 Err(e) => {
66 warn!("{:?} accept error: {}", listener, e);
67 return;
68 }
69 Ok(stream) => {
70 let conn =
71 T::new_conn(stream, facts.get_config(), conn_ref_count.clone());
72 Self::server_conn::<T, D>(
73 conn,
74 &facts,
75 dispatch.clone(),
76 server_close_rx.clone(),
77 )
78 }
79 }
80 }
81 });
82 self.listeners_abort.push((handle, listener_info));
83 return Ok(local_addr);
84 }
85 }
86 }
87
88 fn server_conn<T: ServerTransport, D: Dispatch>(
89 conn: T, facts: &F, dispatch: D, server_close_rx: crossfire::MAsyncRx<()>,
90 ) {
91 let conn = Arc::new(conn);
92
93 let (done_tx, done_rx) = crossfire::mpsc::unbounded_async();
94 let codec = Arc::new(D::Codec::default());
95
96 let noti = RespNoti(done_tx);
97 struct Reader<T: ServerTransport, D: Dispatch> {
98 noti: RespNoti<D::RespTask>,
99 conn: Arc<T>,
100 server_close_rx: crossfire::MAsyncRx<()>,
101 codec: Arc<D::Codec>,
102 dispatch: D,
103 logger: Arc<LogFilter>,
104 }
105 let reader = Reader::<T, D> {
106 noti,
107 codec: codec.clone(),
108 dispatch,
109 conn: conn.clone(),
110 server_close_rx,
111 logger: facts.new_logger(),
112 };
113 facts.spawn_detach(async move { reader.run().await });
114
115 impl<T: ServerTransport, D: Dispatch> Reader<T, D> {
116 async fn run(self) -> Result<(), ()> {
117 loop {
118 match self.conn.read_req(&self.logger, &self.server_close_rx).await {
119 Ok(req) => {
120 if req.action == RpcAction::Num(0) && req.msg.len() == 0 {
121 self.send_quick_resp(req.seq, None)?;
123 } else {
124 let seq = req.seq;
125 if self
126 .dispatch
127 .dispatch_req(&self.codec, req, self.noti.clone())
128 .await
129 .is_err()
130 {
131 self.send_quick_resp(seq, Some(RpcIntErr::Decode.into()))?;
132 }
133 }
134 }
135 Err(_e) => {
136 return Err(());
138 }
139 }
140 }
141 }
142
143 #[inline]
144 fn send_quick_resp(&self, seq: u64, err: Option<RpcIntErr>) -> Result<(), ()> {
145 if self.noti.send_err(seq, err).is_err() {
146 logger_warn!(self.logger, "{:?} reader abort due to writer has err", self.conn);
147 return Err(());
148 }
149 Ok(())
150 }
151 }
152
153 struct Writer<T: ServerTransport, D: Dispatch> {
154 codec: Arc<D::Codec>,
155 done_rx: crossfire::AsyncRx<Result<D::RespTask, (u64, Option<RpcIntErr>)>>,
156 conn: Arc<T>,
157 logger: Arc<LogFilter>,
158 }
159 let writer = Writer::<T, D> { done_rx, codec, conn, logger: facts.new_logger() };
160 facts.spawn_detach(async move { writer.run().await });
161
162 impl<T: ServerTransport, D: Dispatch> Writer<T, D> {
163 async fn run(self) -> Result<(), io::Error> {
164 macro_rules! process {
165 ($task: expr) => {{
166 match $task {
167 Ok(_task) => {
168 logger_trace!(self.logger, "write_resp {:?}", _task);
169 self.conn
170 .write_resp::<D::RespTask>(
171 &self.logger,
172 self.codec.as_ref(),
173 _task,
174 )
175 .await?;
176 }
177 Err((seq, err)) => {
178 self.conn.write_resp_internal(&self.logger, seq, err).await?;
179 }
180 }
181 }};
182 }
183 while let Ok(task) = self.done_rx.recv().await {
184 process!(task);
185 while let Ok(task) = self.done_rx.try_recv() {
186 process!(task);
187 }
188 self.conn.flush_resp(&self.logger).await?;
189 }
190 logger_trace!(self.logger, "{:?} writer exits", self.conn);
191 self.conn.close_conn(&self.logger).await;
192 Ok(())
193 }
194 }
195 }
196
197 #[inline]
198 fn get_alive_conn(&self) -> usize {
199 Arc::strong_count(&self.conn_ref_count) - 1
200 }
201
202 pub async fn close(&mut self) {
211 for h in self.listeners_abort.drain(0..) {
213 h.0.abort();
214 logger_info!(self.logger, "{} has closed", h.1);
215 }
216 let _ = self.server_close_tx.lock().unwrap().take();
218
219 let mut exists_count = self.get_alive_conn();
220 let start_ts = Instant::now();
222 let config = self.facts.get_config();
223 while exists_count > 0 {
224 F::sleep(Duration::from_secs(1)).await;
225 exists_count = self.get_alive_conn();
226 if Instant::now().duration_since(start_ts) > config.server_close_wait {
227 logger_warn!(
228 self.logger,
229 "closed as wait too long for all conn closed voluntarily({} conn left)",
230 exists_count,
231 );
232 break;
233 }
234 }
235 logger_info!(self.logger, "server closed with alive conn {}", exists_count);
236 }
237}