use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use scirs2_core::Complex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DType {
Bool,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float32,
Float64,
String(usize),
Complex32,
Complex64,
Struct(Vec<Field>),
}
impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
DType::Bool => 1,
DType::Int8 => 1,
DType::Int16 => 2,
DType::Int32 => 4,
DType::Int64 => 8,
DType::UInt8 => 1,
DType::UInt16 => 2,
DType::UInt32 => 4,
DType::UInt64 => 8,
DType::Float32 => 4,
DType::Float64 => 8,
DType::String(len) => *len,
DType::Complex32 => 8, DType::Complex64 => 16, DType::Struct(fields) => fields.iter().map(|f| f.dtype.size_in_bytes()).sum(),
}
}
pub fn is_numeric(&self) -> bool {
matches!(
self,
DType::Bool
| DType::Int8
| DType::Int16
| DType::Int32
| DType::Int64
| DType::UInt8
| DType::UInt16
| DType::UInt32
| DType::UInt64
| DType::Float32
| DType::Float64
| DType::Complex32
| DType::Complex64
)
}
pub fn is_floating_point(&self) -> bool {
matches!(
self,
DType::Float32 | DType::Float64 | DType::Complex32 | DType::Complex64
)
}
pub fn is_complex(&self) -> bool {
matches!(self, DType::Complex32 | DType::Complex64)
}
pub fn is_string(&self) -> bool {
matches!(self, DType::String(_))
}
pub fn is_struct(&self) -> bool {
matches!(self, DType::Struct(_))
}
}
impl fmt::Display for DType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DType::Bool => write!(f, "bool"),
DType::Int8 => write!(f, "int8"),
DType::Int16 => write!(f, "int16"),
DType::Int32 => write!(f, "int32"),
DType::Int64 => write!(f, "int64"),
DType::UInt8 => write!(f, "uint8"),
DType::UInt16 => write!(f, "uint16"),
DType::UInt32 => write!(f, "uint32"),
DType::UInt64 => write!(f, "uint64"),
DType::Float32 => write!(f, "float32"),
DType::Float64 => write!(f, "float64"),
DType::String(len) => write!(f, "S{}", len),
DType::Complex32 => write!(f, "complex64"),
DType::Complex64 => write!(f, "complex128"),
DType::Struct(fields) => {
write!(f, "struct{{")?;
for (i, field) in fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {}", field.name, field.dtype)?;
}
write!(f, "}}")?;
Ok(())
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Field {
pub name: String,
pub dtype: DType,
}
impl Field {
pub fn new<S: Into<String>>(name: S, dtype: DType) -> Self {
Self {
name: name.into(),
dtype,
}
}
}
impl fmt::Display for Field {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.name, self.dtype)
}
}
#[derive(Debug, Clone)]
pub struct StructuredArray {
shape: Vec<usize>,
dtype: DType,
data: Vec<u8>,
}
impl StructuredArray {
pub fn new(shape: &[usize], dtype: DType) -> Self {
let size = shape.iter().product::<usize>();
let byte_size = size * dtype.size_in_bytes();
let data = vec![0; byte_size];
Self {
shape: shape.to_vec(),
dtype,
data,
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn dtype(&self) -> &DType {
&self.dtype
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn size(&self) -> usize {
self.shape.iter().product()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn field<T: Clone + Default + 'static>(&self, field_name: &str) -> Result<Array<T>> {
if let DType::Struct(fields) = &self.dtype {
let field = fields
.iter()
.find(|f| f.name == field_name)
.ok_or_else(|| {
NumRs2Error::IndexError(format!("Field '{}' not found", field_name))
})?;
let mut offset = 0;
for f in fields.iter() {
if f.name == field_name {
break;
}
offset += f.dtype.size_in_bytes();
}
let field_size = field.dtype.size_in_bytes();
let element_size = self.dtype.size_in_bytes();
let mut field_data = Vec::with_capacity(self.size());
for i in 0..self.size() {
let start = i * element_size + offset;
let end = start + field_size;
let bytes = &self.data[start..end];
let value = bytes_to_value::<T>(bytes, &field.dtype)?;
field_data.push(value);
}
let arr = Array::from_vec(field_data).reshape(&self.shape);
Ok(arr)
} else {
Err(NumRs2Error::ValueError(
"Not a structured array".to_string(),
))
}
}
pub fn set_field<T: Clone + 'static>(
&mut self,
index: &[usize],
field_name: &str,
value: T,
) -> Result<()> {
if let DType::Struct(fields) = &self.dtype {
if index.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} dimensions, got {}",
self.ndim(),
index.len()
)));
}
for (i, &idx) in index.iter().enumerate() {
if idx >= self.shape[i] {
return Err(NumRs2Error::IndexError(format!(
"Index {} out of bounds for dimension {} with size {}",
idx, i, self.shape[i]
)));
}
}
let field = fields
.iter()
.find(|f| f.name == field_name)
.ok_or_else(|| {
NumRs2Error::IndexError(format!("Field '{}' not found", field_name))
})?;
let mut offset = 0;
for f in fields.iter() {
if f.name == field_name {
break;
}
offset += f.dtype.size_in_bytes();
}
let mut flat_index = 0;
let mut stride = 1;
for i in (0..self.ndim()).rev() {
flat_index += index[i] * stride;
stride *= self.shape[i];
}
let element_size = self.dtype.size_in_bytes();
let start = flat_index * element_size + offset;
let end = start + field.dtype.size_in_bytes();
let bytes = value_to_bytes(&value, &field.dtype)?;
if bytes.len() != field.dtype.size_in_bytes() {
return Err(NumRs2Error::ValueError(format!(
"Expected {} bytes, got {}",
field.dtype.size_in_bytes(),
bytes.len()
)));
}
self.data[start..end].copy_from_slice(&bytes);
Ok(())
} else {
Err(NumRs2Error::ValueError(
"Not a structured array".to_string(),
))
}
}
pub fn from_arrays<T: Clone + Default + 'static>(
arrays: &HashMap<String, Array<T>>,
shape: &[usize],
) -> Result<Self> {
for (name, arr) in arrays.iter() {
if arr.shape() != shape {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array '{}' has shape {:?}, expected {:?}",
name,
arr.shape(),
shape
)));
}
}
let fields = arrays
.keys()
.map(|name| {
Field::new(name.clone(), DType::Float64) })
.collect();
let dtype = DType::Struct(fields);
let mut result = Self::new(shape, dtype);
let size = shape.iter().product::<usize>();
for i in 0..size {
let index = flat_to_index(i, shape);
for (name, _arr) in arrays.iter() {
let value = T::clone(&T::default());
result.set_field(&index, name, value)?;
}
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct RecordArray {
array: StructuredArray,
field_cache: HashMap<String, Array<f64>>, }
impl RecordArray {
pub fn new(shape: &[usize], fields: Vec<Field>) -> Self {
let dtype = DType::Struct(fields.clone());
let array = StructuredArray::new(shape, dtype);
let mut field_cache = HashMap::new();
for field in &fields {
let field_array = Array::zeros(shape);
field_cache.insert(field.name.clone(), field_array);
}
Self { array, field_cache }
}
pub fn from_arrays(arrays: &HashMap<String, Array<f64>>, shape: &[usize]) -> Result<Self> {
let array = StructuredArray::from_arrays(arrays, shape)?;
let mut field_cache = HashMap::new();
for (name, arr) in arrays.iter() {
field_cache.insert(name.clone(), arr.clone());
}
Ok(Self { array, field_cache })
}
pub fn shape(&self) -> &[usize] {
self.array.shape()
}
pub fn dtype(&self) -> &DType {
self.array.dtype()
}
pub fn size(&self) -> usize {
self.array.size()
}
pub fn ndim(&self) -> usize {
self.array.ndim()
}
pub fn field(&self, field_name: &str) -> Result<&Array<f64>> {
if self.field_cache.contains_key(field_name) {
Ok(&self.field_cache[field_name])
} else {
Err(NumRs2Error::IndexError(format!(
"Field '{}' not found",
field_name
)))
}
}
pub fn field_mut(&mut self, field_name: &str) -> Result<&mut Array<f64>> {
self.field_cache
.get_mut(field_name)
.ok_or_else(|| NumRs2Error::IndexError(format!("Field '{}' not found", field_name)))
}
pub fn set_field(&mut self, index: &[usize], field_name: &str, value: f64) -> Result<()> {
if let Some(arr) = self.field_cache.get_mut(field_name) {
arr.set(index, value)?;
}
self.array.set_field(index, field_name, value)
}
pub fn add_field(&mut self, field_name: &str, data: Array<f64>) -> Result<()> {
if self.field_cache.contains_key(field_name) {
return Err(NumRs2Error::ValueError(format!(
"Field '{}' already exists",
field_name
)));
}
if data.shape() != self.array.shape() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array has shape {:?}, expected {:?}",
data.shape(),
self.array.shape()
)));
}
self.field_cache
.insert(field_name.to_string(), data.clone());
let mut new_fields = Vec::new();
if let DType::Struct(ref fields) = &self.array.dtype {
new_fields.extend(fields.clone());
}
new_fields.push(Field::new(field_name, DType::Float64));
let new_dtype = DType::Struct(new_fields);
let mut new_array = StructuredArray::new(self.array.shape(), new_dtype);
for existing_field_name in self.field_cache.keys() {
if existing_field_name != field_name {
let size = self.array.size();
for i in 0..size {
let index = flat_to_index(i, self.array.shape());
if let Some(field_array) = self.field_cache.get(existing_field_name) {
let value = match index.len() {
1 => field_array.array()[[index[0]]],
2 => field_array.array()[[index[0], index[1]]],
3 => field_array.array()[[index[0], index[1], index[2]]],
_ => {
return Err(NumRs2Error::NotImplemented(
"More than 3 dimensions not supported in add_field".to_string(),
))
}
};
new_array.set_field(&index, existing_field_name, value)?;
}
}
}
}
let size = self.array.size();
for i in 0..size {
let index = flat_to_index(i, self.array.shape());
let value = match index.len() {
1 => data.array()[[index[0]]],
2 => data.array()[[index[0], index[1]]],
3 => data.array()[[index[0], index[1], index[2]]],
_ => {
return Err(NumRs2Error::NotImplemented(
"More than 3 dimensions not supported in add_field".to_string(),
))
}
};
new_array.set_field(&index, field_name, value)?;
}
self.array = new_array;
Ok(())
}
pub fn remove_field(&mut self, field_name: &str) -> Result<Array<f64>> {
let arr = self
.field_cache
.remove(field_name)
.ok_or_else(|| NumRs2Error::IndexError(format!("Field '{}' not found", field_name)))?;
if let DType::Struct(ref mut fields) = &mut self.array.dtype {
fields.retain(|f| f.name != field_name);
}
Ok(arr)
}
pub fn field_names(&self) -> Vec<String> {
self.field_cache.keys().cloned().collect()
}
}
impl fmt::Display for StructuredArray {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"StructuredArray(shape={:?}, dtype={})",
self.shape, self.dtype
)
}
}
impl fmt::Display for RecordArray {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"RecordArray(shape={:?}, fields={})",
self.shape(),
self.field_names().join(", ")
)
}
}
fn bytes_to_value<T: Clone + Default + 'static>(bytes: &[u8], dtype: &DType) -> Result<T> {
use std::any::TypeId;
let type_id = TypeId::of::<T>();
match dtype {
DType::Bool => {
if type_id == TypeId::of::<bool>() {
let value = bytes[0] != 0;
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Bool".to_string(),
))
}
}
DType::Int8 => {
if type_id == TypeId::of::<i8>() {
let value = i8::from_le_bytes([bytes[0]]);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int8".to_string(),
))
}
}
DType::Int16 => {
if type_id == TypeId::of::<i16>() {
let mut buf = [0u8; 2];
buf.copy_from_slice(&bytes[0..2]);
let value = i16::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int16".to_string(),
))
}
}
DType::Int32 => {
if type_id == TypeId::of::<i32>() {
let mut buf = [0u8; 4];
buf.copy_from_slice(&bytes[0..4]);
let value = i32::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int32".to_string(),
))
}
}
DType::Int64 => {
if type_id == TypeId::of::<i64>() {
let mut buf = [0u8; 8];
buf.copy_from_slice(&bytes[0..8]);
let value = i64::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int64".to_string(),
))
}
}
DType::UInt8 => {
if type_id == TypeId::of::<u8>() {
let value = bytes[0];
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt8".to_string(),
))
}
}
DType::UInt16 => {
if type_id == TypeId::of::<u16>() {
let mut buf = [0u8; 2];
buf.copy_from_slice(&bytes[0..2]);
let value = u16::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt16".to_string(),
))
}
}
DType::UInt32 => {
if type_id == TypeId::of::<u32>() {
let mut buf = [0u8; 4];
buf.copy_from_slice(&bytes[0..4]);
let value = u32::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt32".to_string(),
))
}
}
DType::UInt64 => {
if type_id == TypeId::of::<u64>() {
let mut buf = [0u8; 8];
buf.copy_from_slice(&bytes[0..8]);
let value = u64::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt64".to_string(),
))
}
}
DType::Float32 => {
if type_id == TypeId::of::<f32>() {
let mut buf = [0u8; 4];
buf.copy_from_slice(&bytes[0..4]);
let value = f32::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Float32".to_string(),
))
}
}
DType::Float64 => {
if type_id == TypeId::of::<f64>() {
let mut buf = [0u8; 8];
buf.copy_from_slice(&bytes[0..8]);
let value = f64::from_le_bytes(buf);
Ok(unsafe { std::mem::transmute_copy(&value) })
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Float64".to_string(),
))
}
}
DType::String(_) => {
if type_id == TypeId::of::<String>() {
let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());
let value = String::from_utf8_lossy(&bytes[0..end]).to_string();
let ptr = &value as *const String as *const T;
let result = unsafe { std::ptr::read(ptr) };
std::mem::forget(value); Ok(result)
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for String".to_string(),
))
}
}
DType::Complex32 => {
if type_id == TypeId::of::<Complex<f32>>() {
let mut real_buf = [0u8; 4];
let mut imag_buf = [0u8; 4];
real_buf.copy_from_slice(&bytes[0..4]);
imag_buf.copy_from_slice(&bytes[4..8]);
let real = f32::from_le_bytes(real_buf);
let imag = f32::from_le_bytes(imag_buf);
let value = Complex::new(real, imag);
let ptr = &value as *const Complex<f32> as *const T;
let result = unsafe { std::ptr::read(ptr) };
let _ = value; Ok(result)
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Complex32".to_string(),
))
}
}
DType::Complex64 => {
if type_id == TypeId::of::<Complex<f64>>() {
let mut real_buf = [0u8; 8];
let mut imag_buf = [0u8; 8];
real_buf.copy_from_slice(&bytes[0..8]);
imag_buf.copy_from_slice(&bytes[8..16]);
let real = f64::from_le_bytes(real_buf);
let imag = f64::from_le_bytes(imag_buf);
let value = Complex::new(real, imag);
let ptr = &value as *const Complex<f64> as *const T;
let result = unsafe { std::ptr::read(ptr) };
let _ = value; Ok(result)
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Complex64".to_string(),
))
}
}
DType::Struct(_) => Err(NumRs2Error::ValueError(
"Cannot convert struct to single value".to_string(),
)),
}
}
fn value_to_bytes<T: Clone + 'static>(value: &T, dtype: &DType) -> Result<Vec<u8>> {
use std::any::TypeId;
let type_id = TypeId::of::<T>();
match dtype {
DType::Bool => {
if type_id == TypeId::of::<bool>() {
let bool_value: &bool = unsafe { std::mem::transmute(value) };
Ok(vec![if *bool_value { 1u8 } else { 0u8 }])
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Bool".to_string(),
))
}
}
DType::Int8 => {
if type_id == TypeId::of::<i8>() {
let int_value: &i8 = unsafe { std::mem::transmute(value) };
Ok(int_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int8".to_string(),
))
}
}
DType::Int16 => {
if type_id == TypeId::of::<i16>() {
let int_value: &i16 = unsafe { std::mem::transmute(value) };
Ok(int_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int16".to_string(),
))
}
}
DType::Int32 => {
if type_id == TypeId::of::<i32>() {
let int_value: &i32 = unsafe { std::mem::transmute(value) };
Ok(int_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int32".to_string(),
))
}
}
DType::Int64 => {
if type_id == TypeId::of::<i64>() {
let int_value: &i64 = unsafe { std::mem::transmute(value) };
Ok(int_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Int64".to_string(),
))
}
}
DType::UInt8 => {
if type_id == TypeId::of::<u8>() {
let uint_value: &u8 = unsafe { std::mem::transmute(value) };
Ok(vec![*uint_value])
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt8".to_string(),
))
}
}
DType::UInt16 => {
if type_id == TypeId::of::<u16>() {
let uint_value: &u16 = unsafe { std::mem::transmute(value) };
Ok(uint_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt16".to_string(),
))
}
}
DType::UInt32 => {
if type_id == TypeId::of::<u32>() {
let uint_value: &u32 = unsafe { std::mem::transmute(value) };
Ok(uint_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt32".to_string(),
))
}
}
DType::UInt64 => {
if type_id == TypeId::of::<u64>() {
let uint_value: &u64 = unsafe { std::mem::transmute(value) };
Ok(uint_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for UInt64".to_string(),
))
}
}
DType::Float32 => {
if type_id == TypeId::of::<f32>() {
let float_value: &f32 = unsafe { std::mem::transmute(value) };
Ok(float_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Float32".to_string(),
))
}
}
DType::Float64 => {
if type_id == TypeId::of::<f64>() {
let float_value: &f64 = unsafe { std::mem::transmute(value) };
Ok(float_value.to_le_bytes().to_vec())
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Float64".to_string(),
))
}
}
DType::String(max_len) => {
if type_id == TypeId::of::<String>() {
let string_value: &String = unsafe { std::mem::transmute(value) };
let mut bytes = string_value.as_bytes().to_vec();
match bytes.len().cmp(max_len) {
std::cmp::Ordering::Less => bytes.resize(*max_len, 0),
std::cmp::Ordering::Greater => bytes.truncate(*max_len),
std::cmp::Ordering::Equal => {}
}
Ok(bytes)
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for String".to_string(),
))
}
}
DType::Complex32 => {
if type_id == TypeId::of::<Complex<f32>>() {
let complex_value: &Complex<f32> = unsafe { std::mem::transmute(value) };
let mut bytes = Vec::with_capacity(8);
bytes.extend_from_slice(&complex_value.re.to_le_bytes());
bytes.extend_from_slice(&complex_value.im.to_le_bytes());
Ok(bytes)
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Complex32".to_string(),
))
}
}
DType::Complex64 => {
if type_id == TypeId::of::<Complex<f64>>() {
let complex_value: &Complex<f64> = unsafe { std::mem::transmute(value) };
let mut bytes = Vec::with_capacity(16);
bytes.extend_from_slice(&complex_value.re.to_le_bytes());
bytes.extend_from_slice(&complex_value.im.to_le_bytes());
Ok(bytes)
} else {
Err(NumRs2Error::TypeCastError(
"Type mismatch for Complex64".to_string(),
))
}
}
DType::Struct(_) => Err(NumRs2Error::ValueError(
"Cannot convert single value to struct".to_string(),
)),
}
}
fn flat_to_index(flat_index: usize, shape: &[usize]) -> Vec<usize> {
let mut index = vec![0; shape.len()];
let mut remainder = flat_index;
for i in (0..shape.len()).rev() {
index[i] = remainder % shape[i];
remainder /= shape[i];
}
index
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtype_size() {
assert_eq!(DType::Bool.size_in_bytes(), 1);
assert_eq!(DType::Int32.size_in_bytes(), 4);
assert_eq!(DType::Float64.size_in_bytes(), 8);
assert_eq!(DType::String(10).size_in_bytes(), 10);
let fields = vec![
Field::new("a", DType::Int32),
Field::new("b", DType::Float64),
];
let struct_type = DType::Struct(fields);
assert_eq!(struct_type.size_in_bytes(), 12); }
#[test]
fn test_dtype_properties() {
assert!(DType::Int32.is_numeric());
assert!(DType::Float64.is_floating_point());
assert!(DType::Complex64.is_complex());
assert!(DType::String(10).is_string());
let fields = vec![
Field::new("a", DType::Int32),
Field::new("b", DType::Float64),
];
let struct_type = DType::Struct(fields);
assert!(struct_type.is_struct());
}
#[test]
fn test_field_creation() {
let field = Field::new("test", DType::Int32);
assert_eq!(field.name, "test");
assert_eq!(field.dtype, DType::Int32);
}
}