use armour_rpc::RpcError;
use compio::buf::{IoBuf, IoBufMut};
use compio::io::framed::codec::{Decoder, Encoder};
use super::protocol::*;
pub(crate) struct RpcCodec;
impl<B: IoBuf> Decoder<Request, B> for RpcCodec {
type Error = RpcError;
fn decode(&mut self, buf: &compio::buf::Slice<B>) -> Result<Request, Self::Error> {
let bytes: &[u8] = buf;
let mut pos = 0;
let op_byte = read_u8(bytes, &mut pos)?;
let op = OpCode::from_repr(op_byte)
.ok_or_else(|| RpcError::Protocol("unknown opcode".to_string()))?;
let hashname = read_u64_be(bytes, &mut pos)?;
let payload = match op {
OpCode::Get | OpCode::Contains => {
let key = read_bytes(bytes, &mut pos)?;
RequestPayload::Key(key)
}
OpCode::First | OpCode::Last | OpCode::ListCollections => RequestPayload::Empty,
OpCode::Count => {
let exact = read_u8(bytes, &mut pos)? != 0;
RequestPayload::Count { exact }
}
OpCode::Range | OpCode::RangeKeys => {
let start = read_bound(bytes, &mut pos)?;
let end = read_bound(bytes, &mut pos)?;
RequestPayload::Range { start, end }
}
OpCode::Upsert => {
let key = read_upsert_key(bytes, &mut pos)?;
let flag_byte = read_u8(bytes, &mut pos)?;
let flag = match flag_byte {
0 => None,
1 => Some(true), 2 => Some(false), _ => {
return Err(RpcError::Protocol("invalid upsert flag".to_string()));
}
};
let value = read_bytes(bytes, &mut pos)?;
RequestPayload::Upsert { key, flag, value }
}
OpCode::Remove => {
let key = read_bytes(bytes, &mut pos)?;
let soft = read_u8(bytes, &mut pos)? != 0;
RequestPayload::Remove { key, soft }
}
OpCode::Take => {
let key = read_bytes(bytes, &mut pos)?;
let soft = read_u8(bytes, &mut pos)? != 0;
RequestPayload::Take { key, soft }
}
OpCode::ApplyBatch => {
let count = read_u32_be(bytes, &mut pos)? as usize;
let mut items = Vec::with_capacity(count);
for _ in 0..count {
let key = read_bytes(bytes, &mut pos)?;
let has_val = read_u8(bytes, &mut pos)?;
let val = if has_val == 1 {
Some(read_bytes(bytes, &mut pos)?)
} else {
None
};
items.push((key, val));
}
RequestPayload::Batch(items)
}
};
Ok(Request {
op,
hashname,
payload,
})
}
}
impl<B: IoBufMut> Encoder<Response, B> for RpcCodec {
type Error = RpcError;
fn encode(&mut self, item: Response, buf: &mut B) -> Result<(), Self::Error> {
let mut tmp = Vec::new();
match item {
Response::Ok(payload) => {
tmp.push(0x00);
encode_ok_payload(&mut tmp, payload);
}
Response::Err { code, message } => {
tmp.push(0x01);
tmp.extend_from_slice(&code.to_be_bytes());
let msg_bytes = message.as_bytes();
tmp.extend_from_slice(&(msg_bytes.len() as u16).to_be_bytes());
tmp.extend_from_slice(msg_bytes);
}
}
buf.extend_from_slice(&tmp)
.map_err(|e| RpcError::Other(e.to_string()))?;
Ok(())
}
}
fn encode_ok_payload(buf: &mut Vec<u8>, payload: ResponsePayload) {
match payload {
ResponsePayload::Empty => {}
ResponsePayload::OptionalData(opt) => match opt {
None => buf.push(0),
Some(data) => {
buf.push(1);
buf.extend_from_slice(&(data.len() as u32).to_be_bytes());
buf.extend_from_slice(&data);
}
},
ResponsePayload::OptionalLen(opt) => match opt {
None => buf.push(0),
Some(len) => {
buf.push(1);
buf.extend_from_slice(&len.to_be_bytes());
}
},
ResponsePayload::OptionalKV(opt) => match opt {
None => buf.push(0),
Some((key, val)) => {
buf.push(1);
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
buf.extend_from_slice(&(val.len() as u32).to_be_bytes());
buf.extend_from_slice(&val);
}
},
ResponsePayload::KeyValues(pairs) => {
buf.extend_from_slice(&(pairs.len() as u32).to_be_bytes());
for (key, val) in pairs {
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
buf.extend_from_slice(&(val.len() as u32).to_be_bytes());
buf.extend_from_slice(&val);
}
}
ResponsePayload::Keys(keys) => {
buf.extend_from_slice(&(keys.len() as u32).to_be_bytes());
for key in keys {
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
}
}
ResponsePayload::Key(key) => {
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
}
ResponsePayload::Count(n) => {
buf.extend_from_slice(&n.to_be_bytes());
}
ResponsePayload::Collections(collections) => {
buf.extend_from_slice(&(collections.len() as u32).to_be_bytes());
for c in collections {
let name = c.name.as_bytes();
buf.extend_from_slice(&(name.len() as u32).to_be_bytes());
buf.extend_from_slice(name);
let partition = c.partition_name.as_bytes();
buf.extend_from_slice(&(partition.len() as u32).to_be_bytes());
buf.extend_from_slice(partition);
buf.extend_from_slice(&c.hashname.to_be_bytes());
buf.extend_from_slice(&c.typ_hash.to_be_bytes());
buf.extend_from_slice(&c.version.to_be_bytes());
buf.extend_from_slice(&c.count.to_be_bytes());
}
}
}
}