use crate::{DType, Result, Shape, TorshError};
use std::collections::HashMap;
pub trait FromExternal<T> {
fn from_external(value: T) -> Result<Self>
where
Self: Sized;
}
pub trait ToExternal<T> {
fn to_external(&self) -> Result<T>;
}
pub trait FromExternalZeroCopy<T> {
fn from_external_zero_copy(value: T) -> Result<Self>
where
Self: Sized;
fn can_zero_copy(value: &T) -> bool;
}
pub trait ToExternalZeroCopy<T> {
fn to_external_zero_copy(&self) -> Result<T>;
fn can_zero_copy(&self) -> bool;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NumpyArrayInfo {
pub shape: Vec<usize>,
pub strides: Vec<isize>,
pub dtype: DType,
pub c_contiguous: bool,
pub f_contiguous: bool,
pub nbytes: usize,
}
impl NumpyArrayInfo {
pub fn new(shape: Vec<usize>, dtype: DType) -> Self {
let strides = Self::compute_c_strides(&shape, dtype.size());
let nbytes = shape.iter().product::<usize>() * dtype.size();
Self {
c_contiguous: true,
f_contiguous: shape.len() <= 1,
shape,
strides,
dtype,
nbytes,
}
}
pub fn with_strides(shape: Vec<usize>, strides: Vec<isize>, dtype: DType) -> Self {
let nbytes = shape.iter().product::<usize>() * dtype.size();
let c_strides = Self::compute_c_strides(&shape, dtype.size());
let f_strides = Self::compute_f_strides(&shape, dtype.size());
Self {
shape,
strides: strides.clone(),
dtype,
c_contiguous: strides == c_strides,
f_contiguous: strides == f_strides,
nbytes,
}
}
fn compute_c_strides(shape: &[usize], itemsize: usize) -> Vec<isize> {
let mut strides = vec![0; shape.len()];
if !shape.is_empty() {
let mut stride = itemsize as isize;
for i in (0..shape.len()).rev() {
strides[i] = stride;
stride *= shape[i] as isize;
}
}
strides
}
fn compute_f_strides(shape: &[usize], itemsize: usize) -> Vec<isize> {
let mut strides = vec![0; shape.len()];
if !shape.is_empty() {
let mut stride = itemsize as isize;
for i in 0..shape.len() {
strides[i] = stride;
stride *= shape[i] as isize;
}
}
strides
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OnnxTensorInfo {
pub elem_type: OnnxDataType,
pub shape: Vec<Option<usize>>,
pub name: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OnnxDataType {
Undefined = 0,
Float = 1,
Uint8 = 2,
Int8 = 3,
Uint16 = 4,
Int16 = 5,
Int32 = 6,
Int64 = 7,
String = 8,
Bool = 9,
Float16 = 10,
Double = 11,
Uint32 = 12,
Uint64 = 13,
Complex64 = 14,
Complex128 = 15,
Bfloat16 = 16,
}
impl From<DType> for OnnxDataType {
fn from(dtype: DType) -> Self {
match dtype {
DType::F32 => OnnxDataType::Float,
DType::F64 => OnnxDataType::Double,
DType::F16 => OnnxDataType::Float16,
DType::BF16 => OnnxDataType::Bfloat16,
DType::I8 => OnnxDataType::Int8,
DType::U8 => OnnxDataType::Uint8,
DType::I16 => OnnxDataType::Int16,
DType::I32 => OnnxDataType::Int32,
DType::I64 => OnnxDataType::Int64,
DType::U32 => OnnxDataType::Uint32,
DType::U64 => OnnxDataType::Uint64,
DType::Bool => OnnxDataType::Bool,
DType::C64 => OnnxDataType::Complex64,
DType::C128 => OnnxDataType::Complex128,
DType::QInt8 => OnnxDataType::Int8, DType::QUInt8 => OnnxDataType::Uint8,
DType::QInt32 => OnnxDataType::Int32, }
}
}
impl TryFrom<OnnxDataType> for DType {
type Error = TorshError;
fn try_from(onnx_type: OnnxDataType) -> Result<Self> {
match onnx_type {
OnnxDataType::Float => Ok(DType::F32),
OnnxDataType::Double => Ok(DType::F64),
OnnxDataType::Float16 => Ok(DType::F16),
OnnxDataType::Bfloat16 => Ok(DType::BF16),
OnnxDataType::Int8 => Ok(DType::I8),
OnnxDataType::Uint8 => Ok(DType::U8),
OnnxDataType::Int16 => Ok(DType::I16),
OnnxDataType::Int32 => Ok(DType::I32),
OnnxDataType::Int64 => Ok(DType::I64),
OnnxDataType::Bool => Ok(DType::Bool),
OnnxDataType::Complex64 => Ok(DType::C64),
OnnxDataType::Complex128 => Ok(DType::C128),
_ => Err(TorshError::UnsupportedOperation {
op: "ONNX data type conversion".to_string(),
dtype: format!("{onnx_type:?}"),
}),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ArrowTypeInfo {
pub data_type: ArrowDataType,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ArrowDataType {
Boolean,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64,
FixedSizeList(Box<ArrowDataType>, usize),
}
impl From<DType> for ArrowDataType {
fn from(dtype: DType) -> Self {
match dtype {
DType::Bool => ArrowDataType::Boolean,
DType::I8 | DType::QInt8 => ArrowDataType::Int8,
DType::U8 | DType::QUInt8 => ArrowDataType::UInt8,
DType::I16 => ArrowDataType::Int16,
DType::I32 | DType::QInt32 => ArrowDataType::Int32,
DType::I64 => ArrowDataType::Int64,
DType::U32 => ArrowDataType::UInt32,
DType::U64 => ArrowDataType::UInt64,
DType::F16 => ArrowDataType::Float16,
DType::F32 => ArrowDataType::Float32,
DType::F64 => ArrowDataType::Float64,
DType::BF16 => ArrowDataType::Float32, DType::C64 => ArrowDataType::FixedSizeList(Box::new(ArrowDataType::Float32), 2),
DType::C128 => ArrowDataType::FixedSizeList(Box::new(ArrowDataType::Float64), 2),
}
}
}
pub struct ConversionUtils;
impl ConversionUtils {
pub fn torsh_shape_to_numpy(shape: &Shape) -> Vec<usize> {
shape.dims().to_vec()
}
pub fn numpy_shape_to_torsh(shape: Vec<usize>) -> Result<Shape> {
Ok(Shape::new(shape))
}
pub fn is_layout_compatible(
shape1: &[usize],
strides1: &[isize],
shape2: &[usize],
strides2: &[isize],
) -> bool {
if shape1.len() != shape2.len() || shape1 != shape2 {
return false;
}
strides1 == strides2
}
pub fn layout_efficiency_score(shape: &[usize], strides: &[isize], itemsize: usize) -> f64 {
if shape.is_empty() {
return 1.0;
}
let c_strides = NumpyArrayInfo::compute_c_strides(shape, itemsize);
if strides == c_strides {
return 1.0;
}
let f_strides = NumpyArrayInfo::compute_f_strides(shape, itemsize);
if strides == f_strides {
return 0.9;
}
let total_elements: usize = shape.iter().product();
let expected_size = total_elements * itemsize;
let actual_span = Self::compute_memory_span(shape, strides, itemsize);
if actual_span == 0 {
return 0.0;
}
(expected_size as f64 / actual_span as f64).min(1.0)
}
fn compute_memory_span(shape: &[usize], strides: &[isize], itemsize: usize) -> usize {
if shape.is_empty() {
return 0;
}
let mut min_offset = 0isize;
let mut max_offset = 0isize;
for (&dim, &stride) in shape.iter().zip(strides.iter()) {
if dim > 1 {
let offset = stride * (dim as isize - 1);
min_offset = min_offset.min(offset);
max_offset = max_offset.max(offset);
}
}
(max_offset - min_offset) as usize + itemsize
}
}
pub struct InteropDocs;
impl InteropDocs {
pub fn supported_conversions() -> String {
let conversions = vec![
("NumPy", "ndarray", "Zero-copy when C-contiguous"),
("ndarray", "Array", "Zero-copy when contiguous"),
("ONNX", "TensorProto", "Schema mapping"),
("Arrow", "Array", "Type mapping with metadata"),
("Rust", "Vec<T>", "Direct conversion"),
];
let mut doc = String::from("Supported Tensor Format Conversions:\n");
doc.push_str("=========================================\n\n");
for (from, to, notes) in conversions {
doc.push_str(&format!("• {from} ↔ {to}: {notes}\n"));
}
doc
}
pub fn conversion_examples() -> String {
r#"
Conversion Examples:
==================
// NumPy-style array info
let numpy_info = NumpyArrayInfo::new(vec![2, 3, 4], DType::F32);
assert!(numpy_info.c_contiguous);
// ONNX type conversion
let onnx_type = OnnxDataType::from(DType::F32);
let back_to_dtype = DType::try_from(onnx_type).unwrap();
// Arrow type conversion
let arrow_type = ArrowDataType::from(DType::C64);
match arrow_type {
ArrowDataType::FixedSizeList(inner, size) => {
assert_eq!(size, 2); // Real and imaginary parts
}
_ => panic!("Unexpected type"),
}
// Layout efficiency checking
let shape = vec![1000, 1000];
let c_strides = vec![4000, 4]; // C-contiguous for f32
let efficiency = ConversionUtils::layout_efficiency_score(&shape, &c_strides, 4);
assert_eq!(efficiency, 1.0);
"#
.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_numpy_array_info() {
let info = NumpyArrayInfo::new(vec![2, 3, 4], DType::F32);
assert_eq!(info.shape, vec![2, 3, 4]);
assert_eq!(info.strides, vec![48, 16, 4]); assert!(info.c_contiguous);
assert!(!info.f_contiguous);
assert_eq!(info.nbytes, 96); }
#[test]
fn test_onnx_dtype_conversion() {
let dtypes = vec![
DType::F32,
DType::F64,
DType::I8,
DType::U8,
DType::I32,
DType::Bool,
DType::C64,
];
for dtype in dtypes {
let onnx_type = OnnxDataType::from(dtype);
let back_to_dtype = DType::try_from(onnx_type).expect("try_from should succeed");
assert_eq!(dtype, back_to_dtype);
}
}
#[test]
fn test_arrow_dtype_conversion() {
assert_eq!(ArrowDataType::from(DType::F32), ArrowDataType::Float32);
assert_eq!(ArrowDataType::from(DType::Bool), ArrowDataType::Boolean);
match ArrowDataType::from(DType::C64) {
ArrowDataType::FixedSizeList(inner, size) => {
assert_eq!(*inner, ArrowDataType::Float32);
assert_eq!(size, 2);
}
_ => panic!("Expected FixedSizeList for C64"),
}
}
#[test]
fn test_layout_efficiency() {
let shape = vec![10, 10];
let itemsize = 4;
let c_strides = vec![40, 4];
let efficiency = ConversionUtils::layout_efficiency_score(&shape, &c_strides, itemsize);
assert_eq!(efficiency, 1.0);
let f_strides = vec![4, 40];
let efficiency = ConversionUtils::layout_efficiency_score(&shape, &f_strides, itemsize);
assert_eq!(efficiency, 0.9);
}
#[test]
fn test_conversion_utils() {
let shape = Shape::new(vec![2, 3, 4]);
let numpy_shape = ConversionUtils::torsh_shape_to_numpy(&shape);
assert_eq!(numpy_shape, vec![2, 3, 4]);
let back_to_shape = ConversionUtils::numpy_shape_to_torsh(numpy_shape)
.expect("numpy_shape_to_torsh should succeed");
assert_eq!(shape.dims(), back_to_shape.dims());
}
#[test]
fn test_layout_compatibility() {
let shape1 = vec![2, 3];
let strides1 = vec![12, 4];
let shape2 = vec![2, 3];
let strides2 = vec![12, 4];
assert!(ConversionUtils::is_layout_compatible(
&shape1, &strides1, &shape2, &strides2
));
let strides3 = vec![4, 8]; assert!(!ConversionUtils::is_layout_compatible(
&shape1, &strides1, &shape2, &strides3
));
}
}