use std::{
io::{Cursor, Read, Seek, SeekFrom, Write},
slice,
};
use crate::{
box_tracked, cimpl_free, deref_mut_or_return_int, error::C2paError, ok_or_return_int,
CimplError,
};
#[repr(C)]
#[derive(Debug)]
pub struct StreamContext;
#[repr(C)]
#[derive(Debug)]
pub enum C2paSeekMode {
Start = 0,
Current = 1,
End = 2,
}
type ReadCallback =
unsafe extern "C" fn(context: *mut StreamContext, data: *mut u8, len: isize) -> isize;
type SeekCallback =
unsafe extern "C" fn(context: *mut StreamContext, offset: isize, mode: C2paSeekMode) -> isize;
type WriteCallback =
unsafe extern "C" fn(context: *mut StreamContext, data: *const u8, len: isize) -> isize;
type FlushCallback = unsafe extern "C" fn(context: *mut StreamContext) -> isize;
#[repr(C)]
#[derive(Debug)]
pub struct C2paStream {
context: *mut StreamContext,
reader: ReadCallback,
seeker: SeekCallback,
writer: WriteCallback,
flusher: FlushCallback,
}
impl C2paStream {
pub unsafe fn new(
context: *mut StreamContext,
reader: ReadCallback,
seeker: SeekCallback,
writer: WriteCallback,
flusher: FlushCallback,
) -> Self {
Self {
context, reader,
seeker,
writer,
flusher,
}
}
pub fn extract_context(&mut self) -> Box<StreamContext> {
let context_ptr = std::mem::replace(&mut self.context, std::ptr::null_mut());
unsafe { Box::from_raw(context_ptr) }
}
}
impl Read for C2paStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if buf.len() > isize::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Read buffer is too large",
));
}
let bytes_read =
unsafe { (self.reader)(self.context, buf.as_mut_ptr(), buf.len() as isize) };
if bytes_read < 0 {
return Err(CimplError::last_message()
.map(|msg| {
let _ = CimplError::take_last(); std::io::Error::other(msg)
})
.unwrap_or_else(std::io::Error::last_os_error));
}
Ok(bytes_read as usize)
}
}
impl Seek for C2paStream {
fn seek(&mut self, from: std::io::SeekFrom) -> std::io::Result<u64> {
let (pos, mode) = match from {
std::io::SeekFrom::Current(pos) => (pos, C2paSeekMode::Current),
std::io::SeekFrom::Start(pos) => (pos as i64, C2paSeekMode::Start),
std::io::SeekFrom::End(pos) => (pos, C2paSeekMode::End),
};
let new_pos = unsafe { (self.seeker)(self.context, pos as isize, mode) };
if new_pos < 0 {
return Err(CimplError::last_message()
.map(|msg| {
let _ = CimplError::take_last(); std::io::Error::other(msg)
})
.unwrap_or_else(std::io::Error::last_os_error));
}
Ok(new_pos as u64)
}
}
impl Write for C2paStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.len() > isize::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Write buffer is too large",
));
}
let bytes_written =
unsafe { (self.writer)(self.context, buf.as_ptr(), buf.len() as isize) };
if bytes_written < 0 {
return Err(CimplError::last_message()
.map(|msg| {
let _ = CimplError::take_last(); std::io::Error::other(msg)
})
.unwrap_or_else(std::io::Error::last_os_error));
}
Ok(bytes_written as usize)
}
fn flush(&mut self) -> std::io::Result<()> {
let err = unsafe { (self.flusher)(self.context) };
if err < 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
}
unsafe impl Send for C2paStream {}
unsafe impl Sync for C2paStream {}
#[cfg(test)]
pub struct TestStream(*mut C2paStream);
#[cfg(test)]
impl TestStream {
pub fn new(data: Vec<u8>) -> Self {
Self(TestC2paStream::new(data).into_c_stream())
}
pub fn stream_mut(&mut self) -> &mut C2paStream {
unsafe { &mut *self.0 }
}
pub fn as_ptr(&mut self) -> *mut C2paStream {
self.0
}
}
#[cfg(test)]
impl Drop for TestStream {
fn drop(&mut self) {
unsafe {
TestC2paStream::drop_c_stream(self.0);
}
}
}
#[no_mangle]
pub unsafe extern "C" fn c2pa_create_stream(
context: *mut StreamContext,
reader: ReadCallback,
seeker: SeekCallback,
writer: WriteCallback,
flusher: FlushCallback,
) -> *mut C2paStream {
box_tracked!(C2paStream::new(context, reader, seeker, writer, flusher,))
}
#[no_mangle]
pub unsafe extern "C" fn c2pa_release_stream(stream: *mut C2paStream) {
cimpl_free(stream as *mut std::ffi::c_void);
}
pub struct TestC2paStream {
cursor: Cursor<Vec<u8>>,
}
impl TestC2paStream {
pub fn new(data: Vec<u8>) -> Self {
Self {
cursor: Cursor::new(data),
}
}
unsafe extern "C" fn reader(context: *mut StreamContext, data: *mut u8, len: isize) -> isize {
let stream = deref_mut_or_return_int!(context as *mut TestC2paStream, TestC2paStream);
let data: &mut [u8] = slice::from_raw_parts_mut(data, len as usize);
ok_or_return_int!(stream.cursor.read(data)) as isize
}
unsafe extern "C" fn seeker(
context: *mut StreamContext,
offset: isize,
mode: C2paSeekMode,
) -> isize {
let stream = deref_mut_or_return_int!(context as *mut TestC2paStream, TestC2paStream);
match mode {
C2paSeekMode::Start => {
if offset < 0 {
CimplError::set_last(CimplError::from(C2paError::Other(
"Offset out of bounds".to_string(),
)));
return -1;
}
stream.cursor.set_position(offset as u64);
}
C2paSeekMode::Current => match stream.cursor.seek(SeekFrom::Current(offset as i64)) {
Ok(_) => {}
Err(e) => {
CimplError::set_last(CimplError::from(C2paError::Io(e.to_string())));
return -1;
}
},
C2paSeekMode::End => match stream.cursor.seek(SeekFrom::End(offset as i64)) {
Ok(_) => {}
Err(e) => {
CimplError::set_last(CimplError::from(C2paError::Io(e.to_string())));
return -1;
}
},
}
stream.cursor.position() as isize
}
unsafe extern "C" fn flusher(_context: *mut StreamContext) -> isize {
0
}
unsafe extern "C" fn writer(context: *mut StreamContext, data: *const u8, len: isize) -> isize {
let stream: &mut TestC2paStream = &mut *(context as *mut TestC2paStream);
let data: &[u8] = slice::from_raw_parts(data, len as usize);
match stream.cursor.write(data) {
Ok(bytes) => bytes as isize,
Err(e) => {
CimplError::set_last(CimplError::from(C2paError::Io(e.to_string())));
-1
}
}
}
pub fn into_c_stream(self) -> *mut C2paStream {
unsafe {
box_tracked!(C2paStream::new(
box_tracked!(self) as *mut StreamContext,
Self::reader,
Self::seeker,
Self::writer,
Self::flusher,
))
}
}
pub fn from_bytes(data: Vec<u8>) -> *mut C2paStream {
let test_stream = Self::new(data);
test_stream.into_c_stream()
}
pub unsafe fn drop_c_stream(c_stream: *mut C2paStream) {
if !c_stream.is_null() {
let context = unsafe { (*c_stream).context };
cimpl_free(context as *mut std::ffi::c_void);
}
cimpl_free(c_stream as *mut std::ffi::c_void);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cstream_read() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
let mut buf = [0u8; 3];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 3);
assert_eq!(buf, [1, 2, 3]);
let mut buf = [0u8; 3];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 2);
assert_eq!(buf, [4, 5, 0]);
}
#[test]
fn test_cstream_seek() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
stream.stream_mut().seek(SeekFrom::Start(2)).unwrap();
let mut buf = [0u8; 3];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 3);
assert_eq!(buf, [3, 4, 5]);
stream.stream_mut().seek(SeekFrom::End(-2)).unwrap();
let mut buf = [0u8; 2];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 2);
assert_eq!(buf, [4, 5]);
stream.stream_mut().seek(SeekFrom::Current(-4)).unwrap();
let mut buf = [0u8; 3];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 3);
assert_eq!(buf, [2, 3, 4]);
}
#[test]
fn test_cstream_write() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
stream.stream_mut().seek(SeekFrom::End(0)).unwrap();
let buf = [6, 7, 8];
let bytes_written = stream.stream_mut().write(&buf).unwrap();
assert_eq!(bytes_written, 3);
assert_eq!(stream.stream_mut().seek(SeekFrom::End(0)).unwrap(), 8);
}
#[test]
fn test_cstream_read_empty() {
let data = vec![];
let mut stream = TestStream::new(data);
let mut buf = [0u8; 3];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 0);
assert_eq!(buf, [0, 0, 0]);
}
#[test]
fn test_cstream_seek_out_of_bounds() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
assert!(stream.stream_mut().seek(SeekFrom::Start(10)).is_ok());
assert!(stream.stream_mut().seek(SeekFrom::Current(-20)).is_err());
}
#[test]
fn test_cstream_write_overwrite() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
stream.stream_mut().seek(SeekFrom::Start(2)).unwrap();
let buf = [9, 9];
let bytes_written = stream.stream_mut().write(&buf).unwrap();
assert_eq!(bytes_written, 2);
stream.stream_mut().seek(SeekFrom::Start(0)).unwrap();
let mut buf = [0u8; 5];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 5);
assert_eq!(buf, [1, 2, 9, 9, 5]);
}
#[test]
fn test_cstream_flush() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
assert!(stream.stream_mut().flush().is_ok());
}
#[test]
fn test_cstream_large_read() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
let mut buf = [0u8; 10];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 5);
assert_eq!(buf[..5], [1, 2, 3, 4, 5]);
assert_eq!(buf[5..], [0, 0, 0, 0, 0]);
}
#[test]
fn test_cstream_large_write() {
let data = vec![1, 2, 3];
let mut stream = TestStream::new(data);
stream.stream_mut().seek(SeekFrom::End(0)).unwrap();
let buf = [6, 7, 8, 9, 10];
let bytes_written = stream.stream_mut().write(&buf).unwrap();
assert_eq!(bytes_written, 5);
stream.stream_mut().seek(SeekFrom::Start(0)).unwrap();
let mut buf = [0u8; 8];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 8);
assert_eq!(buf, [1, 2, 3, 6, 7, 8, 9, 10]);
}
#[test]
fn test_cstream_seek_to_end() {
let data = vec![1, 2, 3, 4, 5];
let mut stream = TestStream::new(data);
let end_pos = stream.stream_mut().seek(SeekFrom::End(0)).unwrap();
assert_eq!(end_pos, 5);
let mut buf = [0u8; 1];
assert_eq!(stream.stream_mut().read(&mut buf).unwrap(), 0); }
#[test]
fn test_create_stream() {
let test_stream = TestC2paStream::new(vec![1, 2, 3, 4, 5]);
let context = box_tracked!(test_stream) as *mut StreamContext;
let c2pa_stream = unsafe {
c2pa_create_stream(
context,
TestC2paStream::reader,
TestC2paStream::seeker,
TestC2paStream::writer,
TestC2paStream::flusher,
)
};
let c2pa_stream = unsafe { &mut *c2pa_stream };
let mut buf = [0u8; 3];
let result = c2pa_stream.read(&mut buf);
result.expect("Failed to read from C2paStream");
assert_eq!(buf, [1, 2, 3]);
unsafe { c2pa_release_stream(c2pa_stream) };
cimpl_free(context as *mut std::ffi::c_void);
}
}