use crate::error::ZmqError;
use crate::message::Msg;
use crate::protocol::zmtp::{ZmtpCodec, manual_parser::ZmtpManualParser};
use crate::security::IDataCipher;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio_util::codec::Encoder;
pub(crate) trait ISecureFramer: Send + Sync + 'static {
fn try_read_msg(&mut self, network_buffer: &mut BytesMut) -> Result<Option<Msg>, ZmqError>;
fn write_msg_multipart(&mut self, msgs: Vec<Msg>) -> Result<Bytes, ZmqError>;
}
pub(crate) struct NullFramer {
parser: ZmtpManualParser,
}
impl NullFramer {
pub(crate) fn new(max_msg_size: i64) -> Self {
Self {
parser: ZmtpManualParser::new(max_msg_size),
}
}
}
impl ISecureFramer for NullFramer {
fn try_read_msg(&mut self, network_buffer: &mut BytesMut) -> Result<Option<Msg>, ZmqError> {
self.parser.decode_from_buffer(network_buffer)
}
fn write_msg_multipart(&mut self, msgs: Vec<Msg>) -> Result<Bytes, ZmqError> {
let mut codec = ZmtpCodec::new();
let mut buffer = BytesMut::new();
for msg in msgs {
codec.encode(msg, &mut buffer)?;
}
Ok(buffer.freeze())
}
}
pub(crate) struct LengthPrefixedFramer {
cipher: Box<dyn IDataCipher>,
parser: ZmtpManualParser,
decrypted_buffer: BytesMut,
}
impl LengthPrefixedFramer {
pub(crate) fn new(cipher: Box<dyn IDataCipher>, max_msg_size: i64) -> Self {
Self {
cipher,
parser: ZmtpManualParser::new(max_msg_size),
decrypted_buffer: BytesMut::with_capacity(65536 * 2),
}
}
}
impl ISecureFramer for LengthPrefixedFramer {
fn try_read_msg(&mut self, network_buffer: &mut BytesMut) -> Result<Option<Msg>, ZmqError> {
loop {
if let Some(msg) = self.parser.decode_from_buffer(&mut self.decrypted_buffer)? {
return Ok(Some(msg));
}
if network_buffer.len() < 2 {
return Ok(None); }
let len = network_buffer.as_ref().get_u16() as usize;
if network_buffer.len() < 2 + len {
return Ok(None); }
network_buffer.advance(2); let encrypted_frame = network_buffer.split_to(len);
let plaintext = self.cipher.decrypt(&encrypted_frame)?;
self.decrypted_buffer.extend_from_slice(&plaintext);
}
}
fn write_msg_multipart(&mut self, msgs: Vec<Msg>) -> Result<Bytes, ZmqError> {
let mut codec = ZmtpCodec::new();
let mut plaintext_buffer = BytesMut::new();
for msg in msgs {
codec.encode(msg, &mut plaintext_buffer)?;
}
let ciphertext = self.cipher.encrypt(&plaintext_buffer)?;
let mut final_buffer = BytesMut::with_capacity(2 + ciphertext.len());
final_buffer.put_u16(ciphertext.len() as u16);
final_buffer.extend_from_slice(&ciphertext);
Ok(final_buffer.freeze())
}
}