#![allow(dead_code)]
use std::fmt::Debug;
use vm_memory::ByteValued;
pub(crate) const MAX_MSG_SIZE: usize = 0x1000;
pub(crate) const MAX_ATTACHED_FD_ENTRIES: usize = 4;
pub(crate) trait Req:
Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Send + Sync + Into<u32>
{
fn is_valid(&self) -> bool;
}
#[repr(u32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum RequestCode {
Noop = 0,
GetBlob = 1,
FetchRange = 2,
MaxCommand = 3,
}
impl From<RequestCode> for u32 {
fn from(req: RequestCode) -> u32 {
req as u32
}
}
impl Req for RequestCode {
fn is_valid(&self) -> bool {
(*self >= RequestCode::Noop) && (*self < RequestCode::MaxCommand)
}
}
pub trait MsgValidator {
fn is_valid(&self) -> bool {
true
}
}
bitflags! {
pub struct HeaderFlag: u32 {
const VERSION = 0x1;
const REPLY = 0x4;
const NEED_REPLY = 0x8;
const ALL_FLAGS = 0xc;
const RESERVED_BITS = !0xf;
}
}
#[repr(C, packed)]
#[derive(Copy)]
pub(crate) struct MsgHeader {
tag: u64,
request: u32,
flags: u32,
size: u32,
}
impl Debug for MsgHeader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MsgHeader")
.field("tag", &{ self.tag })
.field("request", &{ self.request })
.field("flags", &{ self.flags })
.field("size", &{ self.size })
.finish()
}
}
impl Clone for MsgHeader {
fn clone(&self) -> MsgHeader {
*self
}
}
impl PartialEq for MsgHeader {
fn eq(&self, other: &Self) -> bool {
self.tag == other.tag
&& self.request == other.request
&& self.flags == other.flags
&& self.size == other.size
}
}
impl MsgHeader {
pub fn new(tag: u64, request: RequestCode, flags: u32, size: u32) -> Self {
let fl = (flags & HeaderFlag::ALL_FLAGS.bits()) | 0x1;
MsgHeader {
tag,
request: request.into(),
flags: fl,
size,
}
}
pub fn get_tag(&self) -> u64 {
self.tag
}
pub fn set_tag(&mut self, tag: u64) {
self.tag = tag;
}
pub fn get_code(&self) -> RequestCode {
unsafe { std::mem::transmute_copy::<u32, RequestCode>(&{ self.request }) }
}
pub fn set_code(&mut self, request: RequestCode) {
self.request = request.into();
}
pub fn get_version(&self) -> u32 {
self.flags & 0x3
}
pub fn set_version(&mut self, ver: u32) {
self.flags &= !0x3;
self.flags |= ver & 0x3;
}
pub fn is_reply(&self) -> bool {
(self.flags & HeaderFlag::REPLY.bits()) != 0
}
pub fn set_reply(&mut self, is_reply: bool) {
if is_reply {
self.flags |= HeaderFlag::REPLY.bits();
} else {
self.flags &= !HeaderFlag::REPLY.bits();
}
}
pub fn is_need_reply(&self) -> bool {
(self.flags & HeaderFlag::NEED_REPLY.bits()) != 0
}
pub fn set_need_reply(&mut self, need_reply: bool) {
if need_reply {
self.flags |= HeaderFlag::NEED_REPLY.bits();
} else {
self.flags &= !HeaderFlag::NEED_REPLY.bits();
}
}
pub fn is_reply_for(&self, req: &MsgHeader) -> bool {
self.is_reply()
&& !req.is_reply()
&& self.get_code() == req.get_code()
&& req.tag == self.tag
}
pub fn get_size(&self) -> u32 {
self.size
}
pub fn set_size(&mut self, size: u32) {
self.size = size;
}
}
impl Default for MsgHeader {
fn default() -> Self {
MsgHeader {
tag: 0,
request: 0,
flags: 0x1,
size: 0,
}
}
}
unsafe impl ByteValued for MsgHeader {}
impl MsgValidator for MsgHeader {
#[allow(clippy::if_same_then_else)]
fn is_valid(&self) -> bool {
if !self.get_code().is_valid() {
return false;
} else if self.tag == 0 {
return false;
} else if self.size as usize > MAX_MSG_SIZE {
return false;
} else if self.get_version() != 0x1 {
return false;
} else if (self.flags & HeaderFlag::RESERVED_BITS.bits()) != 0 {
return false;
}
true
}
}
#[repr(C, packed)]
#[derive(Clone, Copy)]
pub(crate) struct GetBlobRequest {
pub generation: u32,
pub id: [u8; 256],
}
impl Default for GetBlobRequest {
fn default() -> Self {
Self {
generation: 0,
id: [0u8; 256],
}
}
}
impl GetBlobRequest {
pub fn new(generation: u32, id: &str) -> Self {
debug_assert!(id.len() < 256);
let mut buf = [0x0u8; 256];
buf.copy_from_slice(id.as_bytes());
GetBlobRequest {
generation,
id: buf,
}
}
}
unsafe impl ByteValued for GetBlobRequest {}
impl MsgValidator for GetBlobRequest {
fn is_valid(&self) -> bool {
self.id.contains(&0u8)
}
}
#[repr(C, packed)]
#[derive(Clone, Copy, Default)]
pub(crate) struct GetBlobReply {
pub token: u64,
pub base: u64,
pub result: u32,
}
impl GetBlobReply {
pub fn new(token: u64, base: u64, result: u32) -> Self {
Self {
token,
base,
result,
}
}
}
unsafe impl ByteValued for GetBlobReply {}
impl MsgValidator for GetBlobReply {
fn is_valid(&self) -> bool {
self.token != 0 || self.result != 0
}
}
#[repr(C, packed)]
#[derive(Copy, Clone, Default)]
pub(crate) struct FetchRangeRequest {
pub token: u64,
pub start: u64,
pub count: u64,
}
impl FetchRangeRequest {
pub fn new(token: u64, start: u64, count: u64) -> Self {
FetchRangeRequest {
token,
start,
count,
}
}
}
unsafe impl ByteValued for FetchRangeRequest {}
impl MsgValidator for FetchRangeRequest {}
#[repr(u32)]
pub enum FetchRangeResult {
Success = 0,
Failure = 1,
GenerationMismatch = 2,
}
#[repr(C, packed)]
#[derive(Copy, Clone, Default)]
pub(crate) struct FetchRangeReply {
pub token: u64,
pub count: u64,
pub result: u32,
}
impl FetchRangeReply {
pub fn new(token: u64, count: u64, result: u32) -> Self {
FetchRangeReply {
token,
count,
result,
}
}
}
unsafe impl ByteValued for FetchRangeReply {}
impl MsgValidator for FetchRangeReply {}
#[cfg(test)]
mod tests {
use super::*;
use std::mem;
#[test]
fn check_master_request_code() {
let code = RequestCode::Noop;
assert!(code.is_valid());
let code = RequestCode::MaxCommand;
assert!(!code.is_valid());
assert!(code > RequestCode::Noop);
let code = RequestCode::GetBlob;
assert!(code.is_valid());
let code = RequestCode::FetchRange;
assert!(code.is_valid());
assert_eq!(code, code.clone());
let code: RequestCode = unsafe { std::mem::transmute::<u32, RequestCode>(10000u32) };
assert!(!code.is_valid());
}
#[test]
fn msg_header_ops() {
let mut hdr = MsgHeader::new(2, RequestCode::GetBlob, 0, 0x100);
assert_eq!(hdr.get_code(), RequestCode::GetBlob);
hdr.set_code(RequestCode::FetchRange);
assert_eq!(hdr.get_code(), RequestCode::FetchRange);
assert_eq!(hdr.get_version(), 0x1);
assert!(!hdr.is_reply());
hdr.set_reply(true);
assert!(hdr.is_reply());
hdr.set_reply(false);
assert!(!hdr.is_need_reply());
hdr.set_need_reply(true);
assert!(hdr.is_need_reply());
hdr.set_need_reply(false);
assert_eq!(hdr.get_size(), 0x100);
hdr.set_size(0x200);
assert_eq!(hdr.get_size(), 0x200);
assert!(!hdr.is_need_reply());
assert!(!hdr.is_reply());
assert_eq!(hdr.get_version(), 0x1);
assert!(hdr.is_valid());
hdr.set_size(0x2000);
assert!(!hdr.is_valid());
hdr.set_size(0x100);
assert_eq!(hdr.get_size(), 0x100);
assert!(hdr.is_valid());
hdr.set_size((MAX_MSG_SIZE - mem::size_of::<MsgHeader>()) as u32);
assert!(hdr.is_valid());
hdr.set_size(0x0);
assert!(hdr.is_valid());
hdr.set_version(0x0);
assert!(!hdr.is_valid());
hdr.set_version(0x2);
assert!(!hdr.is_valid());
hdr.set_version(0x1);
assert!(hdr.is_valid());
assert_eq!(hdr.get_tag(), 2);
hdr.set_tag(200);
assert_eq!(hdr.get_tag(), 200);
assert_eq!(hdr, hdr.clone());
assert_eq!(hdr.clone().get_code(), hdr.get_code());
assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr));
}
}