Skip to main content

makepad_hub/
hubclient.rs

1use crate::hubmsg::*;
2use crate::hubrouter::*;
3
4use std::net::{TcpStream, UdpSocket, SocketAddr, SocketAddrV4, SocketAddrV6, Shutdown};
5use std::io::prelude::*;
6use std::sync::{mpsc, Arc, Mutex};
7use std::thread;
8use serde::{Serialize, Deserialize};
9
10#[cfg(any(target_os = "linux", target_os = "macos"))]
11use std::os::unix::io::AsRawFd;
12
13trait ResultMsg<T> {
14    fn expect_msg(self, msg: &str) -> Result<T, HubError>;
15}
16
17impl<T> ResultMsg<T> for Result<T, std::io::Error> {
18    fn expect_msg(self, msg: &str) -> Result<T, HubError> {
19        match self {
20            Err(v) => Err(HubError {msg: format!("{}: {}", msg.to_string(), v.to_string())}),
21            Ok(v) => Ok(v)
22        }
23    }
24}
25
26impl<T> ResultMsg<T> for Result<T, snap::Error> {
27    fn expect_msg(self, msg: &str) -> Result<T, HubError> {
28        match self {
29            Err(v) => Err(HubError {msg: format!("{}: {}", msg.to_string(), v.to_string())}),
30            Ok(v) => Ok(v)
31        }
32    }
33}
34
35type HubResult<T> = Result<T, HubError>;
36
37pub const HUB_ANNOUNCE_PORT: u16 = 46243;
38
39pub fn read_exact_bytes_from_tcp_stream(tcp_stream: &mut TcpStream, bytes: &mut [u8]) -> HubResult<()> {
40    let bytes_total = bytes.len();
41    let mut bytes_left = bytes_total;
42    while bytes_left > 0 {
43        let buf = &mut bytes[(bytes_total - bytes_left)..bytes_total];
44        let bytes_read = tcp_stream.read(buf).expect_msg("read_exact_bytes_from_tcp_stream: read failed") ?;
45        if bytes_read == 0 {
46            return Err(HubError::new("read_exact_bytes_from_tcp_stream - cannot read bytes"));
47        }
48        bytes_left -= bytes_read;
49    }
50    Ok(())
51}
52
53pub fn read_block_from_tcp_stream(tcp_stream: &mut TcpStream, mut check_digest: Digest) -> HubResult<Vec<u8>> {
54    let mut dwd_read = DigestWithData::default();
55    
56    let dwd_u8 = unsafe {std::mem::transmute::<&mut DigestWithData, &mut [u8; 26 * 8]>(&mut dwd_read)};
57    read_exact_bytes_from_tcp_stream(tcp_stream, dwd_u8) ?;
58    
59    let bytes_total = dwd_read.data as usize;
60    if bytes_total > 250 * 1024 * 1024 {
61        return Err(HubError::new("read_block_from_tcp_stream: bytes_total more than 250mb"))
62    }
63    
64    let mut msg_buf = Vec::new();
65    msg_buf.resize(bytes_total, 0);
66    read_exact_bytes_from_tcp_stream(tcp_stream, &mut msg_buf) ?;
67    
68    check_digest.digest_buffer(&msg_buf);
69    
70    if check_digest != dwd_read.digest {
71        return Err(HubError::new("read_block_from_tcp_stream: block digest check failed"))
72    }
73    
74    let mut dec = snap::Decoder::new();
75    let decompressed = dec.decompress_vec(&msg_buf).expect_msg("read_block_from_tcp_stream: cannot decompress_vec");
76    
77    return decompressed;
78}
79
80pub fn write_exact_bytes_to_tcp_stream(tcp_stream: &mut TcpStream, bytes: &[u8]) -> HubResult<()> {
81    let bytes_total = bytes.len();
82    let mut bytes_left = bytes_total;
83    while bytes_left > 0 {
84        let buf = &bytes[(bytes_total - bytes_left)..bytes_total];
85        let bytes_written = tcp_stream.write(buf).expect_msg("write_exact_bytes_to_tcp_stream: block write fail") ?;
86        if bytes_written == 0 {
87            return Err(HubError::new("write_exact_bytes_to_tcp_stream - cannot write bytes"));
88        }
89        bytes_left -= bytes_written;
90    }
91    Ok(())
92}
93
94pub fn write_block_to_tcp_stream(tcp_stream: &mut TcpStream, msg_buf: &[u8], digest: Digest) -> HubResult<()> {
95    let bytes_total = msg_buf.len();
96    
97    if bytes_total > 250 * 1024 * 1024 {
98        return Err(HubError::new("read_block_from_tcp_stream: bytes_total more than 250mb"))
99    }
100    
101    let mut enc = snap::Encoder::new();
102    let compressed = enc.compress_vec(msg_buf).expect_msg("read_block_from_tcp_stream: cannot compress msgbuf") ?;
103    
104    let mut dwd_write = DigestWithData{
105        digest:digest,
106        data: compressed.len() as u64
107    };
108    
109    dwd_write.digest.digest_buffer(&compressed);
110    
111    let dwd_u8 = unsafe {std::mem::transmute::<&DigestWithData, &[u8; 26 * 8]>(&dwd_write)};
112    write_exact_bytes_to_tcp_stream(tcp_stream, dwd_u8) ?;
113    write_exact_bytes_to_tcp_stream(tcp_stream, &compressed) ?;
114    Ok(())
115}
116
117pub struct HubClient {
118    pub own_addr: HubAddr,
119    pub server_addr: HubAddr,
120    pub uid_alloc: u64,
121    read_thread: Option<thread::JoinHandle<()>>,
122    write_thread: Option<thread::JoinHandle<()>>,
123    pub tx_read: mpsc::Sender<FromHubMsg>,
124    pub rx_read: Option<mpsc::Receiver<FromHubMsg>>,
125    pub tx_write: mpsc::Sender<ToHubMsg>
126}
127
128#[derive(Default, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
129pub struct DigestWithData{
130    pub digest:Digest,
131    pub data: u64
132}
133
134impl HubClient {
135    pub fn connect_to_server(digest: Digest, server_address: SocketAddr, hub_log: HubLog) -> HubResult<HubClient> {
136        
137        // first try local address
138        let local_address = SocketAddr::from(([127, 0, 0, 1], server_address.port()));
139        let server_hubaddr;
140        let mut tcp_stream = if let Ok(stream) = TcpStream::connect(local_address) {
141            server_hubaddr = HubAddr::from_socket_addr(local_address);
142            stream
143        }
144        else {
145            server_hubaddr = HubAddr::from_socket_addr(server_address);
146            TcpStream::connect(server_address).expect_msg("connect_to_hub: cannot connect") ?
147        };
148        
149        let own_addr = HubAddr::from_socket_addr(tcp_stream.local_addr().expect("Cannot get client local address"));
150        
151        let (tx_read, rx_read) = mpsc::channel::<FromHubMsg>();
152        let (tx_write, rx_write) = mpsc::channel::<ToHubMsg>();
153        let tx_read_copy = tx_read.clone();
154        let tx_write_copy = tx_write.clone();
155        
156        let read_thread = {
157            let mut tcp_stream = tcp_stream.try_clone().expect_msg("connect_to_hub: cannot clone socket") ?;
158            let digest = digest.clone();
159            let server_hubaddr = server_hubaddr.clone();
160            let hub_log = hub_log.clone();
161            std::thread::spawn(move || {
162                loop {
163                    match read_block_from_tcp_stream(&mut tcp_stream, digest.clone()) {
164                        Ok(msg_buf) => {
165                            let htc_msg: FromHubMsg = bincode::deserialize(&msg_buf).expect("read_thread hub message deserialize fail - version conflict!");
166                            hub_log.msg("HubClient received", &htc_msg);
167                            tx_read.send(htc_msg).expect("tx_read.send fails - should never happen");
168                        },
169                        Err(e) => {
170                            let _ = tcp_stream.shutdown(Shutdown::Both);
171                            tx_read.send(FromHubMsg {
172                                from: server_hubaddr.clone(),
173                                msg: HubMsg::ConnectionError(e.clone())
174                            }).expect("tx_read.send fails - should never happen");
175                            // lets break rx write
176                            let _ = tx_write_copy.send(ToHubMsg {
177                                to: HubMsgTo::Hub,
178                                msg: HubMsg::ConnectionError(e)
179                            });
180                            return
181                        }
182                    }
183                }
184            })
185        };
186        
187        let write_thread = {
188            let digest = digest.clone();
189            let tx_read = tx_read_copy.clone();
190            let server_hubaddr = server_hubaddr.clone();
191            let hub_log = hub_log.clone();
192            std::thread::spawn(move || { // this one cannot send to the read channel.
193                while let Ok(cth_msg) = rx_write.recv() {
194                    hub_log.msg("HubClient sending", &cth_msg);
195                    match &cth_msg.msg {
196                        HubMsg::ConnectionError(_) => { // we are closed by the read loop
197                            return
198                        },
199                        _ => ()
200                    }
201                    
202                    let msg_buf = bincode::serialize(&cth_msg).expect("write_thread hub message serialize fail - should never happen");
203                    if let Err(e) = write_block_to_tcp_stream(&mut tcp_stream, &msg_buf, digest.clone()) {
204                        // disconnect the socket and send shutdown
205                        let _ = tcp_stream.shutdown(Shutdown::Both);
206                        let _ = tx_read.send(FromHubMsg {
207                            from: server_hubaddr.clone(),
208                            msg: HubMsg::ConnectionError(e)
209                        });
210                        return
211                    }
212                }
213            })
214        };
215        
216        Ok(HubClient {
217            uid_alloc: 0,
218            own_addr: own_addr,
219            server_addr: server_hubaddr,
220            read_thread: Some(read_thread),
221            write_thread: Some(write_thread),
222            tx_read: tx_read_copy,
223            rx_read: Some(rx_read),
224            tx_write: tx_write
225        })
226    }
227    
228    pub fn wait_for_announce(digest: Digest) -> Result<SocketAddr, std::io::Error> {
229        Self::wait_for_announce_on(digest, SocketAddr::from(([0, 0, 0, 0], HUB_ANNOUNCE_PORT)))
230    }
231    
232    pub fn wait_for_announce_on(digest: Digest, announce_address: SocketAddr) -> Result<SocketAddr, std::io::Error> {
233        
234        #[cfg(any(target_os = "linux", target_os = "macos"))]
235        fn reuse_addr(socket: &mut UdpSocket) {
236            unsafe {
237                let optval: libc::c_int = 1;
238                let _ = libc::setsockopt(
239                    socket.as_raw_fd(),
240                    libc::SOL_SOCKET,
241                    libc::SO_REUSEADDR,
242                    &optval as *const _ as *const libc::c_void,
243                    std::mem::size_of_val(&optval) as libc::socklen_t,
244                );
245            }
246        }
247        
248        #[cfg(any(target_os = "windows", target_arch = "wasm32"))]
249        fn reuse_addr(_socket: &mut UdpSocket) {
250        }
251        
252        loop {
253            if let Ok(mut socket) = UdpSocket::bind(announce_address) {
254                // TODO. FIX FOR WINDOWS
255                reuse_addr(&mut socket);
256                let mut dwd_read = DigestWithData::default();
257                let dwd_u8 = unsafe {std::mem::transmute::<&mut DigestWithData, &mut [u8; 26 * 8]>(&mut dwd_read)};
258                
259                let (bytes, from) = socket.recv_from(dwd_u8) ?;
260                if bytes != 26 * 8 {
261                    println!("Announce port wrong bytecount");
262                }
263                
264                let mut dwd_check = DigestWithData{
265                    digest: digest.clone(),
266                    data: dwd_read.data
267                };
268                dwd_check.data = dwd_read.data;
269                dwd_check.digest.buf[0] ^= dwd_read.data;
270                dwd_check.digest.digest_cycle();
271                
272                if dwd_check == dwd_read { // use this to support multiple hubs on one network
273                    let listen_port = dwd_read.data;
274                    return Ok(match from {
275                        SocketAddr::V4(v4) => SocketAddr::V4(SocketAddrV4::new(*v4.ip(), listen_port as u16)),
276                        SocketAddr::V6(v6) => SocketAddr::V6(SocketAddrV6::new(*v6.ip(), listen_port as u16, v6.flowinfo(), v6.scope_id())),
277                    })
278                }
279            }
280            //else{
281            //    println!("wait for announce bind failed");
282            //}
283        }
284    }
285    
286    pub fn join_threads(&mut self) {
287        self.read_thread.take().expect("cant take read thread").join().expect("cant join read thread");
288        self.write_thread.take().expect("cant take write thread").join().expect("cant join write thread");
289    }
290    
291    pub fn alloc_uid(&mut self) -> HubUid {
292        self.uid_alloc += 1;
293        return HubUid {
294            addr: self.own_addr,
295            id: self.uid_alloc
296        }
297    }
298    
299    pub fn get_route_send(&self) -> HubRouteSend {
300        HubRouteSend::Networked {
301            uid_alloc: Arc::new(Mutex::new(0)),
302            tx_write_arc: Arc::new(Mutex::new(Some(self.tx_write.clone()))),
303            own_addr_arc: Arc::new(Mutex::new(Some(self.own_addr)))
304        }
305    }
306    
307    pub fn get_route_send_in_place(&self, route_send: &HubRouteSend) {
308        route_send.update_networked_in_place(Some(self.own_addr), Some(self.tx_write.clone()))
309    }
310}
311
312
313#[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
314pub struct Digest {
315    pub buf: [u64; 25]
316}
317
318impl Default for Digest {
319    fn default() -> Self {Self {buf: [0u64; 25]}}
320}
321
322impl Digest {
323    
324    pub fn generate() -> Digest {
325        let mut result = Digest::default();
326        for i in 0..25 {
327            result.buf[i] ^= time::precise_time_ns();
328            std::thread::sleep(std::time::Duration::from_millis(1));
329            result.digest_cycle();
330        }
331        result
332    }
333    
334    pub fn digest_cycle(&mut self){
335        digest_cycle(self);
336    }
337
338    pub fn digest_other(&mut self, other: &Digest) {
339        for i in 0..25{
340            self.buf[i] ^= other.buf[i]
341        }
342        self.digest_cycle();
343    }
344    
345    pub fn digest_buffer(&mut self, msg_buf: &[u8]) {
346        let digest_u8 = unsafe {std::mem::transmute::<&mut Digest, &mut [u8; 26 * 8]>(self)};
347        let mut s = 0;
348        for i in 0..msg_buf.len() {
349            digest_u8[s] ^= msg_buf[i];
350            s += 1;
351            if s >= 25 * 8 {
352                self.digest_cycle();
353                s = 0;
354            }
355        }
356        self.digest_cycle();
357    }
358    
359}
360
361// digest function to hash tcp data to enable error checking and multiple servers on one network, found various
362// similar versions of this on crates.io and github (as MIT). Not sure which one to attribute it to. Thanks whoever wrote this :)
363
364const RHO: [u32; 24] = [1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44,];
365const PI: [usize; 24] = [10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1,];
366const RC: [u64; 24] = [
367    0x0000000000000001,
368    0x0000000000008082,
369    0x800000000000808a,
370    0x8000000080008000,
371    0x000000000000808b,
372    0x0000000080000001,
373    0x8000000080008081,
374    0x8000000000008009,
375    0x000000000000008a,
376    0x0000000000000088,
377    0x0000000080008009,
378    0x000000008000000a,
379    0x000000008000808b,
380    0x800000000000008b,
381    0x8000000000008089,
382    0x8000000000008003,
383    0x8000000000008002,
384    0x8000000000000080,
385    0x000000000000800a,
386    0x800000008000000a,
387    0x8000000080008081,
388    0x8000000000008080,
389    0x0000000080000001,
390    0x8000000080008008,
391];
392
393#[cfg(not(feature = "no_unroll"))]
394macro_rules!unroll5 {
395    ( $ var: ident, $ body: block) => {
396        {const $ var: usize = 0; $ body;}
397        {const $ var: usize = 1; $ body;}
398        {const $ var: usize = 2; $ body;}
399        {const $ var: usize = 3; $ body;}
400        {const $ var: usize = 4; $ body;}
401    };
402}
403
404#[cfg(feature = "no_unroll")]
405macro_rules!unroll5 {
406    ( $ var: ident, $ body: block) => {
407        for $ var in 0..5 $ body
408    }
409}
410
411#[cfg(not(feature = "no_unroll"))]
412macro_rules!unroll24 {
413    ( $ var: ident, $ body: block) => {
414        {const $ var: usize = 0; $ body;}
415        {const $ var: usize = 1; $ body;}
416        {const $ var: usize = 2; $ body;}
417        {const $ var: usize = 3; $ body;}
418        {const $ var: usize = 4; $ body;}
419        {const $ var: usize = 5; $ body;}
420        {const $ var: usize = 6; $ body;}
421        {const $ var: usize = 7; $ body;}
422        {const $ var: usize = 8; $ body;}
423        {const $ var: usize = 9; $ body;}
424        {const $ var: usize = 10; $ body;}
425        {const $ var: usize = 11; $ body;}
426        {const $ var: usize = 12; $ body;}
427        {const $ var: usize = 13; $ body;}
428        {const $ var: usize = 14; $ body;}
429        {const $ var: usize = 15; $ body;}
430        {const $ var: usize = 16; $ body;}
431        {const $ var: usize = 17; $ body;}
432        {const $ var: usize = 18; $ body;}
433        {const $ var: usize = 19; $ body;}
434        {const $ var: usize = 20; $ body;}
435        {const $ var: usize = 21; $ body;}
436        {const $ var: usize = 22; $ body;}
437        {const $ var: usize = 23; $ body;}
438    };
439}
440
441#[cfg(feature = "no_unroll")]
442macro_rules!unroll24 {
443    ( $ var: ident, $ body: block) => {
444        for $ var in 0..24 $ body
445    }
446}
447
448#[allow(non_upper_case_globals, unused_assignments)]
449pub fn digest_cycle(a:&mut Digest) {
450    for i in 0..24 {
451        let mut array = [0u64; 5];
452        
453        // Theta
454        unroll5!(x, {
455            unroll5!(y, {
456                array[x] ^= a.buf[5 * y + x];
457            });
458        });
459        
460        unroll5!(x, {
461            unroll5!(y, {
462                let t1 = array[(x + 4) % 5];
463                let t2 = array[(x + 1) % 5].rotate_left(1);
464                a.buf[5 * y + x] ^= t1 ^ t2;
465            });
466        });
467        
468        // Rho and pi
469        let mut last = a.buf[1];
470        unroll24!(x, {
471            array[0] = a.buf[PI[x]];
472            a.buf[PI[x]] = last.rotate_left(RHO[x]);
473            last = array[0];
474        });
475        
476        // Chi
477        unroll5!(y_step, {
478            let y = 5 * y_step;
479            
480            unroll5!(x, {
481                array[x] = a.buf[y + x];
482            });
483            
484            unroll5!(x, {
485                let t1 = !array[(x + 1) % 5];
486                let t2 = array[(x + 2) % 5];
487                a.buf[y + x] = array[x] ^ (t1 & t2);
488            });
489        });
490        
491        // Iota
492        a.buf[0] ^= RC[i];
493    }
494}