use std::ffi::CString;
use std::io::Cursor;
use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_ipc::reader::StreamReader as IpcStreamReader;
use arrow_ipc::writer::StreamWriter as IpcStreamWriter;
use arrow_schema::{DataType, Schema};
use crate::errors::{Result, RpcError};
use crate::metadata::{LOG_LEVEL_KEY, SHM_LENGTH_KEY, SHM_OFFSET_KEY, SHM_SOURCE_KEY};
use crate::wire::{self, Metadata};
pub const HEADER_SIZE: usize = 65_536;
const MAGIC: &[u8; 4] = b"VGIS";
const VERSION: u32 = 1;
const HEADER_FIXED_SIZE: usize = 24;
const ALLOC_ENTRY_SIZE: usize = 16;
pub const MAX_ALLOCS: usize = (HEADER_SIZE - HEADER_FIXED_SIZE) / ALLOC_ENTRY_SIZE;
const IPC_EOS: [u8; 8] = [0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00];
struct PosixShm {
name: String,
ptr: *mut u8,
size: usize,
track: bool,
}
unsafe impl Send for PosixShm {}
unsafe impl Sync for PosixShm {}
impl std::panic::RefUnwindSafe for PosixShm {}
#[cfg(unix)]
fn make_shm_name() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let pid = unsafe { libc::getpid() };
format!("/vgi_rpc_{:x}_{:x}", pid, nanos as u64)
}
#[cfg(unix)]
impl PosixShm {
fn create(size: usize) -> Result<Self> {
for _ in 0..16 {
let name = make_shm_name();
match Self::try_create(&name, size) {
Ok(h) => return Ok(h),
Err(e) if e.error_type == "AlreadyExists" => continue,
Err(e) => return Err(e),
}
}
Err(RpcError::new("IOError", "shm_open exhausted name retries"))
}
fn try_create(name: &str, size: usize) -> Result<Self> {
let cname = CString::new(name)
.map_err(|e| RpcError::new("ValueError", format!("invalid shm name: {e}")))?;
let fd = unsafe {
libc::shm_open(
cname.as_ptr(),
libc::O_RDWR | libc::O_CREAT | libc::O_EXCL,
0o600,
)
};
if fd < 0 {
let err = std::io::Error::last_os_error();
let kind = if err.raw_os_error() == Some(libc::EEXIST) {
"AlreadyExists"
} else {
"IOError"
};
return Err(RpcError::new(kind, format!("shm_open(create): {err}")));
}
let rc = unsafe { libc::ftruncate(fd, size as libc::off_t) };
if rc != 0 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(fd);
libc::shm_unlink(cname.as_ptr());
}
return Err(RpcError::new("IOError", format!("ftruncate: {err}")));
}
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
fd,
0,
)
};
unsafe { libc::close(fd) };
if ptr == libc::MAP_FAILED {
let err = std::io::Error::last_os_error();
unsafe { libc::shm_unlink(cname.as_ptr()) };
return Err(RpcError::new("IOError", format!("mmap: {err}")));
}
Ok(Self {
name: name.to_string(),
ptr: ptr as *mut u8,
size,
track: true,
})
}
fn attach(name: &str, size: usize, track: bool) -> Result<Self> {
let try_open = |candidate: &str| -> std::io::Result<i32> {
let cname = CString::new(candidate).map_err(std::io::Error::other)?;
let fd = unsafe { libc::shm_open(cname.as_ptr(), libc::O_RDWR, 0o600) };
if fd < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(fd)
}
};
let fd = match try_open(name) {
Ok(fd) => fd,
Err(first) if !name.starts_with('/') => match try_open(&format!("/{name}")) {
Ok(fd) => fd,
Err(second) => {
return Err(RpcError::new(
"IOError",
format!("shm_open(attach) {name:?}: {first}; with leading slash: {second}"),
));
}
},
Err(e) => {
return Err(RpcError::new("IOError", format!("shm_open(attach): {e}")));
}
};
if fd < 0 {
let err = std::io::Error::last_os_error();
return Err(RpcError::new("IOError", format!("shm_open(attach): {err}")));
}
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
fd,
0,
)
};
unsafe { libc::close(fd) };
if ptr == libc::MAP_FAILED {
let err = std::io::Error::last_os_error();
return Err(RpcError::new("IOError", format!("mmap: {err}")));
}
Ok(Self {
name: name.to_string(),
ptr: ptr as *mut u8,
size,
track,
})
}
fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
}
#[allow(clippy::mut_from_ref)]
fn as_mut_slice(&self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
}
}
#[cfg(unix)]
impl Drop for PosixShm {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { libc::munmap(self.ptr as *mut libc::c_void, self.size) };
}
if self.track {
if let Ok(cname) = CString::new(self.name.as_str()) {
unsafe { libc::shm_unlink(cname.as_ptr()) };
}
}
}
}
pub struct ShmAllocator {
shm: Arc<PosixShm>,
total_size: usize,
}
impl ShmAllocator {
fn initialize(shm: &PosixShm) -> Result<()> {
if shm.size <= HEADER_SIZE {
return Err(RpcError::new(
"ValueError",
format!("segment size must be > {HEADER_SIZE}, got {}", shm.size),
));
}
let data_size = (shm.size - HEADER_SIZE) as u64;
let buf = shm.as_mut_slice();
buf[0..4].copy_from_slice(MAGIC);
buf[4..8].copy_from_slice(&VERSION.to_le_bytes());
buf[8..16].copy_from_slice(&data_size.to_le_bytes());
buf[16..20].copy_from_slice(&0u32.to_le_bytes());
buf[20..24].copy_from_slice(&0u32.to_le_bytes());
Ok(())
}
fn attach(shm: Arc<PosixShm>) -> Result<Self> {
let buf = shm.as_slice();
if &buf[0..4] != MAGIC {
return Err(RpcError::new(
"ValueError",
format!("bad SHM magic: {:?}", &buf[0..4]),
));
}
let version = u32::from_le_bytes(buf[4..8].try_into().unwrap());
if version != VERSION {
return Err(RpcError::new(
"ValueError",
format!("unsupported SHM version: {version}"),
));
}
let data_size = u64::from_le_bytes(buf[8..16].try_into().unwrap()) as usize;
let expected = shm.size - HEADER_SIZE;
if data_size != expected {
return Err(RpcError::new(
"ValueError",
format!("data_size mismatch: header says {data_size}, expected {expected}"),
));
}
let total_size = shm.size;
Ok(Self { shm, total_size })
}
pub fn num_allocs(&self) -> usize {
let buf = self.shm.as_slice();
u32::from_le_bytes(buf[16..20].try_into().unwrap()) as usize
}
pub fn max_allocs(&self) -> usize {
MAX_ALLOCS
}
pub fn dump_allocs(&self) -> Vec<(u64, u64)> {
self.read_allocs()
}
fn read_allocs(&self) -> Vec<(u64, u64)> {
let buf = self.shm.as_slice();
let n = u32::from_le_bytes(buf[16..20].try_into().unwrap()) as usize;
let mut out = Vec::with_capacity(n);
for i in 0..n {
let base = HEADER_FIXED_SIZE + i * ALLOC_ENTRY_SIZE;
let off = u64::from_le_bytes(buf[base..base + 8].try_into().unwrap());
let len = u64::from_le_bytes(buf[base + 8..base + 16].try_into().unwrap());
out.push((off, len));
}
out
}
fn write_allocs(&self, allocs: &[(u64, u64)]) {
let buf = self.shm.as_mut_slice();
buf[16..20].copy_from_slice(&(allocs.len() as u32).to_le_bytes());
for (i, (off, len)) in allocs.iter().enumerate() {
let base = HEADER_FIXED_SIZE + i * ALLOC_ENTRY_SIZE;
buf[base..base + 8].copy_from_slice(&off.to_le_bytes());
buf[base + 8..base + 16].copy_from_slice(&len.to_le_bytes());
}
}
pub fn allocate(&self, size: usize) -> Option<u64> {
if size == 0 {
return None;
}
let mut allocs = self.read_allocs();
if allocs.len() >= MAX_ALLOCS {
return None;
}
allocs.sort_by_key(|(o, _)| *o);
let size = size as u64;
let data_end = self.total_size as u64;
let mut prev_end = HEADER_SIZE as u64;
for (i, (off, len)) in allocs.iter().enumerate() {
if *off < prev_end {
prev_end = prev_end.max(off.saturating_add(*len));
continue;
}
let gap = off - prev_end;
if gap >= size && prev_end.saturating_add(size) <= data_end {
allocs.insert(i, (prev_end, size));
self.write_allocs(&allocs);
return Some(prev_end);
}
prev_end = off.saturating_add(*len);
}
if prev_end >= data_end {
return None;
}
let gap = data_end - prev_end;
if gap >= size {
allocs.push((prev_end, size));
self.write_allocs(&allocs);
return Some(prev_end);
}
None
}
pub fn free(&self, offset: u64) -> Result<()> {
let mut allocs = self.read_allocs();
if let Some(idx) = allocs.iter().position(|(o, _)| *o == offset) {
allocs.remove(idx);
self.write_allocs(&allocs);
Ok(())
} else {
Err(RpcError::new(
"ValueError",
format!("no allocation at offset {offset}"),
))
}
}
pub fn shrink(&self, offset: u64, new_len: usize) -> Result<()> {
let mut allocs = self.read_allocs();
let idx = allocs
.iter()
.position(|(o, _)| *o == offset)
.ok_or_else(|| {
RpcError::new("ValueError", format!("no allocation at offset {offset}"))
})?;
let cur = allocs[idx].1;
if (new_len as u64) > cur {
return Err(RpcError::new(
"ValueError",
format!("shrink: new_len {new_len} > reservation {cur}"),
));
}
allocs[idx].1 = new_len as u64;
self.write_allocs(&allocs);
Ok(())
}
pub fn reset(&self) {
let buf = self.shm.as_mut_slice();
buf[16..20].copy_from_slice(&0u32.to_le_bytes());
}
}
pub struct ShmSegment {
shm: Arc<PosixShm>,
allocator: ShmAllocator,
}
impl ShmSegment {
pub fn create(size: usize) -> Result<Self> {
let shm = Arc::new(PosixShm::create(size)?);
ShmAllocator::initialize(&shm)?;
let allocator = ShmAllocator::attach(shm.clone())?;
Ok(Self { shm, allocator })
}
pub fn attach(name: &str, size: usize, track: bool) -> Result<Self> {
let shm = Arc::new(PosixShm::attach(name, size, track)?);
let allocator = ShmAllocator::attach(shm.clone())?;
Ok(Self { shm, allocator })
}
pub fn name(&self) -> &str {
&self.shm.name
}
pub fn size(&self) -> usize {
self.shm.size
}
pub fn allocator(&self) -> &ShmAllocator {
&self.allocator
}
pub fn reset(&self) {
self.allocator.reset();
}
pub fn read_bytes(&self, offset: u64, length: usize) -> Result<Vec<u8>> {
let off: usize = offset
.try_into()
.map_err(|_| RpcError::new("ValueError", "shm offset overflow"))?;
let end = off
.checked_add(length)
.ok_or_else(|| RpcError::new("ValueError", "shm region overflow"))?;
if end > self.shm.size {
return Err(RpcError::new(
"ValueError",
format!(
"shm region {off}..{end} out of bounds (size={})",
self.shm.size
),
));
}
Ok(self.shm.as_slice()[off..end].to_vec())
}
pub fn allocate_and_write(&self, batch: &RecordBatch) -> Result<Option<(u64, usize)>> {
let schema = batch.schema();
if schema_has_dictionary(schema.as_ref()) {
let bytes = serialize_for_shm(batch, schema.as_ref())?;
let size = bytes.len();
let Some(offset) = self.allocator.allocate(size) else {
return Ok(None);
};
let off = offset as usize;
let end = off.checked_add(size).filter(|e| *e <= self.shm.size);
let Some(end) = end else {
let _ = self.allocator.free(offset);
return Err(RpcError::new(
"ValueError",
"shm allocator returned out-of-bounds offset",
));
};
let dst = &mut self.shm.as_mut_slice()[off..end];
dst.copy_from_slice(&bytes);
return Ok(Some((offset, size)));
}
let body_estimate = batch.get_array_memory_size().saturating_mul(2);
let reserve = body_estimate.saturating_add(65_536);
let max_region = self.shm.size.saturating_sub(HEADER_SIZE);
let reserve = reserve.min(max_region);
if reserve == 0 {
return Ok(None);
}
let Some(offset) = self.allocator.allocate(reserve) else {
return Ok(None);
};
let off = offset as usize;
let end = match off.checked_add(reserve).filter(|e| *e <= self.shm.size) {
Some(e) => e,
None => {
let _ = self.allocator.free(offset);
return Err(RpcError::new(
"ValueError",
"shm allocator returned out-of-bounds offset",
));
}
};
let written: u64 = {
let dst = &mut self.shm.as_mut_slice()[off..end];
let mut cursor = Cursor::new(dst);
let result: Result<u64> = (|| {
let mut w = IpcStreamWriter::try_new(&mut cursor, schema.as_ref())
.map_err(RpcError::from)?;
w.write(batch).map_err(RpcError::from)?;
w.finish().map_err(RpcError::from)?;
Ok(cursor.position())
})();
match result {
Ok(n) => n,
Err(e) => {
let _ = self.allocator.free(offset);
if e.message.contains("failed to write whole buffer")
|| e.message.contains("write zero")
{
return Ok(None);
}
return Err(e);
}
}
};
let actual = written as usize;
self.allocator.shrink(offset, actual)?;
Ok(Some((offset, actual)))
}
pub fn read_batch(&self, offset: u64, length: usize, schema: &Schema) -> Result<RecordBatch> {
let off: usize = offset
.try_into()
.map_err(|_| RpcError::new("ValueError", "shm offset overflow"))?;
let end = off
.checked_add(length)
.ok_or_else(|| RpcError::new("ValueError", "shm region overflow"))?;
if end > self.shm.size {
return Err(RpcError::new(
"ValueError",
format!("shm region out of bounds: {off}..{end} > {}", self.shm.size),
));
}
let region: &[u8] = &self.shm.as_slice()[off..end];
deserialize_from_shm(region, schema)
}
pub fn free(&self, offset: u64) -> Result<()> {
self.allocator.free(offset)
}
}
fn schema_has_dictionary(schema: &Schema) -> bool {
schema
.fields()
.iter()
.any(|f| matches!(f.data_type(), DataType::Dictionary(_, _)))
}
fn serialize_for_shm(batch: &RecordBatch, schema: &Schema) -> Result<Vec<u8>> {
let mut buf = Vec::new();
{
let mut w = IpcStreamWriter::try_new(&mut buf, schema).map_err(RpcError::from)?;
w.write(batch).map_err(RpcError::from)?;
w.finish().map_err(RpcError::from)?;
}
if !schema_has_dictionary(schema) {
return Ok(buf);
}
let after_schema = skip_one_ipc_message(&buf)?;
let trimmed = strip_trailing_eos(&buf[after_schema..])?;
Ok(trimmed.to_vec())
}
fn deserialize_from_shm(region: &[u8], schema: &Schema) -> Result<RecordBatch> {
if !schema_has_dictionary(schema) {
let mut r = IpcStreamReader::try_new(Cursor::new(region), None).map_err(RpcError::from)?;
let batch = r
.next()
.ok_or_else(|| RpcError::new("IPC", "empty SHM region"))?
.map_err(RpcError::from)?;
return Ok(batch);
}
let mut schema_only = Vec::new();
{
let mut w = IpcStreamWriter::try_new(&mut schema_only, schema).map_err(RpcError::from)?;
w.finish().map_err(RpcError::from)?;
}
let schema_msg_len = schema_only
.len()
.checked_sub(IPC_EOS.len())
.ok_or_else(|| RpcError::new("IPC", "schema-only stream too short"))?;
let mut combined = Vec::with_capacity(schema_msg_len + region.len() + IPC_EOS.len());
combined.extend_from_slice(&schema_only[..schema_msg_len]);
combined.extend_from_slice(region);
combined.extend_from_slice(&IPC_EOS);
let mut r = IpcStreamReader::try_new(Cursor::new(combined), None).map_err(RpcError::from)?;
let batch = r
.next()
.ok_or_else(|| RpcError::new("IPC", "empty SHM region"))?
.map_err(RpcError::from)?;
Ok(batch)
}
fn skip_one_ipc_message(buf: &[u8]) -> Result<usize> {
if buf.len() < 8 {
return Err(RpcError::new("IPC", "IPC stream too short"));
}
let mut pos = 0;
if buf[pos..pos + 4] == [0xFF, 0xFF, 0xFF, 0xFF] {
pos += 4;
}
if pos + 4 > buf.len() {
return Err(RpcError::new("IPC", "IPC stream truncated"));
}
let meta_len = u32::from_le_bytes(buf[pos..pos + 4].try_into().unwrap()) as usize;
pos += 4;
if meta_len == 0 {
return Err(RpcError::new("IPC", "unexpected EOS while skipping schema"));
}
if pos + meta_len > buf.len() {
return Err(RpcError::new("IPC", "IPC metadata truncated"));
}
let msg = arrow_ipc::root_as_message(&buf[pos..pos + meta_len])
.map_err(|e| RpcError::new("IPC", format!("parse schema message: {e}")))?;
let body_len = msg.bodyLength() as usize;
Ok(pos + meta_len + body_len)
}
fn strip_trailing_eos(buf: &[u8]) -> Result<&[u8]> {
if buf.len() < IPC_EOS.len() || buf[buf.len() - IPC_EOS.len()..] != IPC_EOS {
return Err(RpcError::new("IPC", "stream missing trailing EOS marker"));
}
Ok(&buf[..buf.len() - IPC_EOS.len()])
}
pub fn make_shm_pointer_batch(
schema: &Schema,
offset: u64,
length: usize,
) -> Result<(RecordBatch, Metadata)> {
let batch = wire::empty_batch(schema)?;
let mut md: Metadata = std::collections::HashMap::new();
md.insert(SHM_OFFSET_KEY.into(), offset.to_string());
md.insert(SHM_LENGTH_KEY.into(), length.to_string());
Ok((batch, md))
}
pub fn is_shm_pointer_batch(batch: &RecordBatch, md: &Metadata) -> bool {
if batch.num_rows() != 0 {
return false;
}
md.contains_key(SHM_OFFSET_KEY) && !md.contains_key(LOG_LEVEL_KEY)
}
pub struct ResolvedShm {
pub batch: RecordBatch,
pub metadata: Metadata,
pub release_offset: Option<u64>,
}
pub fn resolve_shm_batch(
batch: RecordBatch,
md: Metadata,
shm: Option<&ShmSegment>,
) -> Result<ResolvedShm> {
let Some(shm) = shm else {
return Ok(ResolvedShm {
batch,
metadata: md,
release_offset: None,
});
};
if !is_shm_pointer_batch(&batch, &md) {
return Ok(ResolvedShm {
batch,
metadata: md,
release_offset: None,
});
}
let offset: u64 = md
.get(SHM_OFFSET_KEY)
.and_then(|s| s.parse().ok())
.ok_or_else(|| RpcError::new("ValueError", "bad shm_offset"))?;
let length: usize = md
.get(SHM_LENGTH_KEY)
.and_then(|s| s.parse().ok())
.ok_or_else(|| RpcError::new("ValueError", "bad shm_length"))?;
let resolved = shm.read_batch(offset, length, batch.schema().as_ref())?;
let mut new_md = md.clone();
new_md.remove(SHM_OFFSET_KEY);
new_md.remove(SHM_LENGTH_KEY);
new_md.insert(SHM_SOURCE_KEY.into(), shm.name().to_string());
Ok(ResolvedShm {
batch: resolved,
metadata: new_md,
release_offset: Some(offset),
})
}
pub fn maybe_write_to_shm(
batch: RecordBatch,
batch_md: Metadata,
shm: Option<&ShmSegment>,
) -> Result<(RecordBatch, Metadata)> {
let Some(shm) = shm else {
return Ok((batch, batch_md));
};
if batch.num_rows() == 0 {
return Ok((batch, batch_md));
}
let Some((offset, length)) = shm.allocate_and_write(&batch)? else {
return Ok((batch, batch_md));
};
let (pointer, pointer_md) = make_shm_pointer_batch(batch.schema().as_ref(), offset, length)?;
let mut merged = batch_md;
for (k, v) in pointer_md.into_iter() {
merged.insert(k, v);
}
Ok((pointer, merged))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{types::Int32Type, DictionaryArray, Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
fn small_seg() -> ShmSegment {
ShmSegment::create(HEADER_SIZE + 64 * 1024).expect("create")
}
#[test]
fn allocator_first_fit_and_coalesce() {
let seg = small_seg();
let a = seg.allocator().allocate(100).expect("a");
let b = seg.allocator().allocate(100).expect("b");
let c = seg.allocator().allocate(100).expect("c");
assert!(b > a && c > b);
seg.allocator().free(b).unwrap();
let b2 = seg.allocator().allocate(100).expect("b2");
assert_eq!(b, b2);
seg.allocator().free(b2).unwrap();
seg.allocator().free(c).unwrap();
let big = seg.allocator().allocate(200).expect("merged");
assert_eq!(big, b);
}
#[test]
fn allocator_rejects_corrupted_unsorted_table() {
let seg = small_seg();
let total = seg.size();
let buf = seg.shm.as_mut_slice();
buf[16..20].copy_from_slice(&2u32.to_le_bytes());
let off0 = (total - 8) as u64;
buf[24..32].copy_from_slice(&off0.to_le_bytes());
buf[32..40].copy_from_slice(&8u64.to_le_bytes());
let off1 = HEADER_SIZE as u64;
buf[40..48].copy_from_slice(&off1.to_le_bytes());
buf[48..56].copy_from_slice(&8u64.to_le_bytes());
if let Some(off) = seg.allocator().allocate(total) {
assert!(
(off as usize)
.checked_add(total)
.map(|e| e <= total)
.unwrap_or(false),
"allocator returned out-of-bounds offset {off} for size {total} (segment {total})",
);
}
if let Some(off) = seg.allocator().allocate(1024) {
assert!((off as usize) + 1024 <= total);
}
}
#[test]
fn allocator_returns_none_when_full() {
let seg = small_seg();
assert!(seg.allocator().allocate(65 * 1024).is_none());
}
#[test]
fn roundtrip_non_dict_batch() {
let seg = ShmSegment::create(HEADER_SIZE + 1024 * 1024).unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["x", "yy", "zzz"])),
],
)
.unwrap();
let (pointer, pointer_md) =
maybe_write_to_shm(batch.clone(), Metadata::new(), Some(&seg)).unwrap();
assert!(is_shm_pointer_batch(&pointer, &pointer_md));
let resolved = resolve_shm_batch(pointer, pointer_md, Some(&seg)).unwrap();
assert_eq!(resolved.batch.num_rows(), 3);
assert_eq!(resolved.batch.schema(), batch.schema());
assert_eq!(
resolved.metadata.get(SHM_SOURCE_KEY).map(String::as_str),
Some(seg.name()),
);
let off = resolved.release_offset.expect("release_offset must be set");
let _ = seg.allocator().free(off);
}
#[test]
fn roundtrip_dict_batch() {
let seg = ShmSegment::create(HEADER_SIZE + 1024 * 1024).unwrap();
let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let schema = Arc::new(Schema::new(vec![Field::new("d", dict_type, false)]));
let values = StringArray::from(vec!["alpha", "beta"]);
let keys = arrow_array::Int32Array::from(vec![0, 1, 0, 1, 0]);
let dict = DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values)).unwrap();
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)]).unwrap();
let (pointer, pointer_md) =
maybe_write_to_shm(batch.clone(), Metadata::new(), Some(&seg)).unwrap();
assert!(is_shm_pointer_batch(&pointer, &pointer_md));
let resolved = resolve_shm_batch(pointer, pointer_md, Some(&seg)).unwrap();
assert_eq!(resolved.batch.num_rows(), 5);
assert_eq!(resolved.batch.schema(), batch.schema());
}
#[test]
fn pointer_batch_distinct_from_log_batch() {
let schema = Schema::empty();
let mut md: Metadata = std::collections::HashMap::new();
md.insert(SHM_OFFSET_KEY.into(), "0".into());
md.insert(LOG_LEVEL_KEY.into(), "INFO".into());
let log_batch = wire::empty_batch(&schema).unwrap();
assert!(!is_shm_pointer_batch(&log_batch, &md));
}
#[test]
fn attach_existing_segment() {
let owner = ShmSegment::create(HEADER_SIZE + 64 * 1024).unwrap();
let attached = ShmSegment::attach(owner.name(), owner.size(), false).expect("attach");
let off = attached.allocator().allocate(123).expect("alloc");
let seen = owner.allocator().read_allocs();
assert_eq!(seen, vec![(off, 123)]);
}
const SHM_HEADER_GOLDEN_HEX: &str =
"5647495301000000004000000000000002000000000000000000010000000000\
640000000000000000010100000000003200000000000000";
#[test]
fn header_layout_matches_canonical_golden_bytes() {
let seg = ShmSegment::create(HEADER_SIZE + 16384).unwrap();
let off0 = seg.allocator().allocate(100).expect("alloc 0");
seg.allocator().free(off0).unwrap();
let _a = seg.allocator().allocate(100).expect("a");
let _gap = seg.allocator().allocate(156).expect("gap");
let _b = seg.allocator().allocate(50).expect("b");
seg.allocator().free(_gap).unwrap();
let buf = seg.shm.as_slice();
let observed = &buf[..56];
let golden_hex: String = SHM_HEADER_GOLDEN_HEX
.chars()
.filter(|c| !c.is_whitespace())
.collect();
let golden: Vec<u8> = (0..golden_hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&golden_hex[i..i + 2], 16).unwrap())
.collect();
assert_eq!(
observed,
golden.as_slice(),
"SHM header layout drifted; if intentional, update both this \
test and tests/test_shm_header_format.py in vgi-rpc canonical",
);
}
#[test]
fn shrink_returns_tail_to_free_pool() {
let seg = ShmSegment::create(HEADER_SIZE + 64 * 1024).unwrap();
let off = seg.allocator().allocate(8 * 1024).expect("first alloc");
seg.allocator().shrink(off, 1024).expect("shrink");
let off2 = seg.allocator().allocate(4 * 1024).expect("second alloc");
assert_eq!(off2, off + 1024);
assert_eq!(seg.allocator().num_allocs(), 2);
}
#[test]
fn shrink_rejects_invalid_inputs() {
let seg = small_seg();
let off = seg.allocator().allocate(1024).unwrap();
assert!(seg.allocator().shrink(off, 2048).is_err());
assert!(seg.allocator().shrink(off + 1, 100).is_err());
assert_eq!(seg.allocator().read_allocs(), vec![(off, 1024)]);
}
#[test]
fn inplace_write_is_byte_compatible_with_legacy_path() {
let seg = ShmSegment::create(HEADER_SIZE + 1024 * 1024).unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![10, 20, 30])),
Arc::new(StringArray::from(vec!["a", "bb", "ccc"])),
],
)
.unwrap();
let expected = serialize_for_shm(&batch, schema.as_ref()).unwrap();
let (offset, length) = seg.allocate_and_write(&batch).unwrap().expect("written");
assert_eq!(length, expected.len(), "length must match legacy path");
let region = seg.read_bytes(offset, length).unwrap();
assert_eq!(region, expected, "bytes must match legacy path");
}
#[test]
fn inplace_write_falls_back_when_segment_too_small() {
let seg = ShmSegment::create(HEADER_SIZE + 4 * 1024).unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
let big = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int64Array::from_iter_values(0..200_000))],
)
.unwrap();
assert!(seg.allocate_and_write(&big).unwrap().is_none());
assert_eq!(seg.allocator().num_allocs(), 0);
}
#[test]
fn inplace_write_then_resolve_round_trip() {
let seg = ShmSegment::create(HEADER_SIZE + 1024 * 1024).unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int64Array::from_iter_values(0..1024))],
)
.unwrap();
let (pointer, pointer_md) =
maybe_write_to_shm(batch.clone(), Metadata::new(), Some(&seg)).unwrap();
assert!(is_shm_pointer_batch(&pointer, &pointer_md));
assert_eq!(seg.allocator().num_allocs(), 1);
let resolved = resolve_shm_batch(pointer, pointer_md, Some(&seg)).unwrap();
assert_eq!(resolved.batch.num_rows(), 1024);
let off = resolved.release_offset.expect("release_offset must be set");
drop(resolved);
seg.allocator().free(off).unwrap();
assert_eq!(seg.allocator().num_allocs(), 0);
}
#[test]
fn resolve_rejects_pointer_with_out_of_bounds_region() {
let seg = ShmSegment::create(HEADER_SIZE + 64 * 1024).unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
let mut md: Metadata = std::collections::HashMap::new();
md.insert(SHM_OFFSET_KEY.into(), seg.size().to_string());
md.insert(SHM_LENGTH_KEY.into(), "1024".into());
let bogus = wire::empty_batch(schema.as_ref()).unwrap();
let err = match resolve_shm_batch(bogus, md, Some(&seg)) {
Err(e) => e,
Ok(_) => panic!("must reject"),
};
assert_eq!(err.error_type, "ValueError");
}
#[test]
fn resolve_rejects_pointer_with_unparseable_metadata() {
let seg = small_seg();
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
let mut md: Metadata = std::collections::HashMap::new();
md.insert(SHM_OFFSET_KEY.into(), "not-a-number".into());
md.insert(SHM_LENGTH_KEY.into(), "100".into());
let bogus = wire::empty_batch(schema.as_ref()).unwrap();
let err = match resolve_shm_batch(bogus, md, Some(&seg)) {
Err(e) => e,
Ok(_) => panic!("must reject"),
};
assert_eq!(err.error_type, "ValueError");
}
#[test]
fn read_batch_overflow_is_rejected_not_panicked() {
let seg = small_seg();
let schema = Schema::new(vec![Field::new("v", DataType::Int64, false)]);
let err = seg
.read_batch(0, usize::MAX, &schema)
.expect_err("must reject");
assert_eq!(err.error_type, "ValueError");
}
#[test]
fn attach_rejects_bad_magic() {
let seg = ShmSegment::create(HEADER_SIZE + 4096).unwrap();
seg.shm.as_mut_slice()[0..4].copy_from_slice(b"XXXX");
let err = match ShmSegment::attach(seg.name(), seg.size(), false) {
Err(e) => e,
Ok(_) => panic!("must reject bad magic"),
};
assert_eq!(err.error_type, "ValueError");
}
#[test]
fn attach_rejects_bad_version() {
let seg = ShmSegment::create(HEADER_SIZE + 4096).unwrap();
seg.shm.as_mut_slice()[4..8].copy_from_slice(&999u32.to_le_bytes());
let err = match ShmSegment::attach(seg.name(), seg.size(), false) {
Err(e) => e,
Ok(_) => panic!("must reject bad version"),
};
assert_eq!(err.error_type, "ValueError");
}
#[test]
fn allocator_respects_max_allocs_capacity() {
let bytes_per = 16;
let need = HEADER_SIZE + (MAX_ALLOCS + 1) * bytes_per;
let seg = ShmSegment::create(need).unwrap();
let mut offsets = Vec::with_capacity(MAX_ALLOCS);
for _ in 0..MAX_ALLOCS {
offsets.push(seg.allocator().allocate(bytes_per).expect("alloc"));
}
assert!(seg.allocator().allocate(bytes_per).is_none());
seg.allocator().free(offsets[0]).unwrap();
assert!(seg.allocator().allocate(bytes_per).is_some());
}
proptest::proptest! {
#![proptest_config(proptest::test_runner::Config {
cases: 64,
.. proptest::test_runner::Config::default()
})]
#[test]
fn allocator_invariants_under_random_ops(
ops in proptest::collection::vec(
(0u8..3u8, 1usize..=4096usize),
1..200,
)
) {
let total = HEADER_SIZE + 64 * 1024;
let seg = ShmSegment::create(total).unwrap();
let mut live: Vec<(u64, u64)> = Vec::new();
for (op, size) in ops {
match op {
0 => {
if let Some(off) = seg.allocator().allocate(size) {
proptest::prop_assert!(off >= HEADER_SIZE as u64);
proptest::prop_assert!(
off as usize + size <= total,
"alloc returned off={off} size={size} but segment is {total}",
);
for &(o, l) in &live {
let a_end = off + size as u64;
let b_end = o + l;
proptest::prop_assert!(
a_end <= o || off >= b_end,
"overlap: new=({off},{size}) existing=({o},{l})",
);
}
live.push((off, size as u64));
}
}
1 => {
if !live.is_empty() {
let idx = size % live.len();
let (off, _) = live.remove(idx);
seg.allocator().free(off).unwrap();
}
}
_ => {
if !live.is_empty() {
let idx = size % live.len();
let (off, len) = live[idx];
let new_len = ((size as u64) % len.max(1)).max(1);
if seg.allocator().shrink(off, new_len as usize).is_ok() {
live[idx] = (off, new_len);
}
}
}
}
proptest::prop_assert_eq!(seg.allocator().num_allocs(), live.len());
}
}
}
}