use crate::dtype::DType;
use crate::error::{Result, TorshError};
use crate::shape::Shape;
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::sync::Arc;
use scirs2_core::error::CoreError;
use scirs2_core::ndarray::{ArrayView, Dimension, IxDyn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy {
ZeroCopy,
DirectCopy,
StrideAdjustment,
TypeConversion,
SimdBatchConversion,
TransposedCopy,
}
#[derive(Debug, Clone)]
pub struct TransferMetadata {
pub strategy: TransferStrategy,
pub source_dtype: DType,
pub target_dtype: DType,
pub total_bytes: usize,
pub estimated_throughput: f64,
pub layouts_compatible: bool,
pub zero_copy_possible: bool,
}
pub struct SciRS2Bridge;
impl SciRS2Bridge {
pub fn analyze_transfer(
shape: &Shape,
source_dtype: DType,
target_dtype: DType,
is_contiguous: bool,
) -> TransferMetadata {
let total_bytes = shape.numel().saturating_mul(target_dtype.size());
let layouts_compatible = is_contiguous;
let dtypes_match = source_dtype == target_dtype;
let zero_copy_possible = layouts_compatible && dtypes_match;
let strategy = if zero_copy_possible {
TransferStrategy::ZeroCopy
} else if !layouts_compatible && dtypes_match {
TransferStrategy::StrideAdjustment
} else if !dtypes_match && is_contiguous {
#[cfg(feature = "simd")]
{
if Self::can_use_simd(source_dtype, target_dtype) && total_bytes > 4096 {
TransferStrategy::SimdBatchConversion
} else {
TransferStrategy::TypeConversion
}
}
#[cfg(not(feature = "simd"))]
{
TransferStrategy::TypeConversion
}
} else if !dtypes_match && !is_contiguous {
TransferStrategy::TransposedCopy
} else {
TransferStrategy::DirectCopy
};
let estimated_throughput = Self::estimate_throughput(strategy, total_bytes);
TransferMetadata {
strategy,
source_dtype,
target_dtype,
total_bytes,
estimated_throughput,
layouts_compatible,
zero_copy_possible,
}
}
#[cfg(feature = "simd")]
fn can_use_simd(source: DType, target: DType) -> bool {
use DType::*;
matches!(
(source, target),
(F32, F64) | (F64, F32) | (F32, F32) | (F64, F64) | (I32, F32) | (I64, F64)
)
}
fn estimate_throughput(strategy: TransferStrategy, total_bytes: usize) -> f64 {
let base_throughput = match strategy {
TransferStrategy::ZeroCopy => 100_000_000_000.0, TransferStrategy::DirectCopy => 50_000_000_000.0, TransferStrategy::StrideAdjustment => 20_000_000_000.0, TransferStrategy::TypeConversion => 10_000_000_000.0, TransferStrategy::SimdBatchConversion => 40_000_000_000.0, TransferStrategy::TransposedCopy => 15_000_000_000.0, };
if total_bytes < 1024 {
base_throughput * 0.5
} else if total_bytes < 65536 {
base_throughput * 0.8
} else {
base_throughput
}
}
#[inline]
pub fn shape_to_ixdyn(shape: &Shape) -> IxDyn {
IxDyn(shape.dims())
}
#[inline]
pub fn ixdyn_to_shape<D: Dimension>(dim: &D) -> Shape {
let dims: Vec<usize> = dim.slice().to_vec();
Shape::new(dims)
}
pub fn create_zero_copy_view<T: Clone>(array: &ArrayView<T, IxDyn>) -> ZeroCopyView<T> {
let shape = array.shape().to_vec();
let strides = array.strides().to_vec();
let ptr = array.as_ptr();
let len = array.len();
ZeroCopyView {
shape,
strides,
ptr,
len,
is_contiguous: array.is_standard_layout(),
_phantom: core::marker::PhantomData,
}
}
#[cfg(feature = "simd")]
pub fn convert_f32_to_f64_simd(src: &[f32], dst: &mut [f64]) -> Result<()> {
if src.len() != dst.len() {
return Err(TorshError::InvalidShape(
"Source and destination lengths must match".into(),
));
}
Self::convert_f32_to_f64_simd_impl(src, dst);
Ok(())
}
#[cfg(all(feature = "simd", target_arch = "x86_64", target_feature = "avx2"))]
fn convert_f32_to_f64_simd_impl(src: &[f32], dst: &mut [f64]) {
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
let chunks = src.len() / 4;
let remainder = src.len() % 4;
unsafe {
for i in 0..chunks {
let base = i * 4;
let src_ptr = src.as_ptr().add(base);
let v_f32 = _mm_loadu_ps(src_ptr);
let v_f64_lo = _mm_cvtps_pd(v_f32);
let v_f32_hi = _mm_movehl_ps(v_f32, v_f32);
let v_f64_hi = _mm_cvtps_pd(v_f32_hi);
let dst_ptr = dst.as_mut_ptr().add(base);
_mm_storeu_pd(dst_ptr, v_f64_lo);
_mm_storeu_pd(dst_ptr.add(2), v_f64_hi);
}
}
let base = chunks * 4;
for i in 0..remainder {
dst[base + i] = src[base + i] as f64;
}
}
}
#[cfg(all(
feature = "simd",
not(all(target_arch = "x86_64", target_feature = "avx2"))
))]
fn convert_f32_to_f64_simd_impl(src: &[f32], dst: &mut [f64]) {
let chunks = src.len() / 4;
let remainder = src.len() % 4;
for i in 0..chunks {
let base = i * 4;
dst[base] = src[base] as f64;
dst[base + 1] = src[base + 1] as f64;
dst[base + 2] = src[base + 2] as f64;
dst[base + 3] = src[base + 3] as f64;
}
let base = chunks * 4;
for i in 0..remainder {
dst[base + i] = src[base + i] as f64;
}
}
#[cfg(feature = "simd")]
pub fn convert_f64_to_f32_simd(src: &[f64], dst: &mut [f32]) -> Result<()> {
if src.len() != dst.len() {
return Err(TorshError::InvalidShape(
"Source and destination lengths must match".into(),
));
}
Self::convert_f64_to_f32_simd_impl(src, dst);
Ok(())
}
#[cfg(all(feature = "simd", target_arch = "x86_64", target_feature = "avx2"))]
fn convert_f64_to_f32_simd_impl(src: &[f64], dst: &mut [f32]) {
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
let chunks = src.len() / 4;
let remainder = src.len() % 4;
unsafe {
for i in 0..chunks {
let base = i * 4;
let src_ptr = src.as_ptr().add(base);
let v_f64_lo = _mm_loadu_pd(src_ptr);
let v_f64_hi = _mm_loadu_pd(src_ptr.add(2));
let v_f32_lo = _mm_cvtpd_ps(v_f64_lo);
let v_f32_hi = _mm_cvtpd_ps(v_f64_hi);
let v_f32 = _mm_shuffle_ps(v_f32_lo, v_f32_hi, 0x44);
let dst_ptr = dst.as_mut_ptr().add(base);
_mm_storeu_ps(dst_ptr, v_f32);
}
}
let base = chunks * 4;
for i in 0..remainder {
dst[base + i] = src[base + i] as f32;
}
}
}
#[cfg(all(
feature = "simd",
not(all(target_arch = "x86_64", target_feature = "avx2"))
))]
fn convert_f64_to_f32_simd_impl(src: &[f64], dst: &mut [f32]) {
let chunks = src.len() / 4;
let remainder = src.len() % 4;
for i in 0..chunks {
let base = i * 4;
dst[base] = src[base] as f32;
dst[base + 1] = src[base + 1] as f32;
dst[base + 2] = src[base + 2] as f32;
dst[base + 3] = src[base + 3] as f32;
}
let base = chunks * 4;
for i in 0..remainder {
dst[base + i] = src[base + i] as f32;
}
}
#[cfg(feature = "simd")]
pub fn batch_convert_f32_to_f64(
sources: &[&[f32]],
destinations: &mut [&mut [f64]],
) -> Result<()> {
if sources.len() != destinations.len() {
return Err(TorshError::InvalidShape(
"Number of source and destination arrays must match".into(),
));
}
for (src, dst) in sources.iter().zip(destinations.iter_mut()) {
Self::convert_f32_to_f64_simd(src, dst)?;
}
Ok(())
}
#[cfg(feature = "simd")]
pub fn batch_convert_f64_to_f32(
sources: &[&[f64]],
destinations: &mut [&mut [f32]],
) -> Result<()> {
if sources.len() != destinations.len() {
return Err(TorshError::InvalidShape(
"Number of source and destination arrays must match".into(),
));
}
for (src, dst) in sources.iter().zip(destinations.iter_mut()) {
Self::convert_f64_to_f32_simd(src, dst)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct ZeroCopyView<T> {
pub shape: Vec<usize>,
pub strides: Vec<isize>,
pub ptr: *const T,
pub len: usize,
pub is_contiguous: bool,
_phantom: core::marker::PhantomData<T>,
}
unsafe impl<T> Send for ZeroCopyView<T> where T: Send {}
unsafe impl<T> Sync for ZeroCopyView<T> where T: Sync {}
impl<T> ZeroCopyView<T> {
pub fn size_bytes(&self) -> usize {
self.len * core::mem::size_of::<T>()
}
pub fn supports_zero_copy(&self) -> bool {
self.is_contiguous
}
}
pub struct ErrorMapper;
impl ErrorMapper {
pub fn from_scirs2(error: CoreError) -> TorshError {
match error {
CoreError::ShapeError(ctx) => TorshError::InvalidShape(ctx.message),
CoreError::DimensionError(ctx) => TorshError::InvalidShape(ctx.message),
CoreError::IndexError(ctx) => {
TorshError::InvalidShape(ctx.message)
}
CoreError::ValueError(ctx) => TorshError::InvalidOperation(ctx.message),
CoreError::ComputationError(ctx) => TorshError::ComputeError(ctx.message),
CoreError::MemoryError(ctx) => TorshError::AllocationError(ctx.message),
CoreError::NotImplementedError(ctx) => TorshError::NotImplemented(ctx.message),
_ => TorshError::RuntimeError(format!("SciRS2 error: {:?}", error)),
}
}
pub fn from_scirs2_with_context(error: CoreError, context: &str) -> TorshError {
let base_error = Self::from_scirs2(error);
match base_error {
TorshError::InvalidShape(msg) => {
TorshError::InvalidShape(format!("{}: {}", context, msg))
}
TorshError::InvalidOperation(msg) => {
TorshError::InvalidOperation(format!("{}: {}", context, msg))
}
TorshError::ComputeError(msg) => {
TorshError::ComputeError(format!("{}: {}", context, msg))
}
TorshError::AllocationError(msg) => {
TorshError::AllocationError(format!("{}: {}", context, msg))
}
TorshError::NotImplemented(msg) => {
TorshError::NotImplemented(format!("{}: {}", context, msg))
}
TorshError::RuntimeError(msg) => {
TorshError::RuntimeError(format!("{}: {}", context, msg))
}
other => other,
}
}
pub fn to_scirs2(error: TorshError) -> CoreError {
use scirs2_core::error::ErrorContext;
match error {
TorshError::ShapeMismatch { expected, got } => {
CoreError::ShapeError(ErrorContext::new(format!(
"Shape mismatch: expected {:?}, got {:?}",
expected, got
)))
}
TorshError::InvalidShape(msg) => CoreError::ShapeError(ErrorContext::new(msg)),
TorshError::IndexOutOfBounds { index, size } => {
CoreError::IndexError(ErrorContext::new(format!(
"Index {} out of bounds for dimension with size {}",
index, size
)))
}
TorshError::IndexError { index, size } => {
CoreError::IndexError(ErrorContext::new(format!(
"Index {} out of bounds for dimension with size {}",
index, size
)))
}
TorshError::InvalidOperation(msg) => CoreError::ValueError(ErrorContext::new(msg)),
TorshError::ComputeError(msg) => CoreError::ComputationError(ErrorContext::new(msg)),
TorshError::AllocationError(msg) => CoreError::MemoryError(ErrorContext::new(msg)),
TorshError::NotImplemented(msg) => {
CoreError::NotImplementedError(ErrorContext::new(msg))
}
TorshError::UnsupportedOperation { op, dtype } => {
CoreError::NotImplementedError(ErrorContext::new(format!(
"Unsupported operation '{}' for data type '{}'",
op, dtype
)))
}
TorshError::DeviceMismatch => CoreError::ValueError(ErrorContext::new(
"Device mismatch: tensors must be on the same device".to_string(),
)),
_ => {
CoreError::ComputationError(ErrorContext::new(format!("ToRSh error: {:?}", error)))
}
}
}
pub fn map_result<T>(result: core::result::Result<T, CoreError>) -> Result<T> {
result.map_err(Self::from_scirs2)
}
}
#[cfg(feature = "std")]
pub struct SharedBufferManager {
buffers: std::sync::Mutex<std::collections::HashMap<usize, Arc<[u8]>>>,
}
#[cfg(feature = "std")]
impl SharedBufferManager {
pub fn new() -> Self {
Self {
buffers: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
pub fn register_buffer(&self, id: usize, buffer: Arc<[u8]>) {
let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
buffers.insert(id, buffer);
}
pub fn get_buffer(&self, id: usize) -> Option<Arc<[u8]>> {
let buffers = self.buffers.lock().expect("lock should not be poisoned");
buffers.get(&id).cloned()
}
pub fn remove_buffer(&self, id: usize) {
let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
buffers.remove(&id);
}
pub fn buffer_count(&self) -> usize {
let buffers = self.buffers.lock().expect("lock should not be poisoned");
buffers.len()
}
pub fn total_bytes(&self) -> usize {
let buffers = self.buffers.lock().expect("lock should not be poisoned");
buffers.values().map(|b| b.len()).sum()
}
}
#[cfg(feature = "std")]
impl Default for SharedBufferManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transfer_strategy_analysis() {
let shape = Shape::new(vec![10, 20, 30]);
let metadata = SciRS2Bridge::analyze_transfer(
&shape,
DType::F32,
DType::F32,
true, );
assert_eq!(metadata.strategy, TransferStrategy::ZeroCopy);
assert!(metadata.zero_copy_possible);
let metadata = SciRS2Bridge::analyze_transfer(
&shape,
DType::F32,
DType::F64,
true, );
#[cfg(feature = "simd")]
assert!(matches!(
metadata.strategy,
TransferStrategy::SimdBatchConversion | TransferStrategy::TypeConversion
));
#[cfg(not(feature = "simd"))]
assert_eq!(metadata.strategy, TransferStrategy::TypeConversion);
assert!(!metadata.zero_copy_possible);
let metadata = SciRS2Bridge::analyze_transfer(
&shape,
DType::F32,
DType::F32,
false, );
assert_eq!(metadata.strategy, TransferStrategy::StrideAdjustment);
assert!(!metadata.zero_copy_possible);
}
#[test]
fn test_shape_conversion() {
let torsh_shape = Shape::new(vec![2, 3, 4]);
let ixdyn = SciRS2Bridge::shape_to_ixdyn(&torsh_shape);
assert_eq!(ixdyn.as_array_view().to_vec(), vec![2, 3, 4]);
let back_to_shape = SciRS2Bridge::ixdyn_to_shape(&ixdyn);
assert_eq!(back_to_shape.dims(), torsh_shape.dims());
}
#[test]
fn test_error_mapping_from_scirs2() {
let scirs2_error = CoreError::ShapeError(scirs2_core::error::ErrorContext::new(
"Shape mismatch: expected [2, 3], got [3, 2]".to_string(),
));
let torsh_error = ErrorMapper::from_scirs2(scirs2_error);
match torsh_error {
TorshError::InvalidShape(msg) => {
assert!(msg.contains("Shape mismatch"));
}
_ => panic!("Expected InvalidShape error"),
}
}
#[test]
fn test_error_mapping_to_scirs2() {
let torsh_error = TorshError::InvalidShape("Test error".to_string());
let scirs2_error = ErrorMapper::to_scirs2(torsh_error);
match scirs2_error {
CoreError::ShapeError(ctx) => assert_eq!(ctx.message, "Test error"),
_ => panic!("Expected ShapeError"),
}
}
#[test]
#[cfg(feature = "std")]
fn test_shared_buffer_manager() {
let manager = SharedBufferManager::new();
let buffer: Arc<[u8]> = Arc::from(vec![1, 2, 3, 4, 5]);
manager.register_buffer(1, buffer.clone());
assert_eq!(manager.buffer_count(), 1);
assert_eq!(manager.total_bytes(), 5);
let retrieved = manager.get_buffer(1).expect("get_buffer should succeed");
assert_eq!(retrieved.len(), 5);
assert_eq!(&*retrieved, &[1, 2, 3, 4, 5]);
manager.remove_buffer(1);
assert_eq!(manager.buffer_count(), 0);
assert!(manager.get_buffer(1).is_none());
}
#[test]
fn test_zero_copy_view_metadata() {
use scirs2_core::ndarray::array;
let arr = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let view = arr.view().into_dyn();
let zero_copy = SciRS2Bridge::create_zero_copy_view(&view);
assert_eq!(zero_copy.shape, vec![2, 3]);
assert_eq!(zero_copy.len, 6);
assert!(zero_copy.is_contiguous);
assert!(zero_copy.supports_zero_copy());
assert_eq!(zero_copy.size_bytes(), 24); }
#[test]
#[cfg(feature = "simd")]
fn test_simd_type_conversion_f32_to_f64() {
let src = vec![1.0f32, 2.0, 3.0, 4.0];
let mut dst = vec![0.0f64; 4];
SciRS2Bridge::convert_f32_to_f64_simd(&src, &mut dst)
.expect("convert_f32_to_f64_simd should succeed");
assert_eq!(dst, vec![1.0f64, 2.0, 3.0, 4.0]);
}
#[test]
#[cfg(feature = "simd")]
fn test_simd_type_conversion_f64_to_f32() {
let src = vec![1.0f64, 2.0, 3.0, 4.0];
let mut dst = vec![0.0f32; 4];
SciRS2Bridge::convert_f64_to_f32_simd(&src, &mut dst)
.expect("convert_f64_to_f32_simd should succeed");
assert_eq!(dst, vec![1.0f32, 2.0, 3.0, 4.0]);
}
#[test]
fn test_transfer_metadata_throughput_estimation() {
let shape = Shape::new(vec![1000, 1000]);
let metadata = SciRS2Bridge::analyze_transfer(&shape, DType::F32, DType::F32, true);
assert!(metadata.estimated_throughput > 10_000_000_000.0);
let small_shape = Shape::new(vec![10, 10]);
let small_metadata =
SciRS2Bridge::analyze_transfer(&small_shape, DType::F32, DType::F32, true);
assert!(small_metadata.estimated_throughput < metadata.estimated_throughput);
}
#[test]
fn test_error_mapping_with_context() {
let scirs2_error = CoreError::ShapeError(scirs2_core::error::ErrorContext::new(
"Invalid dimensions".to_string(),
));
let torsh_error =
ErrorMapper::from_scirs2_with_context(scirs2_error, "During tensor reshape");
match torsh_error {
TorshError::InvalidShape(msg) => {
assert!(msg.contains("During tensor reshape"));
assert!(msg.contains("Invalid dimensions"));
}
_ => panic!("Expected InvalidShape error"),
}
}
#[test]
fn test_error_mapping_comprehensive() {
let test_cases = vec![
(
CoreError::DimensionError(scirs2_core::error::ErrorContext::new(
"Dimension too large".to_string(),
)),
"InvalidShape",
),
(
CoreError::ValueError(scirs2_core::error::ErrorContext::new(
"Invalid value".to_string(),
)),
"InvalidOperation",
),
(
CoreError::ComputationError(scirs2_core::error::ErrorContext::new(
"Computation failed".to_string(),
)),
"ComputeError",
),
(
CoreError::MemoryError(scirs2_core::error::ErrorContext::new(
"Out of memory".to_string(),
)),
"AllocationError",
),
];
for (scirs2_error, expected_type) in test_cases {
let torsh_error = ErrorMapper::from_scirs2(scirs2_error);
let error_type = format!("{:?}", torsh_error);
assert!(
error_type.contains(expected_type),
"Expected {} but got {:?}",
expected_type,
torsh_error
);
}
}
#[test]
#[cfg(feature = "simd")]
fn test_batch_convert_f32_to_f64() {
let src1 = vec![1.0f32, 2.0, 3.0];
let src2 = vec![4.0f32, 5.0, 6.0];
let mut dst1 = vec![0.0f64; 3];
let mut dst2 = vec![0.0f64; 3];
let sources = vec![&src1[..], &src2[..]];
let mut destinations = vec![&mut dst1[..], &mut dst2[..]];
SciRS2Bridge::batch_convert_f32_to_f64(&sources, &mut destinations)
.expect("batch_convert should succeed");
assert_eq!(dst1, vec![1.0f64, 2.0, 3.0]);
assert_eq!(dst2, vec![4.0f64, 5.0, 6.0]);
}
#[test]
#[cfg(feature = "simd")]
fn test_batch_convert_f64_to_f32() {
let src1 = vec![1.0f64, 2.0, 3.0];
let src2 = vec![4.0f64, 5.0, 6.0];
let mut dst1 = vec![0.0f32; 3];
let mut dst2 = vec![0.0f32; 3];
let sources = vec![&src1[..], &src2[..]];
let mut destinations = vec![&mut dst1[..], &mut dst2[..]];
SciRS2Bridge::batch_convert_f64_to_f32(&sources, &mut destinations)
.expect("batch_convert should succeed");
assert_eq!(dst1, vec![1.0f32, 2.0, 3.0]);
assert_eq!(dst2, vec![4.0f32, 5.0, 6.0]);
}
#[test]
#[cfg(feature = "simd")]
fn test_simd_conversion_large_arrays() {
let size = 100;
let src: Vec<f32> = (0..size).map(|i| i as f32).collect();
let mut dst = vec![0.0f64; size];
SciRS2Bridge::convert_f32_to_f64_simd(&src, &mut dst).expect("convert should succeed");
for i in 0..size {
assert_eq!(dst[i], src[i] as f64);
}
}
#[test]
#[cfg(feature = "simd")]
fn test_simd_conversion_error_handling() {
let src = vec![1.0f32, 2.0, 3.0];
let mut dst = vec![0.0f64; 2];
let result = SciRS2Bridge::convert_f32_to_f64_simd(&src, &mut dst);
assert!(result.is_err());
match result {
Err(TorshError::InvalidShape(msg)) => {
assert!(msg.contains("lengths must match"));
}
_ => panic!("Expected InvalidShape error"),
}
}
#[test]
fn test_error_mapper_result_conversion() {
let ok_result: core::result::Result<i32, CoreError> = Ok(42);
let mapped_ok = ErrorMapper::map_result(ok_result);
assert_eq!(mapped_ok.expect("map_result should succeed"), 42);
let err_result: core::result::Result<i32, CoreError> = Err(CoreError::ValueError(
scirs2_core::error::ErrorContext::new("Test error".to_string()),
));
let mapped_err = ErrorMapper::map_result(err_result);
assert!(mapped_err.is_err());
}
}