use crate::common::{ControlMessage, ControlMessageBuilder, ControlMessageRef};
use crate::error::KnotError;
use crate::{common, error, error::Result};
use std::future::Future;
use std::io;
use std::result::Result as StdResult;
use futures_util::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub struct Control<S: AsyncWrite + AsyncRead + Unpin> {
socket: S,
}
impl<S: AsyncWrite + AsyncRead + Unpin> Control<S> {
pub fn new(socket: S) -> Control<S> {
Control { socket }
}
pub async fn send_message(&mut self, msg: &ControlMessageRef<'_>) -> Result<()> {
async fn write_attrib<S: AsyncWrite + AsyncRead + Unpin>(
sock: &mut S,
code: u8,
attrib: Option<&str>,
) -> io::Result<()> {
if let Some(attrib) = attrib {
write_u8(sock, code + 16).await?;
write_u16(sock, attrib.len() as u16).await?;
sock.write_all(attrib.as_bytes()).await?;
}
Ok(())
}
write_u8(&mut self.socket, 0x01).await?;
write_attrib(&mut self.socket, 0x00, msg.command()).await?;
write_attrib(&mut self.socket, 0x01, msg.flags()).await?;
write_attrib(&mut self.socket, 0x02, msg.error()).await?;
write_attrib(&mut self.socket, 0x03, msg.section()).await?;
write_attrib(&mut self.socket, 0x04, msg.item()).await?;
write_attrib(&mut self.socket, 0x05, msg.id()).await?;
write_attrib(&mut self.socket, 0x06, msg.zone()).await?;
write_attrib(&mut self.socket, 0x07, msg.owner()).await?;
write_attrib(&mut self.socket, 0x08, msg.ttl()).await?;
write_attrib(&mut self.socket, 0x09, msg.ty()).await?;
write_attrib(&mut self.socket, 0x0a, msg.data()).await?;
write_attrib(&mut self.socket, 0x0b, msg.filter()).await?;
write_u8(&mut self.socket, 0x03).await?; self.socket.flush().await?;
Ok(())
}
pub async fn recv_single_message(&mut self) -> Result<(u8, ControlMessage)> {
let ty = read_u8(&mut self.socket).await?;
let mut things = Vec::with_capacity(12);
loop {
let code = read_u8(&mut self.socket).await?;
if code == 0x03 {
break;
}
let len = read_u16(&mut self.socket).await?;
let mut buf = vec![0; len as usize];
self.socket.read_exact(&mut buf).await?;
let payload = String::from_utf8(buf)?;
let index = code - 16;
things.push((index, payload))
}
Ok((ty, common::build_message(things)?))
}
pub async fn send_request(
&mut self,
req: &ControlMessageRef<'_>,
) -> error::Result<Vec<ControlMessage>> {
self.send_message(req).await?;
let mut ret = vec![];
loop {
let res = self.recv_single_message().await?;
if res.0 == 0 {
break;
}
ret.push(res.1);
}
for el in &ret {
el.ensure_success_ref()?;
}
Ok(ret)
}
async fn transaction<'a, F, T, C, BCB, Fut>(
&mut self,
cb: C,
build: BCB,
begin: &'a str,
commit: &'a str,
abort: &'a str,
) -> StdResult<T, F>
where
C: FnOnce(&mut Self) -> Fut,
Fut: Future<Output = StdResult<T, F>>,
BCB: Fn(ControlMessageBuilder<'a>) -> ControlMessageBuilder<'a>,
F: From<KnotError>,
{
let request_vec = self
.send_request(&build(ControlMessage::build().command(begin)).build())
.await?;
if let Some(msg) = request_vec.into_iter().next() {
msg.ensure_success()?;
}
match cb(self).await {
Ok(thing) => {
let commit = self
.send_request(&build(ControlMessage::build().command(commit)).build())
.await?;
if let Some(commit) = commit.into_iter().next() {
match commit.ensure_success() {
Ok(_) => Ok(thing),
Err(e) => {
self.send_request(
&build(ControlMessage::build().command(abort)).build(),
)
.await?;
Err(e.into())
}
}
} else {
Ok(thing)
}
}
Err(e) => {
self.send_request(&build(ControlMessage::build().command(abort)).build())
.await?;
Err(e)
}
}
}
pub async fn conf_transaction<F, T, C, Fut>(&mut self, cb: C) -> StdResult<T, F>
where
C: FnOnce(&mut Self) -> Fut,
Fut: Future<Output = StdResult<T, F>>,
F: From<KnotError>,
{
self.transaction(cb, |build| build, "conf-begin", "conf-commit", "conf-abort")
.await
}
pub async fn zone_transaction<'a, F, T, C, Fut>(
&mut self,
zone: &'a str,
cb: C,
) -> StdResult<T, F>
where
C: FnOnce(&mut Self) -> Fut,
Fut: Future<Output = StdResult<T, F>>,
F: From<KnotError>,
{
let build = |build: ControlMessageBuilder<'a>| build.zone(zone);
self.transaction(cb, build, "zone-begin", "zone-commit", "zone-abort")
.await
}
pub async fn conf_set(
&mut self,
section: &str,
id: Option<&str>,
item: Option<&str>,
value: Option<&str>,
) -> Result<()> {
let mut builder = ControlMessage::build().command("conf-set").section(section);
if let Some(id) = id {
builder = builder.id(id);
}
if let Some(item) = item {
builder = builder.item(item);
}
if let Some(data) = value {
builder = builder.data(data);
}
self.send_request(&builder.build()).await?;
Ok(())
}
pub async fn conf_unset(
&mut self,
section: &str,
id: Option<&str>,
item: Option<&str>,
) -> Result<()> {
let mut builder = ControlMessage::build()
.command("conf-unset")
.section(section);
if let Some(id) = id {
builder = builder.id(id);
}
if let Some(item) = item {
builder = builder.item(item);
}
self.send_request(&builder.build()).await?;
Ok(())
}
pub async fn conf_get(
&mut self,
section: &str,
id: Option<&str>,
item: Option<&str>,
) -> Result<Vec<ControlMessage>> {
let mut builder = ControlMessage::build().command("conf-get").section(section);
if let Some(id) = id {
builder = builder.id(id);
}
if let Some(item) = item {
builder = builder.item(item);
}
self.send_request(&builder.build()).await
}
}
async fn write_u8<W: AsyncWrite + Unpin>(write: &mut W, value: u8) -> io::Result<()> {
write.write_all(&[value]).await?;
Ok(())
}
async fn write_u16<W: AsyncWrite + Unpin>(write: &mut W, value: u16) -> io::Result<()> {
write
.write_all(&[((value & 0xff00) >> 8) as u8, (value & 0xff) as u8])
.await?;
Ok(())
}
async fn read_u8<R: AsyncRead + Unpin>(read: &mut R) -> io::Result<u8> {
let mut buf = [0; 1];
read.read_exact(&mut buf).await?;
Ok(buf[0])
}
async fn read_u16<R: AsyncRead + Unpin>(read: &mut R) -> io::Result<u16> {
let mut buf = [0; 2];
read.read_exact(&mut buf).await?;
Ok(((buf[0] as u16) << 8) | buf[1] as u16)
}