1use tokio::net::TcpListener;
3use tokio::net::TcpStream;
4use tokio::runtime::Runtime;
5use tokio::sync::Mutex; use tokio::task::JoinHandle;
7use hex_slice::*;
9use tracing::*;
10use postcard::*;
12use std::ops::{Deref, DerefMut};
14use std::sync::Arc;
15use std::sync::Mutex as StdMutex;
16use std::error::Error;
18use std::net::SocketAddr;
19use std::result::Result;
20
21use crate::*;
22
23#[derive(Debug)]
25pub struct Connection {
26 handle: JoinHandle<()>,
27 stream_addr: SocketAddr,
28 name: String,
29}
30
31#[derive(Debug)]
33pub struct Host {
34 pub cfg: HostConfig,
35 pub runtime: Runtime,
36 pub connections: Arc<StdMutex<Vec<Connection>>>,
37 pub task_listen: Option<JoinHandle<()>>,
38 pub store: Option<sled::Db>,
39 pub reply_count: Arc<Mutex<usize>>,
40}
41
42impl Host {
43 #[tracing::instrument]
45 pub fn start(&mut self) -> Result<(), Box<dyn Error>> {
46 let ip = crate::get_ip(&self.cfg.interface)?;
47 let raw_addr = ip + ":" + &self.cfg.socket_num.to_string();
48 let addr: SocketAddr = raw_addr.parse()?;
49
50 let (max_buffer_size, max_name_size) = (self.cfg.max_buffer_size, self.cfg.max_name_size);
51
52 let connections_clone = self.connections.clone();
53
54 let db = match self.store.clone() {
55 Some(db) => db,
56 None => {
57 error!("Must open a sled database to start the Host");
58 panic!("Must open a sled database to start the Host");
59 }
60 };
61
62 let counter = self.reply_count.clone();
63
64 let task_listen = self.runtime.spawn(async move {
65 let listener = TcpListener::bind(addr).await.unwrap();
66
67 loop {
68 let (stream, stream_addr) = listener.accept().await.unwrap();
69 let (stream, name) = handshake(stream, max_buffer_size, max_name_size).await;
71 info!("Host received connection from {:?}", &name);
72
73 let db = db.clone();
74 let counter = counter.clone();
75 let connections = Arc::clone(&connections_clone.clone());
76
77 let handle = tokio::spawn(async move {
78 process(stream, db, counter, max_buffer_size).await;
79 });
80 let connection = Connection {
81 handle,
82 stream_addr,
83 name,
84 };
85 connections.lock().unwrap().push(connection);
88 }
89 });
90
91 self.task_listen = Some(task_listen);
92
93 Ok(())
94 }
95
96 #[tracing::instrument]
100 pub fn stop(mut self) -> Result<(), Box<dyn Error>> {
101 for conn in self.connections.lock().unwrap().deref_mut() {
102 info!("Aborting connection: {}", conn.name);
104 conn.handle.abort();
105 }
106 self.store = None;
107
108 Ok(())
109 }
110
111 #[no_mangle]
113 pub fn print_connections(&mut self) -> Result<(), Box<dyn Error + '_>> {
114 println!("Connections:");
115 for conn in self.connections.lock()?.deref() {
116 let name = conn.name.clone();
117 println!("\t- {}:{}", name, &conn.stream_addr);
118 }
119 Ok(())
120 }
121}
122
123#[inline]
125#[tracing::instrument]
126async fn handshake(
127 stream: TcpStream,
128 max_buffer_size: usize,
129 max_name_size: usize,
130) -> (TcpStream, String) {
131 let mut buf = Vec::with_capacity(max_buffer_size);
135 info!("Starting handshake");
136 let mut name: String = String::with_capacity(max_name_size);
137 let mut count = 0;
138 stream.readable().await.unwrap();
139 loop {
140 info!("In handshake loop");
141 match stream.try_read_buf(&mut buf) {
142 Ok(n) => {
143 name = match std::str::from_utf8(&buf[..n]) {
144 Ok(name) => {
145 info!("Received connection from {}", &name);
146 name.to_owned()
147 }
148 Err(e) => {
149 error!("Error occurred during handshake on host-side: {} on byte string: {:?}, which in hex is: {:x}", e,&buf[..n],&buf[..n].as_hex());
150
151 "HOST_CONNECTION_ERROR".to_owned()
156 }
157 };
158 break;
159 }
160 Err(e) => {
161 if e.kind() == std::io::ErrorKind::WouldBlock {
162 count += 1;
163 if count > 20 {
164 error!("Host Handshake not unblocking!");
165 panic!("Stream won't unblock");
166 }
167 } else {
168 error!("{:?}", e);
169 }
171 }
172 }
173 }
174 info!("Returning from handshake: ({:?}, {})", &stream, &name);
175 (stream, name)
176}
177
178#[tracing::instrument]
180#[inline]
181async fn process(
182 stream: TcpStream,
183 db: sled::Db,
184 count: Arc<Mutex<usize>>,
185 max_buffer_size: usize,
186) {
187 let mut buf = [0u8; 10_000];
188 loop {
191 stream.readable().await.unwrap();
192 match stream.try_read(&mut buf) {
194 Ok(0) => break, Ok(n) => {
196 stream.writable().await.unwrap();
197
198 let bytes = &buf[..n];
199 let msg: GenericMsg = match from_bytes(bytes) {
200 Ok(msg) => msg,
201 Err(e) => {
202 error!("Had received Msg of {} bytes: {:?}, Error: {}", n, bytes, e);
203 panic!("{}", e);
204 }
205 };
206 match msg.msg_type {
209 MsgType::SET => {
210 let db_result = match db.insert(msg.topic.as_bytes(), bytes) {
212 Ok(_prev_msg) => "SUCCESS".to_string(),
213 Err(e) => e.to_string(),
214 };
215
216 loop {
217 match stream.try_write(db_result.as_bytes()) {
218 Ok(_n) => {
219 let mut count = count.lock().await; *count += 1;
222 break;
223 }
224 Err(e) => {
225 if e.kind() == std::io::ErrorKind::WouldBlock {}
226 continue;
227 }
228 }
229 }
230 }
231 MsgType::GET => loop {
232 let return_bytes = match db.get(&msg.topic).unwrap() {
239 Some(msg) => msg,
240 None => {
241 let e: String =
242 format!("Error: no message with the name {} exists", &msg.name);
243 error!("{}", &e);
244 e.as_bytes().into()
245 }
246 };
247
248 match stream.try_write(&return_bytes) {
249 Ok(_n) => {
250 let mut count = count.lock().await; *count += 1;
253 break;
254 }
255 Err(e) => {
256 if e.kind() == std::io::ErrorKind::WouldBlock {}
257 continue;
258 }
259 }
260 },
261 }
262 }
263 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
264 continue;
266 }
267 Err(e) => {
268 error!("Error: {:?}", e);
270 }
271 }
272 }
273}