1use crate::proto::{
2 Isync, Ping, Request, Request_oneof_cmd as ReqCmd, Response, Response_oneof_cmd as RespCmd,
3 Rsync,
4};
5use protobuf::Message;
6use socket2::{Domain, Protocol, Socket, Type};
7use std::io::{Error, ErrorKind::Other, Read, Result, Write};
8use std::net::{Shutdown::Both, SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
9use std::sync::mpsc::{sync_channel, Receiver, RecvTimeoutError::Timeout, SyncSender};
10use std::sync::{Arc, Condvar, Mutex};
11use std::thread::spawn;
12use std::time::Duration;
13
14#[derive(Clone, Default)]
15struct Signal {
16 exit: bool,
17 broken: bool,
18}
19
20pub struct Client {
45 server_addr: String,
46 id: String,
47 listener: Option<Socket>,
48 local_addr: Option<SocketAddr>,
49 signal: Arc<(Mutex<Signal>, Condvar)>,
50}
51
52impl Drop for Client {
53 fn drop(&mut self) {
54 self.shutdown().unwrap();
55 }
56}
57
58impl Client {
59 pub fn new(server_addr: &str, id: &str, local_addr: Option<SocketAddr>) -> Result<Self> {
62 Ok(Self {
63 server_addr: server_addr.to_owned(),
64 id: id.into(),
65 local_addr: local_addr,
66 listener: None,
67 signal: Default::default(),
68 })
69 }
70
71 pub fn as_socket(&self) -> Option<TcpListener> {
73 self.listener
74 .as_ref()
75 .map(|l| l.try_clone().unwrap().into())
76 }
77
78 fn choose_bind_addr(&self) -> Result<SocketAddr> {
79 if let Some(ref addr) = self.local_addr {
80 return Ok(*addr);
81 }
82
83 let server_addr = self
84 .server_addr
85 .to_socket_addrs()?
86 .next()
87 .ok_or(Error::new(Other, "server name resolve fail"))?;
88
89 let local_addr = match server_addr {
90 SocketAddr::V4(_) => "0.0.0.0:0".parse().unwrap(),
91 SocketAddr::V6(_) => "[::]:0".parse().unwrap(),
92 };
93
94 Ok(local_addr)
95 }
96
97 fn connect_server(local_addr: SocketAddr, server_addr: &str) -> Result<socket2::Socket> {
98 let svr = Self::bind(local_addr.into())?;
99 let server_addr = server_addr
100 .to_socket_addrs()?
101 .next()
102 .ok_or(Error::new(Other, "server name resolve fail"))?;
103 svr.connect(&server_addr.into())?;
104 Ok(svr)
105 }
106
107 fn bind(local_addr: SocketAddr) -> Result<Socket> {
108 let domain = match local_addr {
109 SocketAddr::V4(_) => Domain::IPV4,
110 SocketAddr::V6(_) => Domain::IPV6,
111 };
112
113 let s = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
114 s.set_reuse_address(true)?;
115 #[cfg(unix)]
116 s.set_reuse_port(true)?;
117 s.bind(&local_addr.into())?;
118
119 Ok(s)
120 }
121
122 pub fn connect(&mut self, target_id: &str) -> Result<TcpStream> {
129 let mut svr = Self::connect_server(self.choose_bind_addr()?, &self.server_addr)?;
130
131 let mut isync = Isync::new();
132 isync.set_id(target_id.into());
133
134 Self::write_req(self.id.clone(), &mut svr, ReqCmd::Isync(isync))?;
135
136 let addr = match Self::read_resp(&mut svr)?.cmd {
137 Some(RespCmd::Redirect(rdr)) => rdr.addr,
138 _ => Err(Error::new(Other, "invalid server response"))?,
139 };
140
141 log::debug!("Redirect {}", addr);
142
143 let target_addr: SocketAddr = addr
144 .parse()
145 .map_err(|_| Error::new(Other, "target id not found"))?;
146
147 let local_addr = svr.local_addr().unwrap();
148
149 let s = Self::bind(local_addr.as_socket().unwrap())?;
150 s.connect(&target_addr.into())?;
151
152 Ok(s.into())
153 }
154
155 fn new_req(id: String) -> Request {
156 let mut req = Request::new();
157 req.set_id(id);
158
159 req
160 }
161
162 fn write_req(id: String, w: &mut dyn Write, cmd: ReqCmd) -> Result<()> {
163 let mut req = Self::new_req(id);
164 req.cmd = Some(cmd);
165 let buf = req.write_to_bytes()?;
166
167 w.write_all(&(buf.len() as u16).to_be_bytes())?;
168 w.write_all(&buf)?;
169 Ok(())
170 }
171
172 fn read_resp(r: &mut dyn Read) -> Result<Response> {
173 let mut buf = [0u8; 2];
174 r.read_exact(&mut buf)?;
175 let mut buf = vec![0; u16::from_be_bytes(buf).into()];
176 r.read_exact(&mut buf)?;
177 Response::parse_from_bytes(&buf).map_err(|_| Error::new(Other, "invalid message"))
178 }
179
180 fn write_loop(id: String, s: &mut dyn Write, rx: Receiver<ReqCmd>) -> Result<()> {
181 loop {
182 let req = match rx.recv_timeout(Duration::from_secs(10)) {
183 Ok(req) => req,
184 Err(Timeout) => ReqCmd::Ping(Ping::new()),
185 _ => break,
186 };
187
188 if Self::write_req(id.clone(), s, req).is_err() {
189 break;
190 }
191 }
192
193 Ok(())
194 }
195
196 fn read_loop(local_addr: SocketAddr, r: &mut dyn Read, tx: SyncSender<ReqCmd>) -> Result<()> {
197 loop {
198 let req = match Self::read_resp(r) {
199 Ok(resp) => match resp.cmd {
200 Some(RespCmd::Pong(_)) => None,
201 Some(RespCmd::Fsync(fsync)) => {
202 log::debug!("fsync {}", fsync.get_id());
203
204 let dst_addr: SocketAddr = fsync
205 .get_addr()
206 .parse()
207 .map_err(|_| Error::new(Other, "invalid fsync addr"))?;
208
209 log::debug!("connect {} -> {}", local_addr, dst_addr);
210
211 let _ = Self::bind(local_addr.into())
212 .map(|s| s.connect_timeout(&dst_addr.into(), Duration::from_micros(1)));
213
214 let mut rsync = Rsync::new();
215 rsync.set_id(fsync.get_id().to_string());
216 Some(ReqCmd::Rsync(rsync))
217 }
218 _ => None,
219 },
220 _ => break,
221 };
222
223 if let Some(req) = req {
224 tx.send(req).unwrap();
225 }
226 }
227
228 Ok(())
229 }
230
231 fn start_background(
232 id: String,
233 local_addr: SocketAddr,
234 server_addr: String,
235 signal: Arc<(Mutex<Signal>, Condvar)>,
236 ) -> Result<()> {
237 let svr_sk = Self::connect_server(local_addr, &server_addr)?;
238
239 let (tx, rx) = sync_channel(10);
240
241 tx.send(ReqCmd::Ping(Ping::new())).unwrap();
242
243 let mut hs = vec![];
244
245 hs.push({
246 let mut w = svr_sk.try_clone()?;
247 spawn(move || Self::write_loop(id, &mut w, rx).unwrap())
248 });
249
250 hs.push({
251 let local_addr = local_addr.clone();
252 let mut r = svr_sk.try_clone()?;
253 spawn(move || Self::read_loop(local_addr, &mut r, tx).unwrap())
254 });
255
256 {
257 let signal = signal.clone();
258 spawn(move || {
259 for h in hs {
260 h.join().unwrap();
261 }
262 let (lock, cvar) = &*signal;
263 let mut signal = lock.lock().unwrap();
264 (*signal).broken = true;
265 cvar.notify_all();
266 });
267 }
268
269 {
270 let signal = signal.clone();
271 spawn(move || {
272 let (lock, cvar) = &*signal;
273 let mut signal = lock.lock().unwrap();
274 if (*signal).exit {
275 let _ = svr_sk.shutdown(Both);
276 return;
277 }
278 signal = cvar.wait(signal).unwrap();
279 if (*signal).exit {
280 let _ = svr_sk.shutdown(Both);
281 }
282 });
283 }
284
285 return Ok(());
286 }
287
288 pub fn listen(&mut self) -> Result<()> {
295 let listener = Self::bind(self.choose_bind_addr()?)?;
296 listener.listen(10)?;
297
298 let id = self.id.clone();
299 let local_addr = listener.local_addr().unwrap().as_socket().unwrap();
300 let server_addr = self.server_addr.clone();
301 let signal = self.signal.clone();
302 Self::start_background(
303 id.clone(),
304 local_addr.clone(),
305 server_addr.clone(),
306 signal.clone(),
307 )?;
308
309 spawn(move || loop {
310 {
311 let (lock, cvar) = &*signal;
312 let mut signal = lock.lock().unwrap();
313 if (*signal).exit {
314 return;
315 }
316 signal = cvar.wait(signal).unwrap();
317 if (*signal).exit {
318 return;
319 }
320
321 assert_eq!((*signal).broken, true);
322
323 (*signal).broken = false;
324 }
325
326 log::debug!("connection with server is broken, try to reconnect.");
327
328 loop {
329 match Self::start_background(
330 id.clone(),
331 local_addr,
332 server_addr.clone(),
333 signal.clone(),
334 ) {
335 Ok(_) => {
336 log::debug!("connect server success");
337 break;
338 }
339 Err(err) => log::debug!("connect server fail, retry later. {}", err),
340 };
341
342 let (lock, cvar) = &*signal;
343 let mut signal = lock.lock().unwrap();
344 if (*signal).exit {
345 return;
346 }
347 signal = cvar
348 .wait_timeout(signal, Duration::from_secs(120))
349 .unwrap()
350 .0;
351 if (*signal).exit {
352 return;
353 }
354 }
355 });
356
357 self.listener = Some(listener);
358
359 Ok(())
360 }
361
362 pub fn accept(&mut self) -> Result<(TcpStream, SocketAddr)> {
364 self.listener
365 .as_ref()
366 .ok_or(Error::new(Other, "not listening"))?
367 .accept()
368 .map(|(s, a)| (s.into(), a.as_socket().unwrap()))
369 }
370
371 pub fn shutdown(&mut self) -> Result<()> {
373 let _ = self.listener.take().map(|l| l.shutdown(Both));
374
375 let (lock, cvar) = &*self.signal;
376 let mut signal = lock.lock().unwrap();
377 (*signal).exit = true;
378 cvar.notify_all();
379
380 Ok(())
381 }
382}