#![warn(missing_docs)]
#[macro_use]
extern crate serde_derive;
extern crate serde_bytes;
extern crate serde_cbor;
mod frame;
mod local_db;
mod peer;
mod server;
use siphasher::sip::SipHasher;
use std::hash::{Hash, Hasher};
use std::io::Cursor;
use std::io::Write;
use std::io::{Error, Read};
use std::net::TcpStream;
use bytes::{Buf, BufMut, BytesMut};
use std::net::{IpAddr, SocketAddr};
pub(crate) const HDRL: usize = 8;
pub(crate) const CIDL: usize = 4;
pub(crate) const KEYL: usize = 8;
pub(crate) const HDRKEYL: usize = HDRL + KEYL;
pub(crate) const MSGMAXSIZE: usize = 0xffffff;
const KEEPALIVE: u64 = 5;
pub struct MsgHdr {
thlen: u32,
cid: u32,
key: u64,
}
impl MsgHdr {
pub fn new(len: u32, cid: u32, key: u64) -> MsgHdr {
MsgHdr {
thlen: hdr_set_len(len),
cid,
key,
}
}
pub fn get_type(&self) -> u8 {
hdr_get_type(self.thlen)
}
pub fn get_hdrkey_len() -> usize {
HDRKEYL
}
pub fn set_len(&mut self, len: u32) {
self.thlen = hdr_set_len(len);
}
pub fn get_len(&self) -> u32 {
hdr_get_len(self.thlen)
}
pub fn set_cid(&mut self, cid: u32) {
self.cid = cid;
}
pub fn get_cid(&self) -> u32 {
self.cid
}
pub fn set_key(&mut self, key: u64) {
self.key = key;
}
pub fn get_key(&self) -> u64 {
self.key
}
pub fn encode(&self) -> Vec<u8> {
let mut msgv = write_hdr(self.get_len() as usize, self.get_cid()).to_vec();
msgv.extend(write_key(self.get_key()).to_vec());
msgv
}
pub fn decode(buf: Vec<u8>) -> MsgHdr {
MsgHdr::new(
read_hdr_len(&buf) as u32,
read_cid_from_hdr(&buf),
read_key_from_hdr(&buf),
)
}
#[inline]
pub fn do_hash(t: &[String]) -> u64 {
let mut s = SipHasher::new();
for item in t {
item.hash(&mut s);
}
s.finish()
}
#[inline]
pub fn select_cid(key: u64) -> u32 {
key as u32
}
#[inline]
pub fn addr2str(addr: &SocketAddr) -> String {
let ipaddr = addr.ip();
match ipaddr {
IpAddr::V4(v4) => {
let v4oct = v4.octets();
let v4str = format!(
"{}.{}.{}.{}:{}",
v4oct[0],
v4oct[1],
v4oct[2],
v4oct[3],
addr.port()
);
v4str
}
IpAddr::V6(v6) => {
let v6seg = v6.segments();
let v6str = format!(
"[{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}]:{}",
v6seg[0],
v6seg[1],
v6seg[2],
v6seg[3],
v6seg[4],
v6seg[5],
v6seg[6],
v6seg[7],
addr.port()
);
v6str
}
}
}
}
fn hdr_set_len(len: u32) -> u32 {
77 << 24 | len & 0xffffff
}
fn hdr_get_len(thlen: u32) -> u32 {
thlen & 0xffffff
}
fn hdr_get_type(thlen: u32) -> u8 {
(thlen >> 24) as u8
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Msg {
uid: String,
channel: String,
#[serde(with = "serde_bytes")]
message: Vec<u8>,
}
impl Msg {
#[inline]
pub fn new(uid: String, channel: String, message: Vec<u8>) -> Msg {
Msg {
uid,
channel,
message,
}
}
#[inline]
pub fn set_uid(mut self, uid: String) -> Msg {
self.uid = uid;
self
}
#[inline]
pub fn set_channel(mut self, channel: String) -> Msg {
self.channel = channel;
self
}
#[inline]
pub fn set_message(mut self, message: Vec<u8>) -> Msg {
self.message = message;
self
}
#[inline]
pub fn get_uid(&self) -> &String {
&self.uid
}
#[inline]
pub fn get_channel(&self) -> &String {
&self.channel
}
#[inline]
pub fn get_message(&self) -> &Vec<u8> {
&self.message
}
#[inline]
pub fn get_message_len(&self) -> usize {
self.message.len()
}
#[inline]
pub fn get_mut_message(&mut self) -> &mut Vec<u8> {
&mut self.message
}
#[inline]
pub fn encode(&self) -> Vec<u8> {
let encoded = serde_cbor::to_vec(self);
match encoded {
Ok(encoded) => encoded,
Err(err) => {
println!("Error on encode: {}", err);
Vec::new()
}
}
}
#[inline]
pub fn decode(slice: &[u8]) -> Msg {
let value = serde_cbor::from_slice(slice);
match value {
Ok(value) => value,
Err(err) => {
println!("Error on decode: {}", err);
Msg {
uid: "".to_string(),
channel: "".to_string(),
message: Vec::new(),
} }
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct MsgVec {
#[serde(with = "serde_bytes")]
encoded_msg: Vec<u8>, }
impl MsgVec {
pub fn new(encoded_msg: &Vec<u8>) -> MsgVec {
MsgVec {
encoded_msg: encoded_msg.clone(),
}
}
pub fn get(&self) -> &Vec<u8> {
&self.encoded_msg
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ResyncMsg {
resync_message: Vec<MsgVec>,
}
impl ResyncMsg {
#[inline]
pub fn new(messages: &Vec<Vec<u8>>) -> ResyncMsg {
let mut rmsg = ResyncMsg {
resync_message: Vec::new(),
};
for msg in messages {
rmsg.resync_message.push(MsgVec::new(&msg));
}
rmsg
}
#[inline]
pub fn len(&self) -> usize {
self.resync_message.len()
}
#[inline]
pub fn get_messages(&self) -> Vec<Vec<u8>> {
let mut messages = Vec::new();
for msg in self.resync_message.iter() {
let msg = msg.get();
messages.push(msg.clone());
}
messages
}
#[inline]
pub fn encode(&self) -> Vec<u8> {
let encoded = serde_cbor::to_vec(self);
match encoded {
Ok(encoded) => encoded,
Err(err) => {
println!("Error on resync encode: {}", err);
Vec::new()
}
}
}
#[inline]
pub fn decode(slice: &[u8]) -> ResyncMsg {
let value = serde_cbor::from_slice(slice);
match value {
Ok(value) => value,
Err(_) => {
ResyncMsg {
resync_message: Vec::new(),
} }
}
}
}
pub struct MsgConn {
uid: String,
channel: String,
key: Option<u64>,
stream: Option<TcpStream>,
}
impl MsgConn {
#[inline]
pub fn new(uid: String, channel: String) -> MsgConn {
MsgConn {
uid,
channel,
key: None,
stream: None,
}
}
#[inline]
pub fn get_uid(&self) -> String {
self.uid.clone()
}
#[inline]
pub fn get_channel(&self) -> String {
self.channel.clone()
}
#[inline]
pub fn get_key(&self) -> Option<u64> {
self.key
}
#[inline]
pub fn connect_with_message(mut self, raddr: SocketAddr, msg: Vec<u8>) -> MsgConn {
let msg = Msg::new(self.get_uid(), self.get_channel(), msg);
match TcpStream::connect(raddr) {
Ok(mut stream) => {
let _val = stream.set_nodelay(true);
if self.get_key().is_none() {
let mut keys = Vec::new();
let laddr = match stream.local_addr() {
Ok(laddr) => laddr,
Err(_) => {
let addr = "0.0.0.0:0";
addr.parse::<SocketAddr>().unwrap()
}
};
keys.push(MsgHdr::addr2str(&laddr));
keys.push(self.get_uid());
keys.push(self.get_channel());
let key = MsgHdr::do_hash(&keys);
self.key = Some(key);
}
let encoded_msg = msg.encode();
let key = self.get_key().unwrap();
let keyv = write_key(key);
let mut msgv = write_hdr_with_capacity(
encoded_msg.len(),
MsgHdr::select_cid(key),
HDRKEYL + encoded_msg.len(),
);
msgv.extend(keyv);
msgv.extend(encoded_msg);
let msgv = msgv.freeze();
match stream.write_all(msgv.as_ref()) {
Ok(_) => self.stream = Some(stream),
Err(err) => {
println!("Send error {}", err);
self.stream = None;
}
}
self
}
Err(_) => {
println!("Could not connect to server {}", raddr);
self
}
}
}
#[inline]
pub fn connect(self, raddr: SocketAddr) -> MsgConn {
self.connect_with_message(raddr, Vec::new())
}
#[inline]
pub fn send_message(mut self, msg: Vec<u8>) -> MsgConn {
let message = Msg::new(self.get_uid(), self.get_channel(), msg);
let encoded_msg = message.encode();
let key = self.get_key().unwrap();
let keyv = write_key(key);
let mut msgv = write_hdr_with_capacity(
encoded_msg.len(),
MsgHdr::select_cid(key),
HDRKEYL + encoded_msg.len(),
);
msgv.extend(keyv);
msgv.extend(encoded_msg);
let msgv = msgv.freeze();
let mut stream = self.stream.unwrap();
match stream.write_all(msgv.as_ref()) {
Ok(_) => self.stream = Some(stream),
Err(err) => {
println!("Send error {}", err);
self.stream = None;
}
}
self
}
#[inline]
pub fn read_message(mut self) -> (MsgConn, Vec<u8>) {
let stream = self.stream.unwrap();
loop {
let tuple = read_n(&stream, HDRKEYL);
let status = tuple.0;
if let Ok(0) = status {
println!("Read failed: eof");
self.stream = None;
return (self, Vec::new());
}
let buf = tuple.1;
if buf.is_empty() {
continue;
}
if read_hdr_type(buf.as_slice()) != 'M' as u32 {
continue;
}
let hdr_len = read_hdr_len(buf.as_slice());
if 0 == hdr_len {
continue;
}
let tuple = read_n(&stream, hdr_len);
let status = tuple.0;
if let Ok(0) = status {
continue;
};
let payload = tuple.1;
if payload.len() != (hdr_len as usize) {
continue;
}
let decoded_message = Msg::decode(payload.as_slice());
if 0 == decoded_message.get_message_len() {
continue;
}
self.stream = Some(stream);
return (self, decoded_message.get_message().to_owned());
}
}
#[inline]
pub fn close(mut self) -> MsgConn {
if self.stream.is_some() {
drop(self.stream.unwrap());
}
self.stream = None;
self
}
}
#[inline]
pub(crate) fn read_hdr_type(hdr: &[u8]) -> u32 {
if hdr.len() < HDRL {
return 0;
}
let mut buf = Cursor::new(&hdr[..]);
let num = buf.get_u32_be();
num >> 24
}
fn read_hdr_len(hdr: &[u8]) -> usize {
if hdr.len() < HDRL {
return 0;
}
let mut buf = Cursor::new(&hdr[..]);
let num = buf.get_u32_be();
(num & 0xffffff) as usize
}
fn write_hdr(len: usize, cid: u32) -> BytesMut {
let hdr = (('M' as u32) << 24) | len as u32;
let mut msgv = BytesMut::with_capacity(HDRKEYL);
msgv.put_u32_be(hdr);
msgv.put_u32_be(cid);
msgv
}
fn write_hdr_with_capacity(len: usize, cid: u32, cap: usize) -> BytesMut {
let hdr = (('M' as u32) << 24) | len as u32;
let mut msgv = BytesMut::with_capacity(cap);
msgv.put_u32_be(hdr);
msgv.put_u32_be(cid);
msgv
}
fn write_hdr_without_cid(len: usize) -> BytesMut {
let hdr = (('M' as u32) << 24) | len as u32;
let mut msgv = BytesMut::with_capacity(HDRL);
msgv.put_u32_be(hdr);
msgv
}
#[inline]
pub(crate) fn write_len_to_hdr(len: usize, mut hdrv: BytesMut) -> BytesMut {
if hdrv.len() < HDRL {
return BytesMut::new();
}
let tail = hdrv.split_off(HDRL - CIDL);
let mut nhdrv = write_hdr_without_cid(len);
nhdrv.extend(tail);
nhdrv
}
fn write_key(val: u64) -> BytesMut {
let key = val;
let mut msgv = BytesMut::with_capacity(KEYL);
msgv.put_u64_be(key);
msgv
}
fn write_hdr_with_key(len: usize, key: u64) -> BytesMut {
let mut hdrv = write_hdr(len, MsgHdr::select_cid(key));
hdrv.extend(write_key(key));
hdrv
}
fn read_key_from_hdr(keyv: &[u8]) -> u64 {
if keyv.len() < HDRKEYL {
return 0;
}
let mut buf = Cursor::new(&keyv[HDRL..]);
buf.get_u64_be()
}
fn read_cid_from_hdr(hdrv: &[u8]) -> u32 {
if hdrv.len() < HDRL {
return 0;
}
let mut buf = Cursor::new(&hdrv[(HDRL - CIDL)..]);
buf.get_u32_be()
}
#[inline]
pub fn has_peer(peer: &Option<SocketAddr>) -> bool {
peer::has_peer(peer)
}
fn read_n<R>(reader: R, bytes_to_read: usize) -> (Result<usize, Error>, Vec<u8>)
where
R: Read,
{
let mut buf = Vec::with_capacity(bytes_to_read);
let mut chunk = reader.take(bytes_to_read as u64);
let status = chunk.read_to_end(&mut buf);
(status, buf)
}
#[inline]
pub fn server_run(
address: SocketAddr,
peer: Option<SocketAddr>,
keyval: String,
keyaddr: String,
hist_limit: usize,
debug_flags: u64,
) {
server::run(address, peer, keyval, keyaddr, hist_limit, debug_flags);
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
use std::thread;
use std::time::Duration;
#[test]
fn test_read_hdr_len_one() {
let orig_len = 1;
let hdrv = write_hdr(orig_len, 0x1);
let len = read_hdr_len(hdrv.as_ref());
assert_eq!(len, orig_len);
}
#[test]
fn test_read_hdr_len_16k() {
let orig_len = 16000;
let hdrv = write_hdr_with_capacity(orig_len, 0x1, HDRKEYL + orig_len);
let len = read_hdr_len(hdrv.as_ref());
assert_eq!(len, orig_len);
}
#[test]
fn test_read_hdr_len_16_7m() {
let orig_len = 16777215;
let hdrv = write_hdr(orig_len, 0x1);
let len = read_hdr_len(hdrv.as_ref());
assert_eq!(len, orig_len);
}
#[test]
fn test_encode_decode_msg() {
let uid = "User".to_string();
let channel = "Channel".to_string();
let msg = "a test msg".to_string().into_bytes();
let orig_msg = Msg::new(uid, channel, msg);
let encoded_msg = orig_msg.encode();
let decoded_msg = Msg::decode(&encoded_msg);
assert_eq!(decoded_msg.uid, orig_msg.uid);
assert_eq!(decoded_msg.channel, orig_msg.channel);
assert_eq!(decoded_msg.message, orig_msg.message);
}
#[test]
fn test_encode_decode_resync_msg() {
let uid = "User".to_string();
let channel = "Channel".to_string();
let msg = "a test msg".to_string().into_bytes();
let orig_msg = Msg::new(uid, channel, msg);
let encoded_msg = orig_msg.encode();
let uid2 = "User two".to_string();
let channel2 = "Channel two".to_string();
let msg2 = "a test msg two".to_string().into_bytes();
let orig_msg2 = Msg::new(uid2, channel2, msg2);
let encoded_msg2 = orig_msg2.encode();
let vec = vec![encoded_msg, encoded_msg2];
let rmsg = ResyncMsg::new(&vec);
let encoded_resync_msg: Vec<u8> = rmsg.encode();
let decoded_resync_msg: ResyncMsg = ResyncMsg::decode(&encoded_resync_msg);
let mut cnt = 0;
for msg in decoded_resync_msg.get_messages() {
let decoded_msg = Msg::decode(&msg);
if 0 == cnt {
assert_eq!(decoded_msg.uid, orig_msg.uid);
assert_eq!(decoded_msg.channel, orig_msg.channel);
assert_eq!(decoded_msg.message, orig_msg.message);
} else {
assert_eq!(decoded_msg.uid, orig_msg2.uid);
assert_eq!(decoded_msg.channel, orig_msg2.channel);
assert_eq!(decoded_msg.message, orig_msg2.message);
}
cnt += 1;
}
}
#[test]
fn test_set_get_msg() {
let uid = "User".to_string();
let channel = "Channel".to_string();
let msg = "a test msg".to_string().into_bytes();
let orig_msg = Msg::new("".to_string(), channel.to_string(), Vec::new());
let orig_msg = orig_msg.set_uid(uid.clone());
let orig_msg = orig_msg.set_channel(channel.clone());
let orig_msg = orig_msg.set_message(msg.clone());
assert_eq!(&uid, orig_msg.get_uid());
assert_eq!(&channel, orig_msg.get_channel());
assert_eq!(&msg, orig_msg.get_message());
}
#[test]
fn test_set_get_mut_msg() {
let uid = "User".to_string();
let channel = "Channel".to_string();
let omsg = "a test ".to_string().into_bytes();
let nmsg = "a test mut msg".to_string().into_bytes();
let orig_msg = Msg::new("".to_string(), channel.to_string(), omsg);
let orig_msg = orig_msg.set_uid(uid.clone());
let mut orig_msg = orig_msg.set_channel(channel.clone());
let mut_msg = orig_msg.get_mut_message();
mut_msg.extend_from_slice(&"mut msg".to_string().into_bytes());
assert_eq!(&uid, orig_msg.get_uid());
assert_eq!(&channel, orig_msg.get_channel());
assert_eq!(&nmsg, orig_msg.get_message());
}
#[test]
fn test_cid() {
let orig_key = 0xffeffe;
let hdrv = write_hdr_with_key(64, orig_key);
let orig_len = hdrv.len();
let key = read_key_from_hdr(&hdrv);
assert_eq!(orig_key, key);
let read_cid = read_cid_from_hdr(&hdrv);
assert_eq!(orig_key as u32, read_cid);
let key = read_key_from_hdr(&hdrv);
assert_eq!(orig_key, key);
let len = hdrv.len();
assert_eq!(orig_len, len);
}
#[test]
fn test_msgconn_send_read() {
let sec = Duration::new(1, 0);
let addr = "127.0.0.1:8078";
let addr = addr.parse::<SocketAddr>().unwrap();
let raddr = addr.clone();
let uid = "User".to_string();
let uid2 = "User two".to_string();
let channel = "Channel".to_string();
let message = "Hello World!".to_string();
let child =
thread::spawn(move || server_run(addr, None, "".to_string(), "".to_string(), 100, 0));
thread::sleep(sec);
let mut conn = MsgConn::new(uid2.clone(), channel.clone());
conn = conn.connect_with_message(raddr, message.into_bytes());
conn.close();
let mut conn = MsgConn::new(uid.clone(), channel.clone());
conn = conn.connect(raddr);
let (conn, msg) = conn.read_message();
let msg = String::from_utf8_lossy(msg.as_slice());
assert_eq!("Hello World!", msg);
conn.close();
drop(child);
}
#[test]
fn test_msgconn_read_send() {
let sec = Duration::new(1, 0);
let addr = "127.0.0.1:8076";
let addr = addr.parse::<SocketAddr>().unwrap();
let raddr = addr.clone();
let uid = "User".to_string();
let uid2 = "User two".to_string();
let channel = "Channel".to_string();
let message = "Hello World!".to_string();
let child =
thread::spawn(move || server_run(addr, None, "".to_string(), "".to_string(), 100, 0));
thread::sleep(sec);
let mut conn = MsgConn::new(uid.clone(), channel.clone());
conn = conn.connect(raddr);
let mut sconn = MsgConn::new(uid2.clone(), channel.clone());
sconn = sconn.connect_with_message(raddr, message.into_bytes());
sconn.close();
let (conn, msg) = conn.read_message();
let msg = String::from_utf8_lossy(msg.as_slice());
assert_eq!("Hello World!", msg);
conn.close();
drop(child);
}
#[test]
fn test_msgconn_peer_send_read() {
let sec = Duration::new(1, 0);
let addr = "127.0.0.1:8075";
let addr = addr.parse::<SocketAddr>().unwrap();
let paddr = "127.0.0.1:8074";
let paddr = paddr.parse::<SocketAddr>().unwrap();
let praddr = paddr.clone();
let uid = "User".to_string();
let uid2 = "User two".to_string();
let channel = "Channel".to_string();
let message = "Hello World!".to_string();
let child =
thread::spawn(move || server_run(addr, None, "".to_string(), "".to_string(), 100, 0));
thread::sleep(sec);
let pchild = thread::spawn(move || {
server_run(paddr, Some(addr), "".to_string(), "".to_string(), 100, 0)
});
thread::sleep(sec);
let mut conn = MsgConn::new(uid.clone(), channel.clone());
conn = conn.connect_with_message(praddr, message.into_bytes());
conn.close();
let mut conn = MsgConn::new(uid2.clone(), channel.clone());
conn = conn.connect(praddr);
let (conn, msg) = conn.read_message();
let msg = String::from_utf8_lossy(msg.as_slice());
assert_eq!("Hello World!", msg);
conn.close();
drop(pchild);
drop(child);
}
#[test]
fn test_msgconn_peer_read_send() {
let sec = Duration::new(1, 0);
let addr = "127.0.0.1:8073";
let addr = addr.parse::<SocketAddr>().unwrap();
let paddr = "127.0.0.1:8072";
let paddr = paddr.parse::<SocketAddr>().unwrap();
let praddr = paddr.clone();
let uid = "User".to_string();
let uid2 = "User two".to_string();
let channel = "Channel".to_string();
let message = "Hello World!".to_string();
let child =
thread::spawn(move || server_run(addr, None, "".to_string(), "".to_string(), 100, 0));
thread::sleep(sec);
let pchild = thread::spawn(move || {
server_run(paddr, Some(addr), "".to_string(), "".to_string(), 100, 0)
});
thread::sleep(sec);
let mut conn = MsgConn::new(uid.clone(), channel.clone());
conn = conn.connect(praddr);
let mut sconn = MsgConn::new(uid2.clone(), channel.clone());
sconn = sconn.connect_with_message(praddr, message.into_bytes());
sconn.close();
let (conn, msg) = conn.read_message();
let msg = String::from_utf8_lossy(msg.as_slice());
assert_eq!("Hello World!", msg);
conn.close();
drop(pchild);
drop(child);
}
#[test]
fn test_msgconn_basic_read_send() {
let sec = Duration::new(1, 0);
let addr = "127.0.0.1:8071".parse::<SocketAddr>().unwrap();
let serv =
thread::spawn(move || server_run(addr, None, "".to_string(), "".to_string(), 0, 0));
thread::sleep(sec);
let child = thread::spawn(|| {
let uid = "User two".to_string();
let channel = "Channel".to_string();
let addr = "127.0.0.1:8071".parse::<SocketAddr>().unwrap();
let mut conn = MsgConn::new(uid, channel);
conn = conn.connect(addr);
let (conn, msg) = conn.read_message();
let msg = String::from_utf8_lossy(msg.as_slice());
assert_eq!("Hello World!", msg);
conn.close();
});
thread::sleep(sec);
let addr = "127.0.0.1:8071".parse::<SocketAddr>().unwrap();
let uid = "User".to_string();
let channel = "Channel".to_string();
let message = "Hello World!".to_string();
let mut conn = MsgConn::new(uid, channel);
conn = conn.connect_with_message(addr, message.into_bytes());
conn.close();
let _res = child.join();
drop(serv);
}
}