use std::io;
use std::io::Result;
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::thread;
use std::thread::sleep;
use std::time::Duration;
use byteorder::{ReadBytesExt, WriteBytesExt};
use columnar::Columnar;
use serde::{Deserialize, Serialize};
type ByteOrder = byteorder::BigEndian;
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, Columnar)]
pub struct MessageHeader {
pub channel: usize,
pub source: usize,
pub target_lower: usize,
pub target_upper: usize,
pub length: usize,
pub seqno: usize,
}
impl MessageHeader {
const FIELDS: usize = 6;
#[inline]
pub fn try_read(bytes: &[u8]) -> Option<MessageHeader> {
let mut cursor = io::Cursor::new(bytes);
let mut buffer = [0; Self::FIELDS];
cursor.read_u64_into::<ByteOrder>(&mut buffer).ok()?;
let header = MessageHeader {
channel: buffer[0] as usize,
source: buffer[1] as usize,
target_lower: buffer[2] as usize,
target_upper: buffer[3] as usize,
length: buffer[4] as usize,
seqno: buffer[5] as usize,
};
if bytes.len() >= header.required_bytes() {
Some(header)
} else {
None
}
}
#[inline]
pub fn write_to<W: ::std::io::Write>(&self, writer: &mut W) -> Result<()> {
let mut buffer = [0u8; std::mem::size_of::<u64>() * Self::FIELDS];
let mut cursor = io::Cursor::new(&mut buffer[..]);
cursor.write_u64::<ByteOrder>(self.channel as u64)?;
cursor.write_u64::<ByteOrder>(self.source as u64)?;
cursor.write_u64::<ByteOrder>(self.target_lower as u64)?;
cursor.write_u64::<ByteOrder>(self.target_upper as u64)?;
cursor.write_u64::<ByteOrder>(self.length as u64)?;
cursor.write_u64::<ByteOrder>(self.seqno as u64)?;
writer.write_all(&buffer[..])
}
#[inline]
pub fn required_bytes(&self) -> usize {
self.header_bytes() + self.length
}
#[inline(always)]
pub fn header_bytes(&self) -> usize {
std::mem::size_of::<u64>() * Self::FIELDS
}
}
pub fn create_sockets(addresses: Vec<String>, my_index: usize, noisy: bool) -> Result<Vec<Option<TcpStream>>> {
let hosts1 = Arc::new(addresses);
let hosts2 = Arc::clone(&hosts1);
let start_task = thread::spawn(move || start_connections(hosts1, my_index, noisy));
let await_task = thread::spawn(move || await_connections(hosts2, my_index, noisy));
let mut results = start_task.join().unwrap()?;
results.push(None);
let to_extend = await_task.join().unwrap()?;
results.extend(to_extend);
if noisy { println!("worker {}:\tinitialization complete", my_index) }
Ok(results)
}
pub fn start_connections(addresses: Arc<Vec<String>>, my_index: usize, noisy: bool) -> Result<Vec<Option<TcpStream>>> {
let results = addresses.iter().take(my_index).enumerate().map(|(index, address)| {
loop {
match TcpStream::connect(address) {
Ok(mut stream) => {
stream.set_nodelay(true).expect("set_nodelay call failed");
stream.write_u64::<ByteOrder>(my_index as u64).expect("failed to encode/send worker index");
if noisy { println!("worker {}:\tconnection to worker {}", my_index, index); }
break Some(stream);
},
Err(error) => {
println!("worker {}:\terror connecting to worker {}: {}; retrying", my_index, index, error);
sleep(Duration::from_secs(1));
},
}
}
}).collect();
Ok(results)
}
pub fn await_connections(addresses: Arc<Vec<String>>, my_index: usize, noisy: bool) -> Result<Vec<Option<TcpStream>>> {
let mut results: Vec<_> = (0..(addresses.len() - my_index - 1)).map(|_| None).collect();
let listeners = addresses[my_index].split_whitespace().map(TcpListener::bind).collect::<Result<Vec<_>>>()?;
for listener in listeners.iter() { listener.set_nonblocking(true).expect("Couldn't set nonblocking"); }
while results.iter().any(Option::is_none) {
let mut received = false;
for listener in listeners.iter() {
match listener.accept() {
Ok((mut stream, _)) => {
stream.set_nodelay(true).expect("set_nodelay call failed");
let identifier = stream.read_u64::<ByteOrder>().expect("failed to decode worker index") as usize;
results[identifier - my_index - 1] = Some(stream);
if noisy { println!("worker {}:\tconnection from worker {}", my_index, identifier); }
received = true;
}
Err(e) => { if e.kind() != io::ErrorKind::WouldBlock { return Err(e); } }
}
}
if !received {
println!("awaiting connections (at {:?}/{:?})", results.iter().filter(|x| x.is_some()).count(), results.len());
sleep(Duration::from_secs(1));
}
}
Ok(results)
}