use std::{mem, ops};
use std::marker::PhantomData;
use bytes::{BigEndian, BufMut, ByteOrder, BytesMut};
use iana::opt::OptionCode;
use super::compose::{Compose, Compress, Compressor};
use super::header::{Header, HeaderCounts, HeaderSection};
use super::message::Message;
use super::name::ToDname;
use super::opt::{OptData, OptHeader};
use super::parse::ShortBuf;
use super::question::Question;
use super::rdata::RecordData;
use super::record::Record;
#[derive(Clone, Debug)]
pub struct MessageBuilder {
target: MessageTarget,
}
impl MessageBuilder {
pub fn new_udp() -> Self {
Self::with_params(512, 512, 0)
}
pub fn new_tcp(capacity: usize) -> Self {
let mut buf = BytesMut::with_capacity(capacity + 2);
buf.put_u16_be(0);
let mut res = Self::from_buf(buf);
res.set_limit(::std::u16::MAX as usize);
res.set_page_size(capacity);
res
}
pub fn from_buf(buf: BytesMut) -> Self {
MessageBuilder { target: MessageTarget::from_buf(buf) }
}
pub fn with_capacity(capacity: usize) -> Self {
Self::from_buf(BytesMut::with_capacity(capacity))
}
pub fn with_params(initial: usize, limit: usize, page_size: usize)
-> Self {
let mut res = Self::with_capacity(initial);
res.set_limit(limit);
res.set_page_size(page_size);
res
}
pub fn enable_compression(&mut self) {
self.target.buf.enable_compression()
}
pub fn set_limit(&mut self, limit: usize) {
self.target.buf.set_limit(limit)
}
pub fn set_page_size(&mut self, page_size: usize) {
self.target.buf.set_page_size(page_size)
}
}
impl MessageBuilder {
pub fn push<N: ToDname, Q: Into<Question<N>>>(&mut self, question: Q)
-> Result<(), ShortBuf> {
self.target.push(|target| question.into().compress(target),
|counts| counts.inc_qdcount())
}
pub fn answer(self) -> AnswerBuilder {
AnswerBuilder::new(self.target)
}
pub fn authority(self) -> AuthorityBuilder {
self.answer().authority()
}
pub fn additional(self) -> AdditionalBuilder {
self.answer().authority().additional()
}
}
impl SectionBuilder for MessageBuilder {
fn into_target(self) -> MessageTarget { self.target }
fn get_target(&self) -> &MessageTarget { &self.target }
fn get_target_mut(&mut self) -> &mut MessageTarget { &mut self.target }
}
#[derive(Clone, Debug)]
pub struct AnswerBuilder {
target: MessageTarget,
}
impl AnswerBuilder {
fn new(target: MessageTarget) -> Self {
AnswerBuilder { target }
}
pub fn authority(self) -> AuthorityBuilder {
AuthorityBuilder::new(self.target)
}
pub fn additional(self) -> AdditionalBuilder {
self.authority().additional()
}
pub fn opt(self) -> Result<OptBuilder, ShortBuf> where Self: Sized {
OptBuilder::new(self.into_target())
}
}
impl SectionBuilder for AnswerBuilder {
fn into_target(self) -> MessageTarget { self.target }
fn get_target(&self) -> &MessageTarget { &self.target }
fn get_target_mut(&mut self) -> &mut MessageTarget { &mut self.target }
}
impl RecordSectionBuilder for AnswerBuilder {
fn push<N, D, R>(&mut self, record: R) -> Result<(), ShortBuf>
where N: ToDname, D: RecordData, R: Into<Record<N, D>> {
self.target.push(|target| record.into().compress(target),
|counts| counts.inc_ancount())
}
}
#[derive(Clone, Debug)]
pub struct AuthorityBuilder {
target: MessageTarget,
}
impl AuthorityBuilder {
fn new(target: MessageTarget) -> Self {
AuthorityBuilder { target }
}
pub fn additional(self) -> AdditionalBuilder {
AdditionalBuilder::new(self.target)
}
pub fn opt(self) -> Result<OptBuilder, ShortBuf> where Self: Sized {
OptBuilder::new(self.into_target())
}
}
impl SectionBuilder for AuthorityBuilder {
fn into_target(self) -> MessageTarget { self.target }
fn get_target(&self) -> &MessageTarget { &self.target }
fn get_target_mut(&mut self) -> &mut MessageTarget { &mut self.target }
}
impl RecordSectionBuilder for AuthorityBuilder {
fn push<N, D, R>(&mut self, record: R) -> Result<(), ShortBuf>
where N: ToDname, D: RecordData, R: Into<Record<N, D>> {
self.target.push(|target| record.into().compress(target),
|counts| counts.inc_nscount())
}
}
#[derive(Clone, Debug)]
pub struct AdditionalBuilder {
target: MessageTarget,
}
impl AdditionalBuilder {
fn new(target: MessageTarget) -> Self {
AdditionalBuilder { target }
}
pub fn opt(self) -> Result<OptBuilder, ShortBuf> where Self: Sized {
OptBuilder::new(self.into_target())
}
}
impl SectionBuilder for AdditionalBuilder {
fn into_target(self) -> MessageTarget { self.target }
fn get_target(&self) -> &MessageTarget { &self.target }
fn get_target_mut(&mut self) -> &mut MessageTarget { &mut self.target }
}
impl RecordSectionBuilder for AdditionalBuilder {
fn push<N, D, R>(&mut self, record: R) -> Result<(), ShortBuf>
where N: ToDname, D: RecordData, R: Into<Record<N, D>> {
self.target.push(|target| record.into().compress(target),
|counts| counts.inc_nscount())
}
}
#[derive(Clone, Debug)]
pub struct OptBuilder {
target: MessageTarget,
pos: usize,
}
impl OptBuilder {
fn new(mut target: MessageTarget) -> Result<Self, ShortBuf> {
let pos = target.len();
target.compose(&OptHeader::default())?;
target.compose(&0u16)?;
target.counts_mut().inc_arcount();
Ok(OptBuilder { pos, target })
}
pub fn push<O: OptData>(&mut self, option: &O) -> Result<(), ShortBuf> {
self.target.compose(&option.code())?;
let len = option.compose_len();
assert!(len <= ::std::u16::MAX as usize);
self.target.compose(&(len as u16))?;
self.target.compose(option)?;
self.complete();
Ok(())
}
pub(super) fn build<F>(&mut self, code: OptionCode, len: u16, op: F)
-> Result<(), ShortBuf>
where F: FnOnce(&mut Compressor)
-> Result<(), ShortBuf> {
self.target.compose(&code)?;
self.target.compose(&len)?;
op(&mut self.target.buf)?;
self.complete();
Ok(())
}
pub fn additional(self) -> AdditionalBuilder {
AdditionalBuilder::new(self.target)
}
fn complete(&mut self) {
let len = self.target.len()
- (self.pos + mem::size_of::<OptHeader>() + 2);
assert!(len <= ::std::u16::MAX as usize);
let count_pos = self.pos + mem::size_of::<OptHeader>();
BigEndian::write_u16(&mut self.target.as_slice_mut()[count_pos..],
len as u16);
}
}
impl SectionBuilder for OptBuilder {
fn into_target(self) -> MessageTarget { self.target }
fn get_target(&self) -> &MessageTarget { &self.target }
fn get_target_mut(&mut self) -> &mut MessageTarget { &mut self.target }
}
impl ops::Deref for OptBuilder {
type Target = OptHeader;
fn deref(&self) -> &Self::Target {
OptHeader::for_record_slice(&self.target.as_slice()[self.pos..])
}
}
impl ops::DerefMut for OptBuilder {
fn deref_mut(&mut self) -> &mut Self::Target {
OptHeader::for_record_slice_mut(&mut self.target.as_slice_mut()
[self.pos..])
}
}
#[derive(Clone, Debug)]
pub struct MessageTarget {
buf: Compressor,
start: usize,
}
impl MessageTarget {
fn from_buf(mut buf: BytesMut) -> Self {
let start = buf.len();
if buf.remaining_mut() < mem::size_of::<HeaderSection>() {
let additional = mem::size_of::<HeaderSection>()
- buf.remaining_mut();
buf.reserve(additional)
}
let mut buf = Compressor::from_buf(buf);
HeaderSection::default().compose(&mut buf);
MessageTarget { buf, start }
}
fn header(&self) -> &Header {
Header::for_message_slice(self.buf.so_far())
}
fn header_mut(&mut self) -> &mut Header {
Header::for_message_slice_mut(self.buf.so_far_mut())
}
fn counts(&self) -> &HeaderCounts {
HeaderCounts::for_message_slice(self.buf.so_far())
}
fn counts_mut(&mut self) -> &mut HeaderCounts {
HeaderCounts::for_message_slice_mut(self.buf.so_far_mut())
}
fn push<O, I, E>(&mut self, composeop: O, incop: I) -> Result<(), E>
where O: FnOnce(&mut Compressor) -> Result<(), E>,
I: FnOnce(&mut HeaderCounts) {
composeop(&mut self.buf).map(|()| incop(self.counts_mut()))
}
fn snapshot<T>(&self) -> Snapshot<T> {
Snapshot {
pos: self.buf.len(),
counts: *self.counts(),
marker: PhantomData,
}
}
fn rewind<T>(&mut self, snapshot: &Snapshot<T>) {
self.buf.truncate(snapshot.pos);
self.counts_mut().set(snapshot.counts);
}
fn preview(&self) -> &[u8] {
self.buf.as_slice()
}
fn prelude(&self) -> &[u8] {
&self.buf.as_slice()[..self.start]
}
fn prelude_mut(&mut self) -> &mut [u8] {
&mut self.buf.as_slice_mut()[..self.start]
}
fn unwrap(self) -> BytesMut {
self.buf.unwrap()
}
fn freeze(self) -> Message {
let bytes = if self.start == 0 {
self.buf.unwrap().freeze()
}
else {
self.buf.unwrap().freeze().slice_from(self.start)
};
unsafe { Message::from_bytes_unchecked(bytes) }
}
}
impl ops::Deref for MessageTarget {
type Target = Compressor;
fn deref(&self) -> &Self::Target {
&self.buf
}
}
impl ops::DerefMut for MessageTarget {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.buf
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Snapshot<T> {
pos: usize,
counts: HeaderCounts,
marker: PhantomData<T>,
}
pub trait SectionBuilder : Sized {
fn set_limit(&mut self, limit: usize) {
self.get_target_mut().buf.set_limit(limit)
}
fn header(&self) -> &Header {
self.get_target().header()
}
fn header_mut(&mut self) -> &mut Header {
self.get_target_mut().header_mut()
}
fn preview(&self) -> &[u8] {
self.get_target().preview()
}
fn prelude(&self) -> &[u8] {
self.get_target().prelude()
}
fn prelude_mut(&mut self) -> &mut [u8] {
self.get_target_mut().prelude_mut()
}
fn finish(self) -> BytesMut {
self.into_target().unwrap()
}
fn freeze(self) -> Message {
self.into_target().freeze()
}
fn snapshot(&self) -> Snapshot<Self> {
self.get_target().snapshot()
}
fn rewind(&mut self, snapshot: &Snapshot<Self>) {
self.get_target_mut().rewind(snapshot)
}
fn into_target(self) -> MessageTarget;
fn get_target(&self) -> &MessageTarget;
fn get_target_mut(&mut self) -> &mut MessageTarget;
}
pub trait RecordSectionBuilder : SectionBuilder {
fn push<N, D, R>(&mut self, record: R) -> Result<(), ShortBuf>
where N: ToDname, D: RecordData, R: Into<Record<N, D>>;
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use super::*;
use rdata::*;
use bits::name::*;
use bits::rdata::*;
use bits::message::*;
fn get_built_message() -> Message {
let msg = MessageBuilder::with_capacity(512);
let mut msg = msg.answer();
msg.push((Dname::from_str("foo.example.com.").unwrap(), 86000,
Cname::new(Dname::from_str("baz.example.com.")
.unwrap()))).unwrap();
let mut msg = msg.authority();
msg.push((Dname::from_str("bar.example.com.").unwrap(), 86000,
Cname::new(Dname::from_str("baz.example.com.")
.unwrap()))).unwrap();
msg.freeze()
}
#[test]
fn build_message() {
let msg = get_built_message();
assert_eq!(1, msg.header_counts().ancount());
assert_eq!(1, msg.header_counts().nscount());
}
}