use crate::api::ChannelError::AccessError;
use crate::api::{ChannelError, WriteError, Writer};
use crate::header::Header;
use crate::utils::{align, store_atomic_u64, CLOSE, REC_HEADER_LEN, WATERMARK};
use kekbit_codecs::codecs::DataFormat;
use kekbit_codecs::codecs::Encodable;
use log::{debug, error, info};
use memmap::MmapMut;
use std::cmp::min;
use std::io::Write;
use std::ptr::copy_nonoverlapping;
use std::result::Result;
use std::sync::atomic::Ordering;
pub struct ShmWriter<D: DataFormat> {
header: Header,
data_ptr: *mut u8,
write_offset: u32,
mmap: MmapMut,
df: D,
write: KekWrite,
}
impl<D: DataFormat> ShmWriter<D> {
#[allow(clippy::cast_ptr_alignment)]
pub(super) fn new(mut mmap: MmapMut, df: D) -> Result<ShmWriter<D>, ChannelError> {
let buf = &mut mmap[..];
let header = Header::read(buf)?;
let header_ptr = buf.as_ptr() as *mut u64;
let head_len = header.len();
let data_ptr = unsafe { header_ptr.add(head_len) } as *mut u8;
let write = KekWrite::new(data_ptr, header.max_msg_len() as usize);
let mut writer = ShmWriter {
header,
data_ptr,
write_offset: 0,
mmap,
df,
write,
};
info!(
"Kekbit channel writer created. Size is {}MB. Max msg size {}KB",
writer.header.capacity() / 1_000_000,
writer.header.max_msg_len() / 1_000
);
match writer.heartbeat() {
Ok(_) => {
info!("Initial hearbeat successfully sent!");
Ok(writer)
}
Err(we) => Err(AccessError {
reason: format!("Initial heartbeat failed!. Reason {:?}", we),
}),
}
}
#[inline(always)]
fn write_metadata(&mut self, write_ptr: *mut u64, len: u64, aligned_rec_len: u32) {
unsafe {
store_atomic_u64(write_ptr.add(aligned_rec_len as usize), WATERMARK, Ordering::Release);
}
store_atomic_u64(write_ptr, len, Ordering::Release);
}
}
impl<D: DataFormat> Writer<D> for ShmWriter<D> {
#[allow(clippy::cast_ptr_alignment)]
fn write(&mut self, data: &impl Encodable<D>) -> Result<u32, WriteError> {
let read_head_ptr = unsafe { self.data_ptr.add(self.write_offset as usize) };
let write_ptr = unsafe { read_head_ptr.add(REC_HEADER_LEN as usize) };
let available = self.available();
if available <= REC_HEADER_LEN {
return Err(WriteError::ChannelFull);
}
let len = min(self.header.max_msg_len(), available - REC_HEADER_LEN) as usize;
let write_res = data.encode_to(&self.df, self.write.reset(write_ptr, len));
match write_res {
Ok(0) => Err(WriteError::NoSpaceForRecord),
Ok(_) => {
if !self.write.failed {
let aligned_rec_len = align(self.write.total as u32 + REC_HEADER_LEN);
self.write_metadata(read_head_ptr as *mut u64, self.write.total as u64, aligned_rec_len >> 3);
self.write_offset += aligned_rec_len;
Ok(aligned_rec_len)
} else {
Err(WriteError::NoSpaceForRecord)
}
}
Err(io_err) => Err(WriteError::EncodingError(io_err)),
}
}
#[allow(clippy::cast_ptr_alignment)]
#[inline]
fn heartbeat(&mut self) -> Result<u32, WriteError> {
let read_head_ptr = unsafe { self.data_ptr.add(self.write_offset as usize) };
let available = self.available();
if available <= REC_HEADER_LEN {
return Err(WriteError::ChannelFull);
}
let aligned_rec_len = REC_HEADER_LEN; self.write_metadata(read_head_ptr as *mut u64, 0u64, aligned_rec_len >> 3);
self.write_offset += aligned_rec_len;
Ok(aligned_rec_len)
}
#[inline]
fn flush(&mut self) -> Result<(), std::io::Error> {
debug!("Flushing the channel");
self.mmap.flush()
}
}
impl<D: DataFormat> Drop for ShmWriter<D> {
fn drop(&mut self) {
let write_index = self.write_offset;
info!("Closing message queue..");
unsafe {
#[allow(clippy::cast_ptr_alignment)]
let write_ptr = self.data_ptr.offset(write_index as isize) as *mut u64;
store_atomic_u64(write_ptr, CLOSE, Ordering::Release);
info!("Closing message sent")
}
self.write_offset = self.mmap.len() as u32;
if self.mmap.flush().is_ok() {
info!("All changes flushed");
} else {
error!("Flush Failed");
}
}
}
impl<D: DataFormat> ShmWriter<D> {
#[inline]
pub fn available(&self) -> u32 {
(self.header.capacity() - self.write_offset) & 0xFFFF_FFF8 }
#[inline]
pub fn write_offset(&self) -> u32 {
self.write_offset
}
#[inline]
pub fn header(&self) -> &Header {
&self.header
}
#[inline]
pub fn data_format(&self) -> &D {
&self.df
}
}
struct KekWrite {
write_ptr: *mut u8,
max_size: usize,
total: usize,
failed: bool,
}
impl KekWrite {
#[inline]
fn new(write_ptr: *mut u8, max_size: usize) -> Self {
KekWrite {
write_ptr,
max_size,
total: 0,
failed: false,
}
}
#[inline]
fn reset(&mut self, write_ptr: *mut u8, max_size: usize) -> &mut Self {
self.write_ptr = write_ptr;
self.max_size = max_size;
self.total = 0;
self.failed = false;
self
}
}
impl Write for KekWrite {
#[inline]
fn write(&mut self, data: &[u8]) -> Result<usize, std::io::Error> {
if self.failed {
return Ok(0);
}
let data_len = data.len();
if self.total + data_len > self.max_size {
self.failed |= true;
return Ok(0);
}
unsafe {
let crt_ptr = self.write_ptr.add(self.total as usize);
copy_nonoverlapping(data.as_ptr(), crt_ptr, data_len);
}
self.total += data_len;
Ok(data_len)
}
#[inline]
fn flush(&mut self) -> Result<(), std::io::Error> {
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_write() {
let mut raw_data: [u8; 1000] = [0; 1000];
let write_ptr = raw_data.as_mut_ptr();
let mut kw = KekWrite::new(write_ptr, 20);
kw.flush().unwrap(); let d1: [u8; 10] = [1; 10];
let r1 = kw.write(&d1).unwrap();
assert_eq!(kw.total, r1);
assert!(!kw.failed);
for i in 0..10 {
assert_eq!(raw_data[i], 1);
}
kw.flush().unwrap(); let r2 = kw.write(&d1).unwrap();
assert_eq!(kw.total, r1 + r2);
assert!(!kw.failed);
for i in 10..20 {
assert_eq!(raw_data[i], 1);
}
let r3 = kw.write(&d1).unwrap();
assert_eq!(0, r3);
assert!(kw.failed);
kw.reset(write_ptr, 15);
assert!(!kw.failed);
let d2: [u8; 10] = [2; 10];
let r4 = kw.write(&d2).unwrap();
assert_eq!(kw.total, r4);
assert!(!kw.failed);
for i in 0..10 {
assert_eq!(raw_data[i], 2);
}
assert_eq!(kw.total, 10);
let r5 = kw.write(&d2).unwrap();
assert_eq!(0, r5);
assert!(kw.failed);
assert_eq!(kw.total, 10);
let r6 = kw.write(&d2[0..3]).unwrap();
assert_eq!(0, r6);
assert!(kw.failed);
assert_eq!(kw.total, 10);
}
}