rndz/tcp/
server.rs

1use crate::proto::{
2    Bye, Fsync, Isync, Ping, Pong, Redirect, Request, Request_oneof_cmd as ReqCmd, Response,
3    Response_oneof_cmd as RespCmd, Rsync,
4};
5use log;
6use protobuf::Message;
7use std::collections::HashMap;
8use std::io::{Error, ErrorKind::Other, Result};
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12use tokio::sync::mpsc::{channel, Receiver, Sender};
13use tokio::{
14    io::{AsyncReadExt, AsyncWriteExt},
15    net::{
16        tcp::{ReadHalf, WriteHalf},
17        TcpListener, ToSocketAddrs,
18    },
19    select, task,
20    time::timeout,
21};
22
23/// Tcp rendezvous server
24///
25/// keep traces of all peers, and forward connection request.
26pub struct Server {
27    listener: TcpListener,
28    peers: PeerMap,
29    count: u64,
30}
31
32impl Server {
33    pub async fn new<A: ToSocketAddrs>(listen_addr: A) -> Result<Self> {
34        let listener = TcpListener::bind(listen_addr).await?;
35
36        Ok(Self {
37            listener: listener,
38            peers: Default::default(),
39            count: 0,
40        })
41    }
42
43    fn next_id(&mut self) -> u64 {
44        self.count += 1;
45        self.count
46    }
47
48    pub async fn run(mut self) -> Result<()> {
49        while let Ok((mut stream, _addr)) = self.listener.accept().await {
50            let id = self.next_id();
51            let peers = self.peers.clone();
52
53            task::spawn(async move {
54                let (tx, rx) = channel(10);
55                let (r, w) = stream.split();
56                let h = PeerHandler {
57                    stream: w,
58                    peers: peers,
59                    peer_id: "".to_string(),
60                    req_rx: rx,
61                    req_tx: tx,
62                    id: id,
63                };
64
65                h.handle_stream(r).await;
66            });
67        }
68
69        Ok(())
70    }
71}
72
73struct PeerState {
74    id: u64,
75    last_ping: Instant,
76    req_tx: Sender<Request>,
77    addr: SocketAddr,
78}
79
80type PeerMap = Arc<Mutex<HashMap<String, PeerState>>>;
81
82struct PeerHandler<'a> {
83    stream: WriteHalf<'a>,
84    peers: PeerMap,
85    peer_id: String,
86    id: u64,
87    req_tx: Sender<Request>,
88    req_rx: Receiver<Request>,
89}
90
91impl<'a> PeerHandler<'a> {
92    async fn handle_stream(mut self, r: ReadHalf<'a>) {
93        let req_tx = self.req_tx.clone();
94
95        select! {
96            _ = Self::read_reqs(r, req_tx) => {},
97            _ = self.handle_cmds()=> {},
98        }
99
100        let mut peers = self.peers.lock().unwrap();
101        if let Some(p) = (*peers).get(&self.peer_id) {
102            if p.id == self.id {
103                peers.remove(&self.peer_id);
104            }
105        }
106
107        if self.peer_id != "" {
108            log::debug!("peer {} disconnect", self.peer_id);
109        }
110    }
111
112    async fn read_reqs(mut r: ReadHalf<'a>, req_tx: Sender<Request>) -> Result<()> {
113        loop {
114            match timeout(Duration::from_secs(30), Self::read_req(&mut r)).await? {
115                Ok(req) => req_tx
116                    .send(req)
117                    .await
118                    .map_err(|_| Error::new(Other, "mpsc closed?")),
119                _ => {
120                    let mut bye = Request::new();
121                    bye.set_Bye(Bye::new());
122                    let _ = req_tx.send(bye);
123
124                    Err(Error::new(Other, "byte"))
125                }
126            }?;
127        }
128    }
129
130    async fn handle_cmds(&mut self) -> Result<()> {
131        loop {
132            match self.req_rx.recv().await {
133                Some(req) => {
134                    let src_id = req.get_id().to_string();
135                    match req.cmd {
136                        Some(ReqCmd::Ping(ping)) => self.handle_ping(src_id, ping).await,
137                        Some(ReqCmd::Isync(isync)) => self.handle_isync(src_id, isync).await,
138                        Some(ReqCmd::Fsync(fsync)) => self.handle_fsync(fsync).await,
139                        Some(ReqCmd::Rsync(rsync)) => self.handle_rsync(src_id, rsync).await,
140                        Some(ReqCmd::Bye(_)) => Err(Error::new(Other, "bye")),
141                        _ => Err(Error::new(Other, "uknown cmd")),
142                    }
143                }
144                _ => Err(Error::new(Other, "bye")),
145            }?;
146        }
147    }
148
149    async fn read_req(stream: &mut ReadHalf<'a>) -> Result<Request> {
150        let mut buf = [0u8; 2];
151        stream.read_exact(&mut buf).await?;
152
153        let size = u16::from_be_bytes(buf).into();
154        if size > 1500 {
155            Err(Error::new(Other, "invalid message"))?;
156        }
157        let mut buf = vec![0u8; size];
158        stream.read_exact(&mut buf).await?;
159
160        Request::parse_from_bytes(&mut buf).map_err(|_| Error::new(Other, "invalid message"))
161    }
162
163    async fn send_response(&mut self, cmd: RespCmd) -> Result<()> {
164        let mut resp = Response::new();
165        resp.cmd = Some(cmd);
166
167        let vec = resp.write_to_bytes().unwrap();
168
169        let _ = self
170            .stream
171            .write_all(&(vec.len() as u16).to_be_bytes())
172            .await?;
173        let _ = self.stream.write_all(&vec).await?;
174        Ok(())
175    }
176
177    fn insert_peerstate(&mut self, id: String) {
178        self.peer_id = id;
179        let mut peers = self.peers.lock().unwrap();
180
181        if match (*peers).get(&self.peer_id) {
182            Some(p) => p.id != self.id,
183            _ => false,
184        } {
185            log::debug!("update peer {}", self.peer_id);
186            peers.remove(&self.peer_id);
187        }
188        let mut p = peers
189            .entry(self.peer_id.clone())
190            .or_insert_with(|| PeerState {
191                id: self.id,
192                req_tx: self.req_tx.clone(),
193                last_ping: Instant::now(),
194                addr: self.stream.as_ref().peer_addr().unwrap(),
195            });
196
197        p.last_ping = Instant::now();
198    }
199
200    async fn handle_ping(&mut self, src_id: String, _ping: Ping) -> Result<()> {
201        log::trace!("ping {}", src_id);
202
203        self.insert_peerstate(src_id);
204
205        self.send_response(RespCmd::Pong(Pong::new())).await
206    }
207
208    async fn handle_isync(&mut self, src_id: String, isync: Isync) -> Result<()> {
209        let dst_id = isync.get_id();
210        log::debug!("isync {} -> {}", src_id, dst_id);
211
212        let mut rdr = Redirect::new();
213        rdr.set_id(dst_id.to_string());
214
215        if let Some((req_tx, freq)) = {
216            let peers = self.peers.lock().unwrap();
217            let p = match (*peers).get(dst_id) {
218                Some(p) => Some(p),
219                None => {
220                    log::debug!("{} not found", dst_id);
221                    None
222                }
223            };
224
225            if let Some(p) = p {
226                rdr.set_addr(p.addr.to_string());
227
228                //forward
229                let mut fsync = Fsync::new();
230                fsync.set_id(src_id.clone());
231                fsync.set_addr(self.stream.as_ref().peer_addr().unwrap().to_string());
232
233                let mut freq = Request::new();
234
235                freq.set_Fsync(fsync);
236
237                Some((p.req_tx.clone(), freq))
238            } else {
239                None
240            }
241        } {
242            //wait for reply (rsync)
243            self.insert_peerstate(src_id);
244
245            let _ = req_tx.send(freq).await;
246            Ok(())
247        } else {
248            self.send_response(RespCmd::Redirect(rdr)).await
249        }
250    }
251
252    async fn handle_fsync(&mut self, fsync: Fsync) -> Result<()> {
253        log::debug!("fsync {} -> {} ", fsync.get_id(), self.peer_id);
254        self.send_response(RespCmd::Fsync(fsync)).await
255    }
256
257    async fn handle_rsync(&mut self, src_id: String, rsync: Rsync) -> Result<()> {
258        let dst_id = rsync.get_id();
259        log::debug!("rsync {} -> {}", src_id, dst_id);
260
261        if dst_id == self.peer_id {
262            let rdr = if let Some(p) = self.peers.lock().unwrap().get(&src_id) {
263                let mut rdr = Redirect::new();
264                rdr.set_id(src_id.to_string());
265                rdr.set_addr(p.addr.to_string());
266                rdr
267            } else {
268                log::debug!("{} not found", src_id);
269                return Ok(());
270            };
271
272            return self.send_response(RespCmd::Redirect(rdr)).await;
273        }
274
275        let req_tx = {
276            let peers = self.peers.lock().unwrap();
277            match (*peers).get(dst_id) {
278                Some(p) => p.req_tx.clone(),
279                None => {
280                    log::debug!("{} not found", dst_id);
281                    return Ok(());
282                }
283            }
284        };
285
286        let mut req = Request::new();
287        req.set_id(src_id);
288        req.cmd = Some(ReqCmd::Rsync(rsync));
289
290        let _ = req_tx.send(req).await;
291
292        Ok(())
293    }
294}