1use crate::proto::{
2 Bye, Fsync, Isync, Ping, Pong, Redirect, Request, Request_oneof_cmd as ReqCmd, Response,
3 Response_oneof_cmd as RespCmd, Rsync,
4};
5use log;
6use protobuf::Message;
7use std::collections::HashMap;
8use std::io::{Error, ErrorKind::Other, Result};
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12use tokio::sync::mpsc::{channel, Receiver, Sender};
13use tokio::{
14 io::{AsyncReadExt, AsyncWriteExt},
15 net::{
16 tcp::{ReadHalf, WriteHalf},
17 TcpListener, ToSocketAddrs,
18 },
19 select, task,
20 time::timeout,
21};
22
23pub struct Server {
27 listener: TcpListener,
28 peers: PeerMap,
29 count: u64,
30}
31
32impl Server {
33 pub async fn new<A: ToSocketAddrs>(listen_addr: A) -> Result<Self> {
34 let listener = TcpListener::bind(listen_addr).await?;
35
36 Ok(Self {
37 listener: listener,
38 peers: Default::default(),
39 count: 0,
40 })
41 }
42
43 fn next_id(&mut self) -> u64 {
44 self.count += 1;
45 self.count
46 }
47
48 pub async fn run(mut self) -> Result<()> {
49 while let Ok((mut stream, _addr)) = self.listener.accept().await {
50 let id = self.next_id();
51 let peers = self.peers.clone();
52
53 task::spawn(async move {
54 let (tx, rx) = channel(10);
55 let (r, w) = stream.split();
56 let h = PeerHandler {
57 stream: w,
58 peers: peers,
59 peer_id: "".to_string(),
60 req_rx: rx,
61 req_tx: tx,
62 id: id,
63 };
64
65 h.handle_stream(r).await;
66 });
67 }
68
69 Ok(())
70 }
71}
72
73struct PeerState {
74 id: u64,
75 last_ping: Instant,
76 req_tx: Sender<Request>,
77 addr: SocketAddr,
78}
79
80type PeerMap = Arc<Mutex<HashMap<String, PeerState>>>;
81
82struct PeerHandler<'a> {
83 stream: WriteHalf<'a>,
84 peers: PeerMap,
85 peer_id: String,
86 id: u64,
87 req_tx: Sender<Request>,
88 req_rx: Receiver<Request>,
89}
90
91impl<'a> PeerHandler<'a> {
92 async fn handle_stream(mut self, r: ReadHalf<'a>) {
93 let req_tx = self.req_tx.clone();
94
95 select! {
96 _ = Self::read_reqs(r, req_tx) => {},
97 _ = self.handle_cmds()=> {},
98 }
99
100 let mut peers = self.peers.lock().unwrap();
101 if let Some(p) = (*peers).get(&self.peer_id) {
102 if p.id == self.id {
103 peers.remove(&self.peer_id);
104 }
105 }
106
107 if self.peer_id != "" {
108 log::debug!("peer {} disconnect", self.peer_id);
109 }
110 }
111
112 async fn read_reqs(mut r: ReadHalf<'a>, req_tx: Sender<Request>) -> Result<()> {
113 loop {
114 match timeout(Duration::from_secs(30), Self::read_req(&mut r)).await? {
115 Ok(req) => req_tx
116 .send(req)
117 .await
118 .map_err(|_| Error::new(Other, "mpsc closed?")),
119 _ => {
120 let mut bye = Request::new();
121 bye.set_Bye(Bye::new());
122 let _ = req_tx.send(bye);
123
124 Err(Error::new(Other, "byte"))
125 }
126 }?;
127 }
128 }
129
130 async fn handle_cmds(&mut self) -> Result<()> {
131 loop {
132 match self.req_rx.recv().await {
133 Some(req) => {
134 let src_id = req.get_id().to_string();
135 match req.cmd {
136 Some(ReqCmd::Ping(ping)) => self.handle_ping(src_id, ping).await,
137 Some(ReqCmd::Isync(isync)) => self.handle_isync(src_id, isync).await,
138 Some(ReqCmd::Fsync(fsync)) => self.handle_fsync(fsync).await,
139 Some(ReqCmd::Rsync(rsync)) => self.handle_rsync(src_id, rsync).await,
140 Some(ReqCmd::Bye(_)) => Err(Error::new(Other, "bye")),
141 _ => Err(Error::new(Other, "uknown cmd")),
142 }
143 }
144 _ => Err(Error::new(Other, "bye")),
145 }?;
146 }
147 }
148
149 async fn read_req(stream: &mut ReadHalf<'a>) -> Result<Request> {
150 let mut buf = [0u8; 2];
151 stream.read_exact(&mut buf).await?;
152
153 let size = u16::from_be_bytes(buf).into();
154 if size > 1500 {
155 Err(Error::new(Other, "invalid message"))?;
156 }
157 let mut buf = vec![0u8; size];
158 stream.read_exact(&mut buf).await?;
159
160 Request::parse_from_bytes(&mut buf).map_err(|_| Error::new(Other, "invalid message"))
161 }
162
163 async fn send_response(&mut self, cmd: RespCmd) -> Result<()> {
164 let mut resp = Response::new();
165 resp.cmd = Some(cmd);
166
167 let vec = resp.write_to_bytes().unwrap();
168
169 let _ = self
170 .stream
171 .write_all(&(vec.len() as u16).to_be_bytes())
172 .await?;
173 let _ = self.stream.write_all(&vec).await?;
174 Ok(())
175 }
176
177 fn insert_peerstate(&mut self, id: String) {
178 self.peer_id = id;
179 let mut peers = self.peers.lock().unwrap();
180
181 if match (*peers).get(&self.peer_id) {
182 Some(p) => p.id != self.id,
183 _ => false,
184 } {
185 log::debug!("update peer {}", self.peer_id);
186 peers.remove(&self.peer_id);
187 }
188 let mut p = peers
189 .entry(self.peer_id.clone())
190 .or_insert_with(|| PeerState {
191 id: self.id,
192 req_tx: self.req_tx.clone(),
193 last_ping: Instant::now(),
194 addr: self.stream.as_ref().peer_addr().unwrap(),
195 });
196
197 p.last_ping = Instant::now();
198 }
199
200 async fn handle_ping(&mut self, src_id: String, _ping: Ping) -> Result<()> {
201 log::trace!("ping {}", src_id);
202
203 self.insert_peerstate(src_id);
204
205 self.send_response(RespCmd::Pong(Pong::new())).await
206 }
207
208 async fn handle_isync(&mut self, src_id: String, isync: Isync) -> Result<()> {
209 let dst_id = isync.get_id();
210 log::debug!("isync {} -> {}", src_id, dst_id);
211
212 let mut rdr = Redirect::new();
213 rdr.set_id(dst_id.to_string());
214
215 if let Some((req_tx, freq)) = {
216 let peers = self.peers.lock().unwrap();
217 let p = match (*peers).get(dst_id) {
218 Some(p) => Some(p),
219 None => {
220 log::debug!("{} not found", dst_id);
221 None
222 }
223 };
224
225 if let Some(p) = p {
226 rdr.set_addr(p.addr.to_string());
227
228 let mut fsync = Fsync::new();
230 fsync.set_id(src_id.clone());
231 fsync.set_addr(self.stream.as_ref().peer_addr().unwrap().to_string());
232
233 let mut freq = Request::new();
234
235 freq.set_Fsync(fsync);
236
237 Some((p.req_tx.clone(), freq))
238 } else {
239 None
240 }
241 } {
242 self.insert_peerstate(src_id);
244
245 let _ = req_tx.send(freq).await;
246 Ok(())
247 } else {
248 self.send_response(RespCmd::Redirect(rdr)).await
249 }
250 }
251
252 async fn handle_fsync(&mut self, fsync: Fsync) -> Result<()> {
253 log::debug!("fsync {} -> {} ", fsync.get_id(), self.peer_id);
254 self.send_response(RespCmd::Fsync(fsync)).await
255 }
256
257 async fn handle_rsync(&mut self, src_id: String, rsync: Rsync) -> Result<()> {
258 let dst_id = rsync.get_id();
259 log::debug!("rsync {} -> {}", src_id, dst_id);
260
261 if dst_id == self.peer_id {
262 let rdr = if let Some(p) = self.peers.lock().unwrap().get(&src_id) {
263 let mut rdr = Redirect::new();
264 rdr.set_id(src_id.to_string());
265 rdr.set_addr(p.addr.to_string());
266 rdr
267 } else {
268 log::debug!("{} not found", src_id);
269 return Ok(());
270 };
271
272 return self.send_response(RespCmd::Redirect(rdr)).await;
273 }
274
275 let req_tx = {
276 let peers = self.peers.lock().unwrap();
277 match (*peers).get(dst_id) {
278 Some(p) => p.req_tx.clone(),
279 None => {
280 log::debug!("{} not found", dst_id);
281 return Ok(());
282 }
283 }
284 };
285
286 let mut req = Request::new();
287 req.set_id(src_id);
288 req.cmd = Some(ReqCmd::Rsync(rsync));
289
290 let _ = req_tx.send(req).await;
291
292 Ok(())
293 }
294}