use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use error::MemcacheError;
use rand;
use std::collections::HashMap;
use std::io;
use std::io::{Error, ErrorKind, Read, Write};
use std::net::UdpSocket;
use std::time::Duration;
use std::u16;
use url::Url;
pub struct UdpStream {
socket: UdpSocket,
read_buf: Vec<u8>,
write_buf: Vec<u8>,
request_id: u16,
}
impl UdpStream {
pub fn new(addr: &Url) -> Result<Self, MemcacheError> {
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.connect(&*addr.socket_addrs(|| None)?)?;
return Ok(UdpStream {
socket,
read_buf: Vec::new(),
write_buf: Vec::new(),
request_id: rand::random::<u16>(),
});
}
pub(crate) fn set_read_timeout(&self, duration: Option<Duration>) -> Result<(), MemcacheError> {
Ok(self.socket.set_read_timeout(duration)?)
}
pub(crate) fn set_write_timeout(&self, duration: Option<Duration>) -> Result<(), MemcacheError> {
Ok(self.socket.set_write_timeout(duration)?)
}
}
impl Read for UdpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut buf_len = buf.len();
if buf_len > self.read_buf.len() {
buf_len = self.read_buf.len();
}
buf[0..buf_len].copy_from_slice(&(self.read_buf[0..buf_len]));
self.read_buf.drain(0..buf_len);
Ok(buf_len)
}
}
impl Write for UdpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_buf.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
let mut udp_header: Vec<u8> = Vec::new();
udp_header.write_u16::<BigEndian>(self.request_id)?; udp_header.write_u16::<BigEndian>(0)?; udp_header.write_u16::<BigEndian>(1)?; udp_header.write_u16::<BigEndian>(0)?; self.write_buf.splice(0..0, udp_header.iter().cloned());
self.socket.send(self.write_buf.as_slice())?;
self.write_buf.clear();
let mut response_datagrams: HashMap<u16, Vec<u8>> = HashMap::new();
let mut total_datagrams;
let mut remaining_datagrams = 0;
self.read_buf.clear();
loop {
let mut buf: [u8; 1400] = [0; 1400]; let bytes_read = self.socket.recv(&mut buf)?;
if bytes_read < 8 {
return Err(Error::new(ErrorKind::Other, "Invalid UDP header received"));
}
let request_id = BigEndian::read_u16(&buf[0..]);
if self.request_id != request_id {
continue;
}
let sequence_no = BigEndian::read_u16(&buf[2..]);
total_datagrams = BigEndian::read_u16(&buf[4..]);
if remaining_datagrams == 0 {
remaining_datagrams = total_datagrams;
}
let mut v: Vec<u8> = Vec::new();
v.extend_from_slice(&buf[8..bytes_read]);
response_datagrams.insert(sequence_no, v);
remaining_datagrams -= 1;
if remaining_datagrams == 0 {
break;
}
}
for i in 0..total_datagrams {
self.read_buf.append(&mut (response_datagrams[&i].clone()));
}
self.request_id = (self.request_id % (u16::MAX)) + 1;
Ok(())
}
}