#![warn(missing_docs)]
#![allow(incomplete_features)]
#![feature(slice_as_chunks)]
#![feature(generic_const_exprs)]
use aes::Aes256;
use bincode::{
config::{RejectTrailing, VarintEncoding, WithOtherIntEncoding, WithOtherTrailing},
DefaultOptions, Options,
};
use cipher::{
block_padding::Pkcs7, generic_array::GenericArray, BlockDecrypt, BlockEncrypt,
KeyInit,
};
use rand::random;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::error::Error;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
};
#[derive(Debug)]
pub enum CovertError {
DecryptionError,
InvalidHash,
DeserializeError,
}
impl Display for CovertError {
#[cfg(not(debug_assertions))]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "")
}
#[cfg(debug_assertions)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CovertError::DecryptionError => write!(f, "Decryption error"),
CovertError::InvalidHash => write!(f, "Invalid hash error"),
CovertError::DeserializeError => write!(f, "Deserialization error"),
}
}
}
impl Error for CovertError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
fn description(&self) -> &str {
""
}
fn cause(&self) -> Option<&dyn Error> {
self.source()
}
}
#[derive(Deserialize, Serialize, Debug)]
struct CovertPacket {
hash: u8,
stream: u16,
syn: u32,
want: u32,
last: bool,
payload: Vec<u8>,
pad: Vec<u8>,
}
impl CovertPacket {
fn new(stream: u16, syn: u32, last: bool, payload: &[u8], blocks: usize) -> Self {
let buf_size = (blocks * 16) - 18;
let pad = vec![0u8; buf_size - payload.len()];
let pad: Vec<u8> = pad.iter().map(|_| random()).collect();
Self {
hash: 0,
stream,
syn,
want: 0,
last,
payload: payload.to_vec(),
pad,
}
}
}
struct CovertStream<T> {
out_bound_packet_cache: VecDeque<CovertPacket>,
in_bound_packet_cache: VecDeque<CovertPacket>,
message_cache: VecDeque<T>,
out_count: u32,
out_syn: u32,
in_syn: u32,
}
pub struct CovertChannel<T, const BLOCKS: usize>
where
T: DeserializeOwned + Serialize,
{
engine: Aes256,
encoder: WithOtherTrailing<
WithOtherIntEncoding<DefaultOptions, VarintEncoding>,
RejectTrailing,
>,
streams: HashMap<u16, CovertStream<T>>,
}
impl<T, const BLOCKS: usize> CovertChannel<T, BLOCKS>
where
T: DeserializeOwned + Serialize,
{
pub fn new(key: [u8; 32]) -> CovertChannel<T, BLOCKS> {
let encoder = bincode::DefaultOptions::new()
.with_varint_encoding()
.reject_trailing_bytes();
CovertChannel {
encoder,
engine: Aes256::new(&GenericArray::from(key)),
streams: HashMap::new(),
}
}
pub fn get_message(&mut self, stream_id: u16) -> Option<T> {
let stream = self.streams.get_mut(&stream_id)?;
stream.message_cache.pop_front()
}
pub fn put_message(&mut self, msg: T, stream_id: u16) -> ()
where
[(); (BLOCKS * 16) - 18]:,
{
let encode = self.encoder;
let stream = self.get_stream_by_id(stream_id);
let res = encode.serialize(&msg).unwrap();
let (parts, end) = res.as_chunks::<{ (BLOCKS * 16) - 18 }>();
for part in parts {
let new_packet = CovertPacket::new(
stream_id,
stream.out_count,
false,
part.as_slice(),
BLOCKS,
);
stream.out_count += 1;
stream.out_bound_packet_cache.push_back(new_packet);
}
if end.len() != 0 {
let last_packet =
CovertPacket::new(stream_id, stream.out_count, true, end, BLOCKS);
stream.out_bound_packet_cache.push_back(last_packet);
stream.out_count += 1;
} else {
stream.out_bound_packet_cache.back_mut().unwrap().last = true;
}
}
pub fn get_packet(&mut self, stream_id: u16) -> Vec<u8> {
let encode = self.encoder;
let stream = self.get_stream_by_id(stream_id);
if stream.out_bound_packet_cache.len() == 0 {
stream.out_bound_packet_cache.push_front(CovertPacket::new(
stream_id,
stream.out_count,
false,
&[],
BLOCKS,
));
stream.out_count += 1;
}
let out = stream.out_bound_packet_cache.front_mut().unwrap();
out.want = stream.in_syn;
let mut tmp = encode.serialize(out).unwrap();
let hash = crc32fast::hash(&tmp).to_le_bytes()[0];
tmp[0] = hash;
let done = self.engine.encrypt_padded_vec::<Pkcs7>(&tmp);
return done;
}
pub fn put_packet(&mut self, pkt: &[u8]) -> Result<(u16, bool), CovertError> {
let encode = self.encoder;
let mut tmp = self
.engine
.decrypt_padded_vec::<Pkcs7>(pkt)
.or(Err(CovertError::DecryptionError))?;
let hash = tmp[0];
tmp[0] = 0;
let actual = crc32fast::hash(&tmp).to_le_bytes()[0];
if hash != actual {
return Err(CovertError::InvalidHash);
};
let in_packet = encode
.deserialize::<CovertPacket>(&tmp)
.or(Err(CovertError::DeserializeError))?;
let stream_id = in_packet.stream;
let mut stream = self.get_stream_by_id(stream_id);
stream.out_syn = Ord::max(stream.out_syn, in_packet.want);
stream
.out_bound_packet_cache
.retain(|item| item.syn >= stream.out_syn);
if in_packet.syn == stream.in_syn {
stream.in_syn += 1;
let is_last = in_packet.last;
stream.in_bound_packet_cache.push_back(in_packet);
if is_last {
let mut payload: Vec<u8> = stream
.in_bound_packet_cache
.drain(..)
.map(|i| i.payload)
.flatten()
.collect();
let in_message = encode
.deserialize::<T>(&mut payload)
.or(Err(CovertError::DeserializeError))?;
stream.message_cache.push_back(in_message);
return Ok((stream_id, true));
}
};
return Ok((stream_id, false));
}
pub fn packets_in_queue(&self, stream_id: u16) -> (usize, usize) {
if let Some(stream) = self.streams.get(&stream_id) {
return (
stream.in_bound_packet_cache.len(),
stream.out_bound_packet_cache.len(),
);
}
return (0, 0);
}
fn get_stream_by_id(&mut self, stream_id: u16) -> &mut CovertStream<T> {
if !self.streams.contains_key(&stream_id) {
let new_stream = CovertStream {
out_bound_packet_cache: VecDeque::new(),
in_bound_packet_cache: VecDeque::new(),
message_cache: VecDeque::new(),
out_count: 0,
out_syn: 0,
in_syn: 0,
};
self.streams.insert(stream_id, new_stream);
}
self.streams.get_mut(&stream_id).unwrap()
}
}
#[cfg(test)]
mod test;