use core::ffi::c_void;
use core::marker::PhantomData;
use serde::de::DeserializeOwned;
use serde::Serialize;
#[repr(C)]
pub struct FidiusStreamHandle {
pub next: unsafe extern "C" fn(*mut FidiusStreamHandle, *mut u8, u32, *mut u32) -> i32,
pub drop_fn: unsafe extern "C" fn(*mut FidiusStreamHandle),
pub state: *mut c_void,
}
pub enum NextStatus {
Item(usize),
End,
TooSmall(usize),
SerErr,
}
pub struct StreamState<T> {
stream: crate::stream_marker::Stream<T>,
pending: Option<T>,
}
impl<T: Serialize> StreamState<T> {
pub fn new(stream: crate::stream_marker::Stream<T>) -> Self {
Self {
stream,
pending: None,
}
}
pub fn next_into(&mut self, buf: &mut [u8]) -> NextStatus {
if self.pending.is_none() {
match self.stream.next_item() {
Some(item) => self.pending = Some(item),
None => return NextStatus::End,
}
}
let item = self.pending.as_ref().unwrap();
let size = match crate::wire::serialized_size(item) {
Ok(s) => s as usize,
Err(_) => return NextStatus::SerErr,
};
if size > buf.len() {
return NextStatus::TooSmall(size);
}
if crate::wire::serialize_into(&mut buf[..size], item).is_err() {
return NextStatus::SerErr;
}
self.pending = None;
NextStatus::Item(size)
}
}
pub struct HostStream<T> {
handle: *mut FidiusStreamHandle,
cap: usize,
_marker: PhantomData<T>,
}
impl<T: DeserializeOwned> HostStream<T> {
pub unsafe fn from_handle(handle: *mut FidiusStreamHandle) -> Self {
Self {
handle,
cap: 256,
_marker: PhantomData,
}
}
fn pull(&mut self) -> Option<T> {
let mut buf = vec![0u8; self.cap];
loop {
let mut out_len: u32 = 0;
let status = unsafe {
((*self.handle).next)(self.handle, buf.as_mut_ptr(), self.cap as u32, &mut out_len)
};
match status {
crate::status::STATUS_OK => {
return crate::wire::deserialize::<T>(&buf[..out_len as usize]).ok();
}
crate::status::STATUS_BUFFER_TOO_SMALL => {
self.cap = (out_len as usize).max(self.cap * 2);
buf = vec![0u8; self.cap];
}
_ => return None,
}
}
}
}
impl<T: DeserializeOwned> Iterator for HostStream<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.pull()
}
}
unsafe impl<T> Send for HostStream<T> {}
impl<T> Drop for HostStream<T> {
fn drop(&mut self) {
unsafe { ((*self.handle).drop_fn)(self.handle) };
}
}
#[cfg(test)]
mod host_stream_tests {
use super::*;
struct MockProducer {
items: Vec<u64>,
idx: usize,
}
unsafe extern "C" fn mock_next(
h: *mut FidiusStreamHandle,
buf: *mut u8,
cap: u32,
out_len: *mut u32,
) -> i32 {
let p = &mut *((*h).state as *mut MockProducer);
if p.idx >= p.items.len() {
return crate::status::STATUS_STREAM_END;
}
let bytes = crate::wire::serialize(&p.items[p.idx]).unwrap();
if bytes.len() > cap as usize {
*out_len = bytes.len() as u32;
return crate::status::STATUS_BUFFER_TOO_SMALL;
}
core::ptr::copy_nonoverlapping(bytes.as_ptr(), buf, bytes.len());
*out_len = bytes.len() as u32;
p.idx += 1;
crate::status::STATUS_OK
}
unsafe extern "C" fn mock_drop(h: *mut FidiusStreamHandle) {
drop(Box::from_raw((*h).state as *mut MockProducer));
drop(Box::from_raw(h));
}
fn mock_handle(items: Vec<u64>) -> *mut FidiusStreamHandle {
let producer = Box::into_raw(Box::new(MockProducer { items, idx: 0 }));
Box::into_raw(Box::new(FidiusStreamHandle {
next: mock_next,
drop_fn: mock_drop,
state: producer as *mut c_void,
}))
}
#[test]
fn host_stream_consumes_all_items_then_drops_cleanly() {
let h = mock_handle(vec![10u64, 20, 30]);
let consumer = unsafe { HostStream::<u64>::from_handle(h) };
let got: Vec<u64> = consumer.collect();
assert_eq!(got, vec![10, 20, 30]);
}
}