use crate::common::{ControlMessage, ControlMessageBuilder, ControlMessageRef, KnotSync};
use crate::error::KnotError;
use crate::{common, error};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use fallible_iterator::FallibleIterator;
use std::io;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::TcpStream;
use unix_socket::UnixStream;
pub struct Control<S: Write + Read + TryClone> {
read: BufReader<S>,
write: S,
in_transaction: bool,
}
impl<S: Write + Read + TryClone> Control<S> {
pub fn new(socket: S) -> Result<Control<S>, KnotError> {
Ok(Control {
read: BufReader::new(socket.try_clone()?),
write: socket,
in_transaction: false,
})
}
pub fn send_message(&mut self, msg: &ControlMessageRef<'_>) -> Result<(), KnotError> {
log::debug!("Writing Message: {:?}", msg);
fn write_attrib<S: Write + Read>(
sock: &mut S,
code: u8,
attrib: Option<&str>,
) -> io::Result<()> {
if let Some(attrib) = attrib {
sock.write_u8(code + 16)?;
sock.write_u16::<BigEndian>(attrib.len() as u16)?;
sock.write_all(attrib.as_bytes())?;
}
Ok(())
}
self.write.write_u8(0x01)?;
write_attrib(&mut self.write, 0x00, msg.command())?;
write_attrib(&mut self.write, 0x01, msg.flags())?;
write_attrib(&mut self.write, 0x02, msg.error())?;
write_attrib(&mut self.write, 0x03, msg.section())?;
write_attrib(&mut self.write, 0x04, msg.item())?;
write_attrib(&mut self.write, 0x05, msg.id())?;
write_attrib(&mut self.write, 0x06, msg.zone())?;
write_attrib(&mut self.write, 0x07, msg.owner())?;
write_attrib(&mut self.write, 0x08, msg.ttl())?;
write_attrib(&mut self.write, 0x09, msg.ty())?;
write_attrib(&mut self.write, 0x0a, msg.data())?;
write_attrib(&mut self.write, 0x0b, msg.filter())?;
self.write.write_u8(0x03)?; self.write.flush()?;
Ok(())
}
pub fn recv_single_message_old(&mut self) -> Result<(u8, ControlMessage), KnotError> {
let ty = self.read.read_u8()?;
log::trace!("Type is {}", ty);
let mut things = Vec::with_capacity(12);
if ty != 0 && ty != 3 {
loop {
let code = self.read.read_u8()?;
log::trace!("Code is {}", code);
if code == 0x03 {
break;
}
let len = self.read.read_u16::<BigEndian>()?;
log::trace!("len is {}", len);
let mut buf = vec![0; len as usize];
log::trace!("Buffer allocated, starting read");
self.read.read_exact(&mut buf)?;
log::trace!("Converting to string");
let payload = String::from_utf8(buf)?;
let index = code - 16;
log::trace!("Index is {}", index);
things.push((index, payload))
}
}
let msg = common::build_message(things)?;
log::debug!("Message of kind {} Read: {:?}", ty, msg);
Ok((ty, msg))
}
pub fn recv_single_message(&mut self) -> Result<(u8, ControlMessage), KnotError> {
#[derive(Eq, PartialEq, Copy, Clone, Debug)]
enum Type {
End = 0,
Data,
Extra,
Block,
}
fn code_to_type(code: u8) -> Option<Type> {
match code {
0 => Some(Type::End),
1 => Some(Type::Data),
2 => Some(Type::Extra),
3 => Some(Type::Block),
_ => None,
}
}
let mut ret_ty = Type::End;
let mut have_type = false;
let mut things = Vec::with_capacity(12);
loop {
log::trace!("Filling buffer");
let buf = self.read.fill_buf()?;
log::trace!("Buffer filled, size {}", buf.len());
if buf.is_empty() {
return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof).into());
}
let code = buf[0];
log::trace!("Code is {}", code);
if let Some(ty) = code_to_type(code) {
log::trace!("Type is {:?}", ty);
if have_type {
log::trace!("We already have a type, exiting");
break;
} else {
log::trace!("Consuming code from wire");
self.read.consume(1);
}
ret_ty = ty;
if ty == Type::Data || ty == Type::Extra {
have_type = true;
continue;
} else {
break;
}
} else {
log::trace!("Consuming code from wire");
self.read.consume(1);
}
let len = self.read.read_u16::<BigEndian>()?;
log::trace!("len is {}", len);
let mut buf = vec![0; len as usize];
log::trace!("Buffer allocated, starting read");
self.read.read_exact(&mut buf)?;
log::trace!("Converting to string");
let payload = String::from_utf8(buf)?;
let index = code - 16;
log::trace!("Index is {}", index);
things.push((index, payload))
}
let msg = common::build_message(things)?;
log::debug!("Message of kind {:?} Read: {:?}", ret_ty, msg);
Ok((ret_ty as u8, msg))
}
pub fn send_request(
&mut self,
msg: &ControlMessageRef<'_>,
) -> Result<MessageIterator<'_, S>, KnotError> {
self.send_message(msg)?;
Ok(MessageIterator {
inner: self,
burnt: false,
})
}
fn transaction<'a, F, T, C, BCB>(
&mut self,
cb: C,
build: BCB,
begin: &'a str,
commit: &'a str,
abort: &'a str,
) -> Result<T, F>
where
C: FnOnce(&mut Self) -> Result<T, F>,
BCB: Fn(ControlMessageBuilder<'a>) -> ControlMessageBuilder<'a>,
F: From<KnotError>,
{
let mut request_iter =
self.send_request(&build(ControlMessage::build().command(begin)).build())?;
if let Some(msg) = request_iter.next() {
msg?.ensure_success()?;
}
match cb(self) {
Ok(thing) => {
let commit = self
.send_request(&build(ControlMessage::build().command(commit)).build())?
.next();
if let Some(commit) = commit {
match commit?.ensure_success() {
Ok(_) => Ok(thing),
Err(e) => {
self.send_request(
&build(ControlMessage::build().command(abort)).build(),
)?;
Err(e.into())
}
}
} else {
Ok(thing)
}
}
Err(e) => {
self.send_request(&build(ControlMessage::build().command(abort)).build())?;
Err(e)
}
}
}
}
impl<S: Write + Read + TryClone> KnotSync for Control<S> {
fn in_transaction(&self) -> bool {
self.in_transaction
}
fn conf_transaction<F, T, C>(&mut self, cb: C) -> Result<T, F>
where
C: FnOnce(&mut Self) -> Result<T, F>,
F: From<KnotError>,
{
if self.in_transaction {
return Err(KnotError::NestedTransaction.into());
}
self.in_transaction = true;
let value = self.transaction(
cb,
|build| build.flags(""),
"conf-begin",
"conf-commit",
"conf-abort",
);
self.in_transaction = false;
value
}
fn zone_transaction<'a, F, T, C>(&mut self, zone: &'a str, cb: C) -> Result<T, F>
where
C: FnOnce(&mut Self) -> Result<T, F>,
F: From<KnotError>,
{
let build = |build: ControlMessageBuilder<'a>| build.zone(zone);
self.transaction(cb, build, "zone-begin", "zone-commit", "zone-abort")
}
fn conf_set(
&mut self,
section: &str,
id: Option<&str>,
item: Option<&str>,
value: Option<&str>,
) -> Result<(), KnotError> {
let mut builder = ControlMessage::build().command("conf-set").section(section);
if let Some(id) = id {
builder = builder.id(id);
}
if let Some(item) = item {
builder = builder.item(item);
}
if let Some(data) = value {
builder = builder.data(data);
}
self.send_request(&builder.build())?;
Ok(())
}
fn conf_unset(
&mut self,
section: &str,
id: Option<&str>,
item: Option<&str>,
value: Option<&str>,
) -> Result<(), KnotError> {
let mut builder = ControlMessage::build()
.command("conf-unset")
.section(section);
if let Some(id) = id {
builder = builder.id(id);
}
if let Some(item) = item {
builder = builder.item(item);
}
if let Some(value) = value {
builder = builder.data(value);
}
self.send_request(&builder.build())?;
Ok(())
}
fn conf_get<'a>(
&'a mut self,
section: Option<&str>,
id: Option<&str>,
item: Option<&str>,
) -> Result<Vec<ControlMessage>, KnotError> {
let mut builder = ControlMessage::build().command("conf-get");
if let Some(section) = id {
builder = builder.section(section);
}
if let Some(section) = section {
builder = builder.section(section);
}
if let Some(id) = id {
builder = builder.id(id);
}
if let Some(item) = item {
builder = builder.item(item);
}
fallible_iterator::convert(self.send_request(&builder.build())?).collect()
}
}
pub struct MessageIterator<'a, S: Write + Read + TryClone> {
inner: &'a mut Control<S>,
burnt: bool,
}
impl<'a, S: Write + Read + TryClone> Iterator for MessageIterator<'a, S> {
type Item = error::Result<ControlMessage>;
fn next(&mut self) -> Option<Self::Item> {
if self.burnt {
None
} else {
match self.inner.recv_single_message() {
Ok((0, _)) => {
self.burnt = true;
None
}
Ok(res) => Some(Ok(res.1)),
Err(e) => Some(Err(e)),
}
}
}
}
pub trait TryClone: Sized {
fn try_clone(&self) -> std::io::Result<Self>;
}
impl TryClone for TcpStream {
fn try_clone(&self) -> std::io::Result<Self> {
TcpStream::try_clone(self)
}
}
impl TryClone for UnixStream {
fn try_clone(&self) -> std::io::Result<Self> {
UnixStream::try_clone(self)
}
}