use std::io::{Read, Write, Cursor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::iter::once;
use std::collections::HashMap;
use tag::Tag;
use error::Error;
#[derive(Debug, Clone)]
pub struct RtMessage {
tags: Vec<Tag>,
values: Vec<Vec<u8>>,
}
impl RtMessage {
pub fn new(num_fields: u32) -> Self {
RtMessage {
tags: Vec::with_capacity(num_fields as usize),
values: Vec::with_capacity(num_fields as usize),
}
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
let mut msg = Cursor::new(bytes);
let num_tags = msg.read_u32::<LittleEndian>()?;
let mut rt_msg = RtMessage::new(num_tags);
if num_tags == 1 {
let pos = msg.position() as usize;
let tag = Tag::from_wire(&bytes[pos..pos+4])?;
msg.set_position((pos + 4) as u64);
let mut value = Vec::new();
msg.read_to_end(&mut value).unwrap();
rt_msg.add_field(tag, &value)?;
return Ok(rt_msg)
}
let mut offsets = Vec::with_capacity((num_tags - 1) as usize);
let mut tags = Vec::with_capacity(num_tags as usize);
for _ in 0..num_tags - 1 {
let offset = msg.read_u32::<LittleEndian>()?;
if offset % 4 != 0 {
panic!("Invalid offset {:?} in message {:?}", offset, bytes);
}
offsets.push(offset as usize);
}
let mut buf = [0; 4];
for _ in 0..num_tags {
msg.read_exact(&mut buf).unwrap();
let tag = Tag::from_wire(&buf)?;
if let Some(last_tag) = tags.last() {
if tag <= *last_tag {
return Err(Error::TagNotStrictlyIncreasing(tag))
}
}
tags.push(tag);
}
let header_end = msg.position() as usize;
let msg_end = bytes.len() - header_end;
assert_eq!(offsets.len(), tags.len() - 1);
for (tag, (value_start, value_end)) in tags.into_iter().zip(
once(&0).chain(offsets.iter()).zip(
offsets.iter().chain(once(&msg_end))
)
) {
let value = bytes[(header_end + value_start)..(header_end + value_end)].to_vec();
rt_msg.add_field(tag, &value)?;
}
Ok(rt_msg)
}
pub fn add_field(&mut self, tag: Tag, value: &[u8]) -> Result<(), Error> {
if let Some(last_tag) = self.tags.last() {
if tag <= *last_tag {
return Err(Error::TagNotStrictlyIncreasing(tag));
}
}
self.tags.push(tag);
self.values.push(value.to_vec());
Ok(())
}
pub fn num_fields(&self) -> u32 {
self.tags.len() as u32
}
pub fn tags(&self) -> &[Tag] {
&self.tags
}
pub fn values(&self) -> &[Vec<u8>] {
&self.values
}
pub fn into_hash_map(self) -> HashMap<Tag, Vec<u8>> {
self.tags.into_iter().zip(self.values.into_iter()).collect()
}
pub fn encode(&self) -> Result<Vec<u8>, Error> {
let num_tags = self.tags.len();
let mut out = Vec::with_capacity(self.encoded_size());
out.write_u32::<LittleEndian>(num_tags as u32)?;
if num_tags > 1 {
let mut offset_sum = self.values[0].len();
for val in &self.values[1..] {
out.write_u32::<LittleEndian>(offset_sum as u32)?;
offset_sum += val.len();
}
}
for tag in &self.tags {
out.write_all(tag.wire_value())?;
}
for value in &self.values {
out.write_all(value)?;
}
assert_eq!(out.len(), self.encoded_size(), "unexpected length");
Ok(out)
}
pub fn encoded_size(&self) -> usize {
let num_tags = self.tags.len();
let tags_size = 4 * num_tags;
let offsets_size = if num_tags < 2 { 0 } else { 4 * (num_tags - 1) };
let values_size: usize = self.values.iter().map(|v| v.len()).sum();
4 + tags_size + offsets_size + values_size
}
pub fn pad_to_kilobyte(&mut self) {
let size = self.encoded_size();
if size >= 1024 {
return;
}
let mut padding_needed = 1024 - size;
if self.tags.len() == 1 {
padding_needed -= 4;
}
padding_needed -= Tag::PAD.wire_value().len();
let padding = vec![0; padding_needed];
self.add_field(Tag::PAD, &padding).unwrap();
assert_eq!(self.encoded_size(), 1024);
}
}
#[cfg(test)]
mod test {
use std::io::{Cursor, Read};
use byteorder::{LittleEndian, ReadBytesExt};
use message::*;
use tag::Tag;
#[test]
fn empty_message_size() {
let msg = RtMessage::new(0);
assert_eq!(msg.num_fields(), 0);
assert_eq!(msg.encoded_size(), 4);
}
#[test]
fn single_field_message_size() {
let mut msg = RtMessage::new(1);
msg.add_field(Tag::NONC, "1234".as_bytes()).unwrap();
assert_eq!(msg.num_fields(), 1);
assert_eq!(msg.encoded_size(), 12);
}
#[test]
fn two_field_message_size() {
let mut msg = RtMessage::new(2);
msg.add_field(Tag::NONC, "1234".as_bytes()).unwrap();
msg.add_field(Tag::PAD, "abcd".as_bytes()).unwrap();
assert_eq!(msg.num_fields(), 2);
assert_eq!(msg.encoded_size(), 24);
}
#[test]
fn empty_message_encoding() {
let msg = RtMessage::new(0);
let mut encoded = Cursor::new(msg.encode().unwrap());
assert_eq!(encoded.read_u32::<LittleEndian>().unwrap(), 0);
}
#[test]
fn single_field_message_encoding() {
let value = vec![b'a'; 64];
let mut msg = RtMessage::new(1);
msg.add_field(Tag::CERT, &value).unwrap();
let mut encoded = Cursor::new(msg.encode().unwrap());
assert_eq!(encoded.read_u32::<LittleEndian>().unwrap(), 1);
let mut cert = [0u8; 4];
encoded.read_exact(&mut cert).unwrap();
assert_eq!(cert, Tag::CERT.wire_value());
let mut read_val = vec![0u8; 64];
encoded.read_exact(&mut read_val).unwrap();
assert_eq!(value, read_val);
assert_eq!(encoded.position(), 72);
}
#[test]
fn two_field_message_encoding() {
let dele_value = vec![b'a'; 24];
let maxt_value = vec![b'z'; 32];
let mut msg = RtMessage::new(2);
msg.add_field(Tag::DELE, &dele_value).unwrap();
msg.add_field(Tag::MAXT, &maxt_value).unwrap();
let mut encoded = Cursor::new(msg.encode().unwrap());
assert_eq!(encoded.read_u32::<LittleEndian>().unwrap(), 2);
assert_eq!(encoded.read_u32::<LittleEndian>().unwrap(),
dele_value.len() as u32);
let mut dele = [0u8; 4];
encoded.read_exact(&mut dele).unwrap();
assert_eq!(dele, Tag::DELE.wire_value());
let mut maxt = [0u8; 4];
encoded.read_exact(&mut maxt).unwrap();
assert_eq!(maxt, Tag::MAXT.wire_value());
let mut read_dele_val = vec![0u8; 24];
encoded.read_exact(&mut read_dele_val).unwrap();
assert_eq!(dele_value, read_dele_val);
let mut read_maxt_val = vec![0u8; 32];
encoded.read_exact(&mut read_maxt_val).unwrap();
assert_eq!(maxt_value, read_maxt_val);
assert_eq!(encoded.position() as usize, msg.encoded_size());
}
}