bissel/host/
host.rs

1// Tokio for async
2use tokio::net::TcpListener;
3use tokio::net::TcpStream;
4use tokio::runtime::Runtime;
5use tokio::sync::Mutex; // as TokioMutex;
6use tokio::task::JoinHandle;
7// Tracing for logging
8use hex_slice::*;
9use tracing::*;
10// Postcard is the default de/serializer
11use postcard::*;
12// Multi-threading primitives
13use std::ops::{Deref, DerefMut};
14use std::sync::Arc;
15use std::sync::Mutex as StdMutex;
16// Misc other imports
17use std::error::Error;
18use std::net::SocketAddr;
19use std::result::Result;
20
21use crate::*;
22
23/// Named task handle for each Hosted connection
24#[derive(Debug)]
25pub struct Connection {
26    handle: JoinHandle<()>,
27    stream_addr: SocketAddr,
28    name: String,
29}
30
31/// Central coordination process, which stores published data and responds to requests
32#[derive(Debug)]
33pub struct Host {
34    pub cfg: HostConfig,
35    pub runtime: Runtime,
36    pub connections: Arc<StdMutex<Vec<Connection>>>,
37    pub task_listen: Option<JoinHandle<()>>,
38    pub store: Option<sled::Db>,
39    pub reply_count: Arc<Mutex<usize>>,
40}
41
42impl Host {
43    /// Allow Host to begin accepting incoming connections
44    #[tracing::instrument]
45    pub fn start(&mut self) -> Result<(), Box<dyn Error>> {
46        let ip = crate::get_ip(&self.cfg.interface)?;
47        let raw_addr = ip + ":" + &self.cfg.socket_num.to_string();
48        let addr: SocketAddr = raw_addr.parse()?;
49
50        let (max_buffer_size, max_name_size) = (self.cfg.max_buffer_size, self.cfg.max_name_size);
51
52        let connections_clone = self.connections.clone();
53
54        let db = match self.store.clone() {
55            Some(db) => db,
56            None => {
57                error!("Must open a sled database to start the Host");
58                panic!("Must open a sled database to start the Host");
59            }
60        };
61
62        let counter = self.reply_count.clone();
63
64        let task_listen = self.runtime.spawn(async move {
65            let listener = TcpListener::bind(addr).await.unwrap();
66
67            loop {
68                let (stream, stream_addr) = listener.accept().await.unwrap();
69                // TO_DO: The handshake function is not always happy
70                let (stream, name) = handshake(stream, max_buffer_size, max_name_size).await;
71                info!("Host received connection from {:?}", &name);
72
73                let db = db.clone();
74                let counter = counter.clone();
75                let connections = Arc::clone(&connections_clone.clone());
76
77                let handle = tokio::spawn(async move {
78                    process(stream, db, counter, max_buffer_size).await;
79                });
80                let connection = Connection {
81                    handle,
82                    stream_addr,
83                    name,
84                };
85                // dbg!(&connection);
86
87                connections.lock().unwrap().push(connection);
88            }
89        });
90
91        self.task_listen = Some(task_listen);
92
93        Ok(())
94    }
95
96    /// Shuts down all networking connections and releases Host object handle
97    /// This also makes sure that temporary sled::Db's built are also dropped
98    /// following the shutdown of a Host
99    #[tracing::instrument]
100    pub fn stop(mut self) -> Result<(), Box<dyn Error>> {
101        for conn in self.connections.lock().unwrap().deref_mut() {
102            // println!("Aborting connection: {}", conn.name);
103            info!("Aborting connection: {}", conn.name);
104            conn.handle.abort();
105        }
106        self.store = None;
107
108        Ok(())
109    }
110
111    /// Print information about all Host connections
112    #[no_mangle]
113    pub fn print_connections(&mut self) -> Result<(), Box<dyn Error + '_>> {
114        println!("Connections:");
115        for conn in self.connections.lock()?.deref() {
116            let name = conn.name.clone();
117            println!("\t- {}:{}", name, &conn.stream_addr);
118        }
119        Ok(())
120    }
121}
122
123/// Initiate a connection with a Node
124#[inline]
125#[tracing::instrument]
126async fn handshake(
127    stream: TcpStream,
128    max_buffer_size: usize,
129    max_name_size: usize,
130) -> (TcpStream, String) {
131    // Handshake
132    // let mut buf = [0u8; 4096];
133    // TO_DO_PART_A: This seems fine, but PART_B errors out for some reason?
134    let mut buf = Vec::with_capacity(max_buffer_size);
135    info!("Starting handshake");
136    let mut name: String = String::with_capacity(max_name_size);
137    let mut count = 0;
138    stream.readable().await.unwrap();
139    loop {
140        info!("In handshake loop");
141        match stream.try_read_buf(&mut buf) {
142            Ok(n) => {
143                name = match std::str::from_utf8(&buf[..n]) {
144                    Ok(name) => {
145                        info!("Received connection from {}", &name);
146                        name.to_owned()
147                    }
148                    Err(e) => {
149                        error!("Error occurred during handshake on host-side: {} on byte string: {:?}, which in hex is: {:x}", e,&buf[..n],&buf[..n].as_hex());
150
151                        //let emsg = format!("Error parsing the following bytes: {:?}",&buf[..n]);
152                        //panic!("{}",emsg);
153
154                        // println!("Error during handshake (Host-side): {:?}", e);
155                        "HOST_CONNECTION_ERROR".to_owned()
156                    }
157                };
158                break;
159            }
160            Err(e) => {
161                if e.kind() == std::io::ErrorKind::WouldBlock {
162                    count += 1;
163                    if count > 20 {
164                        error!("Host Handshake not unblocking!");
165                        panic!("Stream won't unblock");
166                    }
167                } else {
168                    error!("{:?}", e);
169                    // println!("Error: {:?}", e);
170                }
171            }
172        }
173    }
174    info!("Returning from handshake: ({:?}, {})", &stream, &name);
175    (stream, name)
176}
177
178/// Host process for handling incoming connections from Nodes
179#[tracing::instrument]
180#[inline]
181async fn process(
182    stream: TcpStream,
183    db: sled::Db,
184    count: Arc<Mutex<usize>>,
185    max_buffer_size: usize,
186) {
187    let mut buf = [0u8; 10_000];
188    // TO_DO_PART_B: Tried to with try_read_buf(), but seems to panic?
189    // let mut buf = Vec::with_capacity(max_buffer_size);
190    loop {
191        stream.readable().await.unwrap();
192        // dbg!(&count);
193        match stream.try_read(&mut buf) {
194            Ok(0) => break, // TO_DO: break or continue?
195            Ok(n) => {
196                stream.writable().await.unwrap();
197
198                let bytes = &buf[..n];
199                let msg: GenericMsg = match from_bytes(bytes) {
200                    Ok(msg) => msg,
201                    Err(e) => {
202                        error!("Had received Msg of {} bytes: {:?}, Error: {}", n, bytes, e);
203                        panic!("{}", e);
204                    }
205                };
206                // dbg!(&msg);
207
208                match msg.msg_type {
209                    MsgType::SET => {
210                        // println!("received {} bytes, to be assigned to: {}", n, &msg.name);
211                        let db_result = match db.insert(msg.topic.as_bytes(), bytes) {
212                            Ok(_prev_msg) => "SUCCESS".to_string(),
213                            Err(e) => e.to_string(),
214                        };
215
216                        loop {
217                            match stream.try_write(db_result.as_bytes()) {
218                                Ok(_n) => {
219                                    // println!("Successfully replied with {} bytes", n);
220                                    let mut count = count.lock().await; //.unwrap();
221                                    *count += 1;
222                                    break;
223                                }
224                                Err(e) => {
225                                    if e.kind() == std::io::ErrorKind::WouldBlock {}
226                                    continue;
227                                }
228                            }
229                        }
230                    }
231                    MsgType::GET => loop {
232                        /*
233                        println!(
234                            "received {} bytes, asking for reply on topic: {}",
235                            n, &msg.name
236                        );*/
237
238                        let return_bytes = match db.get(&msg.topic).unwrap() {
239                            Some(msg) => msg,
240                            None => {
241                                let e: String =
242                                    format!("Error: no message with the name {} exists", &msg.name);
243                                error!("{}", &e);
244                                e.as_bytes().into()
245                            }
246                        };
247
248                        match stream.try_write(&return_bytes) {
249                            Ok(_n) => {
250                                // println!("Successfully replied with {} bytes", n);
251                                let mut count = count.lock().await; //.unwrap();
252                                *count += 1;
253                                break;
254                            }
255                            Err(e) => {
256                                if e.kind() == std::io::ErrorKind::WouldBlock {}
257                                continue;
258                            }
259                        }
260                    },
261                }
262            }
263            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
264                // println!("Error::WouldBlock: {:?}", e);
265                continue;
266            }
267            Err(e) => {
268                // println!("Error: {:?}", e);
269                error!("Error: {:?}", e);
270            }
271        }
272    }
273}