#![allow(clippy::missing_safety_doc)]
use core::ffi::{c_char, c_int};
use std::cell::RefCell;
use std::ffi::CString;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr;
use tenso::{decode, dense_required_size, encode_dense_into, parse_header};
use tenso::{ArraySpec, Decoded, Dtype, EncodeOpts, TensoError};
pub const TENSO_OK: c_int = 0;
pub const TENSO_ERR_TOO_SHORT: c_int = -1;
pub const TENSO_ERR_BAD_MAGIC: c_int = -2;
pub const TENSO_ERR_UNSUPPORTED_VERSION: c_int = -3;
pub const TENSO_ERR_BAD_DTYPE: c_int = -4;
pub const TENSO_ERR_TOO_MANY_DIMS: c_int = -5;
pub const TENSO_ERR_TOO_MANY_ELEMENTS: c_int = -6;
pub const TENSO_ERR_INTEGRITY: c_int = -7;
pub const TENSO_ERR_BAD_BUNDLE: c_int = -8;
pub const TENSO_ERR_LZ4: c_int = -9;
pub const TENSO_ERR_BUFFER_TOO_SMALL: c_int = -10;
pub const TENSO_ERR_NULL: c_int = -11;
pub const TENSO_ERR_MALFORMED: c_int = -12;
pub const TENSO_ERR_PANIC: c_int = -13;
pub const TENSO_ERR_UNSUPPORTED_KIND: c_int = -14;
thread_local! {
static LAST_ERROR: RefCell<Option<CString>> = const { RefCell::new(None) };
}
fn set_last_error(msg: &str) {
let sanitized: String = msg.chars().filter(|&c| c != '\0').collect();
let c = CString::new(sanitized).unwrap_or_else(|_| CString::new("tenso: error").unwrap());
LAST_ERROR.with(|slot| *slot.borrow_mut() = Some(c));
}
fn clear_last_error() {
LAST_ERROR.with(|slot| *slot.borrow_mut() = None);
}
fn report_error(err: &TensoError) -> c_int {
let (code, msg): (c_int, String) = match err {
TensoError::TooShort => (TENSO_ERR_TOO_SHORT, "tenso: input too short".into()),
TensoError::BadMagic => (TENSO_ERR_BAD_MAGIC, "tenso: bad magic".into()),
TensoError::UnsupportedVersion(v) => (
TENSO_ERR_UNSUPPORTED_VERSION,
format!("tenso: unsupported version {v}"),
),
TensoError::BadDtype(c) => (TENSO_ERR_BAD_DTYPE, format!("tenso: bad dtype code {c}")),
TensoError::TooManyDims => (TENSO_ERR_TOO_MANY_DIMS, "tenso: too many dims".into()),
TensoError::TooManyElements => (
TENSO_ERR_TOO_MANY_ELEMENTS,
"tenso: too many elements".into(),
),
TensoError::IntegrityMismatch => (TENSO_ERR_INTEGRITY, "tenso: integrity mismatch".into()),
TensoError::BadBundle => (TENSO_ERR_BAD_BUNDLE, "tenso: bad bundle".into()),
TensoError::Lz4(reason) => (TENSO_ERR_LZ4, format!("tenso: lz4: {reason}")),
TensoError::BufferTooSmall => {
(TENSO_ERR_BUFFER_TOO_SMALL, "tenso: buffer too small".into())
}
TensoError::Malformed => (TENSO_ERR_MALFORMED, "tenso: malformed packet".into()),
};
set_last_error(&msg);
code
}
#[repr(C)]
pub struct TensoHeader {
pub version: u8,
pub flags: u16,
pub dtype_code: u8,
pub ndim: u32,
pub base_size: u32,
}
pub struct TensoView {
dtype_code: u8,
shape: Vec<u32>,
body: Vec<u8>,
}
#[no_mangle]
pub unsafe extern "C" fn tenso_parse_header(
data: *const u8,
len: usize,
out: *mut TensoHeader,
) -> c_int {
clear_last_error();
if out.is_null() {
set_last_error("tenso: null out pointer");
return TENSO_ERR_NULL;
}
let bytes = match slice_from_raw(data, len) {
Ok(b) => b,
Err(code) => return code,
};
let result = catch_unwind(AssertUnwindSafe(|| parse_header(bytes)));
match result {
Ok(Ok(h)) => {
let hdr = TensoHeader {
version: h.version,
flags: h.flags,
dtype_code: h.dtype_code,
ndim: h.ndim.min(u32::MAX as usize) as u32,
base_size: h.base_size.min(u32::MAX as usize) as u32,
};
ptr::write(out, hdr);
TENSO_OK
}
Ok(Err(e)) => report_error(&e),
Err(_) => {
set_last_error("tenso: panic in parse_header");
TENSO_ERR_PANIC
}
}
}
unsafe fn build_spec_opts<'a>(
data: *const u8,
data_len: usize,
dtype_code: u8,
shape: *const u32,
ndim: usize,
check_integrity: bool,
compress: bool,
alignment: usize,
) -> Result<(ArraySpec<'a>, EncodeOpts), c_int> {
let dtype = match Dtype::from_code(dtype_code) {
Ok(d) => d,
Err(e) => return Err(report_error(&e)),
};
let data_slice = slice_from_raw(data, data_len)?;
let shape_slice = slice_from_raw_t::<u32>(shape, ndim)?;
let alignment = if alignment == 0 {
tenso::ALIGNMENT
} else {
alignment
};
let spec = ArraySpec {
data: data_slice,
dtype,
shape: shape_slice,
};
let opts = EncodeOpts {
check_integrity,
compress,
alignment,
};
Ok((spec, opts))
}
#[no_mangle]
pub unsafe extern "C" fn tenso_dense_required_size(
data: *const u8,
data_len: usize,
dtype_code: u8,
shape: *const u32,
ndim: usize,
check_integrity: bool,
compress: bool,
alignment: usize,
out_size: *mut usize,
) -> c_int {
clear_last_error();
if out_size.is_null() {
set_last_error("tenso: null out_size pointer");
return TENSO_ERR_NULL;
}
let (spec, opts) = match build_spec_opts(
data,
data_len,
dtype_code,
shape,
ndim,
check_integrity,
compress,
alignment,
) {
Ok(v) => v,
Err(code) => return code,
};
let result = catch_unwind(AssertUnwindSafe(|| dense_required_size(&spec, &opts)));
match result {
Ok(Ok(n)) => {
ptr::write(out_size, n);
TENSO_OK
}
Ok(Err(e)) => report_error(&e),
Err(_) => {
set_last_error("tenso: panic in dense_required_size");
TENSO_ERR_PANIC
}
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_encode_dense_into(
data: *const u8,
data_len: usize,
dtype_code: u8,
shape: *const u32,
ndim: usize,
check_integrity: bool,
compress: bool,
alignment: usize,
out: *mut u8,
out_cap: usize,
written: *mut usize,
) -> c_int {
clear_last_error();
if written.is_null() {
set_last_error("tenso: null written pointer");
return TENSO_ERR_NULL;
}
let (spec, opts) = match build_spec_opts(
data,
data_len,
dtype_code,
shape,
ndim,
check_integrity,
compress,
alignment,
) {
Ok(v) => v,
Err(code) => return code,
};
let out_slice = match mut_slice_from_raw(out, out_cap) {
Ok(s) => s,
Err(code) => return code,
};
let result = catch_unwind(AssertUnwindSafe(|| {
encode_dense_into(&spec, out_slice, &opts)
}));
match result {
Ok(Ok(n)) => {
ptr::write(written, n);
TENSO_OK
}
Ok(Err(e)) => report_error(&e),
Err(_) => {
set_last_error("tenso: panic in encode_dense_into");
TENSO_ERR_PANIC
}
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_decode(data: *const u8, len: usize) -> *mut TensoView {
clear_last_error();
let bytes = match slice_from_raw(data, len) {
Ok(b) => b,
Err(_) => return ptr::null_mut(),
};
let owned: Vec<u8> = bytes.to_vec();
let result = catch_unwind(AssertUnwindSafe(|| build_view(&owned)));
match result {
Ok(Ok(view)) => Box::into_raw(Box::new(view)),
Ok(Err(code)) => {
let _ = code;
ptr::null_mut()
}
Err(_) => {
set_last_error("tenso: panic in decode");
ptr::null_mut()
}
}
}
fn build_view(owned: &[u8]) -> Result<TensoView, c_int> {
let decoded = decode(owned).map_err(|e| report_error(&e))?;
match decoded {
Decoded::Dense(t) => Ok(TensoView {
dtype_code: t.dtype.code(),
shape: t.shape,
body: t.body.to_vec(),
}),
Decoded::Quantized(q) => Ok(TensoView {
dtype_code: q.dtype.code(),
shape: q.shape,
body: q.packed.to_vec(),
}),
Decoded::Bundle(_)
| Decoded::Sparse { .. }
| Decoded::String { .. }
| Decoded::Ragged { .. }
| Decoded::IpcRef(_) => {
set_last_error(
"tenso: decoded packet is structured (bundle/sparse/string/ragged/ipc-ref); \
no flat body — use a structured decode path",
);
Err(TENSO_ERR_UNSUPPORTED_KIND)
}
}
}
unsafe fn view_ref<'a>(view: *const TensoView) -> Option<&'a TensoView> {
if view.is_null() {
None
} else {
Some(&*view)
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_view_dtype(view: *const TensoView) -> u8 {
match view_ref(view) {
Some(v) => v.dtype_code,
None => 0,
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_view_ndim(view: *const TensoView) -> usize {
match view_ref(view) {
Some(v) => v.shape.len(),
None => 0,
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_view_shape(view: *const TensoView) -> *const u32 {
match view_ref(view) {
Some(v) => v.shape.as_ptr(),
None => ptr::null(),
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_view_body_ptr(view: *const TensoView) -> *const u8 {
match view_ref(view) {
Some(v) => v.body.as_ptr(),
None => ptr::null(),
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_view_body_len(view: *const TensoView) -> usize {
match view_ref(view) {
Some(v) => v.body.len(),
None => 0,
}
}
#[no_mangle]
pub unsafe extern "C" fn tenso_view_free(view: *mut TensoView) {
if !view.is_null() {
drop(Box::from_raw(view));
}
}
#[no_mangle]
pub extern "C" fn tenso_last_error() -> *const c_char {
LAST_ERROR.with(|slot| match slot.borrow().as_ref() {
Some(c) => c.as_ptr(),
None => EMPTY.as_ptr() as *const c_char,
})
}
static EMPTY: [u8; 1] = [0];
unsafe fn slice_from_raw<'a>(data: *const u8, len: usize) -> Result<&'a [u8], c_int> {
if data.is_null() {
if len == 0 {
Ok(&[])
} else {
set_last_error("tenso: null data pointer with non-zero length");
Err(TENSO_ERR_NULL)
}
} else if len == 0 {
Ok(&[])
} else {
Ok(std::slice::from_raw_parts(data, len))
}
}
unsafe fn slice_from_raw_t<'a, T>(data: *const T, count: usize) -> Result<&'a [T], c_int> {
if data.is_null() {
if count == 0 {
Ok(&[])
} else {
set_last_error("tenso: null array pointer with non-zero count");
Err(TENSO_ERR_NULL)
}
} else if count == 0 {
Ok(&[])
} else {
Ok(std::slice::from_raw_parts(data, count))
}
}
unsafe fn mut_slice_from_raw<'a>(out: *mut u8, cap: usize) -> Result<&'a mut [u8], c_int> {
if out.is_null() {
if cap == 0 {
Ok(&mut [])
} else {
set_last_error("tenso: null out pointer with non-zero capacity");
Err(TENSO_ERR_NULL)
}
} else if cap == 0 {
Ok(&mut [])
} else {
Ok(std::slice::from_raw_parts_mut(out, cap))
}
}
#[cfg(test)]
mod c_abi_conformance {
use super::*;
use tenso::Dtype;
fn fixture(name: &str) -> Vec<u8> {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../tests/fixtures")
.join(name);
std::fs::read(&path).unwrap_or_else(|e| panic!("read {}: {e}", path.display()))
}
fn le<T: Copy, const N: usize>(vals: &[T], to_le: impl Fn(T) -> [u8; N]) -> Vec<u8> {
vals.iter().flat_map(|&v| to_le(v)).collect()
}
fn c_encode(
data: &[u8],
dtype: Dtype,
shape: &[u32],
integrity: bool,
compress: bool,
) -> Vec<u8> {
unsafe {
let mut size = 0usize;
let rc = tenso_dense_required_size(
data.as_ptr(),
data.len(),
dtype.code(),
shape.as_ptr(),
shape.len(),
integrity,
compress,
0,
&mut size,
);
assert_eq!(rc, TENSO_OK, "tenso_dense_required_size returned {rc}");
let mut out = vec![0u8; size];
let mut written = 0usize;
let rc = tenso_encode_dense_into(
data.as_ptr(),
data.len(),
dtype.code(),
shape.as_ptr(),
shape.len(),
integrity,
compress,
0,
out.as_mut_ptr(),
out.len(),
&mut written,
);
assert_eq!(rc, TENSO_OK, "tenso_encode_dense_into returned {rc}");
out.truncate(written);
out
}
}
#[test]
fn c_abi_encode_dense_f32_matches_python() {
let data = le(&[1.0f32, 2.0, 3.0, 4.0, 5.0], f32::to_le_bytes);
let got = c_encode(&data, Dtype::F32, &[5], false, false);
assert_eq!(got, fixture("dense_f32_vec.tenso"));
}
#[test]
fn c_abi_encode_i32_integrity_matches_python() {
let vals: Vec<i32> = (0..8).collect();
let data = le(&vals, i32::to_le_bytes);
let got = c_encode(&data, Dtype::I32, &[8], true, false);
assert_eq!(got, fixture("dense_i32_integrity.tenso"));
}
#[test]
fn c_abi_encode_compressed_matches_python() {
let vals: Vec<f64> = (0..16).flat_map(|_| 0..64).map(|x| x as f64).collect();
let data = le(&vals, f64::to_le_bytes);
let got = c_encode(&data, Dtype::F64, &[1024], true, true);
assert_eq!(got, fixture("dense_f64_compressed.tenso"));
}
#[test]
fn c_abi_decode_f64_matches_python() {
let packet = fixture("dense_f64_mat.tenso"); unsafe {
let view = tenso_decode(packet.as_ptr(), packet.len());
assert!(!view.is_null(), "tenso_decode returned null");
assert_eq!(tenso_view_dtype(view), Dtype::F64.code());
assert_eq!(tenso_view_ndim(view), 2);
let shape = std::slice::from_raw_parts(tenso_view_shape(view), tenso_view_ndim(view));
assert_eq!(shape, &[3, 4]);
let body =
std::slice::from_raw_parts(tenso_view_body_ptr(view), tenso_view_body_len(view));
let expected: Vec<f64> = (0..12).map(|x| x as f64).collect();
assert_eq!(body, le(&expected, f64::to_le_bytes).as_slice());
tenso_view_free(view);
}
}
}