use super::*;
use crate::device::Device;
use crate::dtype::{get_dtype_info, DType, DTypeMapping};
use crate::tensor::{DataPtr, Tensor};
use crate::{DTYPE_FLOAT32, DTYPE_INT32};
use bytemuck::{cast_slice, Pod};
use cudarc::driver::DevicePtr;
impl Tensor {
pub fn new_cpu_from_slice<T: Pod + DTypeMapping>(data: &[T], shape: Vec<usize>) -> Self {
let dtype = T::DTYPE;
let elem_size = std::mem::size_of::<T>();
let bytes = cast_slice(data).to_vec().into_boxed_slice();
let strides = Self::compute_row_major_strides(&shape, elem_size);
Tensor {
data: DataPtr::Cpu(bytes),
shape,
strides,
dtype,
device: Device::Cpu,
}
}
pub fn new_cpu_from_f32(data: Vec<f32>, shape: Vec<usize>) -> Self {
Self::new_cpu_from_slice(&data, shape)
}
pub fn new_cpu_from_i32(data: Vec<i32>, shape: Vec<usize>) -> Self {
Self::new_cpu_from_slice(&data, shape)
}
pub fn new_cpu_from_bytes(
bytes: Box<[u8]>,
shape: Vec<usize>,
dtype: DType,
) -> Result<Self, String> {
let elem_size = get_dtype_info(dtype).ok_or("Unknown dtype")?.size;
let expected_size = shape.iter().product::<usize>() * elem_size;
if bytes.len() != expected_size {
return Err(format!(
"Byte size mismatch: expected {}, got {}",
expected_size,
bytes.len()
));
}
let strides = Self::compute_row_major_strides(&shape, elem_size);
Ok(Tensor {
data: DataPtr::Cpu(bytes),
shape,
strides,
dtype,
device: Device::Cpu,
})
}
pub fn new_contiguous(shape: Vec<usize>, dtype: DType) -> Result<Self, String> {
let elem_size = get_dtype_info(dtype).ok_or("Unknown dtype")?.size;
let total_elements: usize = shape.iter().product();
let total_bytes = total_elements * elem_size;
let bytes = vec![0u8; total_bytes].into_boxed_slice();
let strides = Self::compute_row_major_strides(&shape, elem_size);
Ok(Tensor {
data: DataPtr::Cpu(bytes),
shape,
strides,
dtype,
device: Device::Cpu,
})
}
pub fn from_strings(
strings: &[&str],
shape: &[usize],
dtype_hint: Option<&str>,
) -> Result<Self, String> {
let total_elements: usize = shape.iter().product();
if strings.is_empty() && total_elements > 0 {
let dtype = match dtype_hint {
Some("i32") => DTYPE_INT32,
Some("f32") => DTYPE_FLOAT32,
_ => DTYPE_FLOAT32, };
return Self::new_contiguous(shape.to_vec(), dtype);
}
if strings.len() != total_elements {
return Err(format!(
"Number of strings ({}) does not match shape product ({})",
strings.len(),
total_elements
));
}
let dtype = match dtype_hint {
Some("i32") => DTYPE_INT32,
Some("f32") => DTYPE_FLOAT32,
None => {
let all_int = strings
.iter()
.all(|s| !s.contains('.') && !s.contains('e') && !s.contains('E'));
if all_int {
DTYPE_INT32
} else {
DTYPE_FLOAT32
}
}
Some(other) => return Err(format!("Unsupported dtype hint: {}", other)),
};
match dtype {
DTYPE_FLOAT32 => {
let mut data = Vec::with_capacity(strings.len());
for s in strings {
let val = s
.parse::<f32>()
.map_err(|e| format!("Failed to parse '{}' as f32: {}", s, e))?;
data.push(val);
}
Ok(Self::new_cpu_from_f32(data, shape.to_vec()))
}
DTYPE_INT32 => {
let mut data = Vec::with_capacity(strings.len());
for s in strings {
let val = s
.parse::<i32>()
.map_err(|e| format!("Failed to parse '{}' as i32: {}", s, e))?;
data.push(val);
}
Ok(Self::new_cpu_from_i32(data, shape.to_vec()))
}
_ => Err(format!("Unsupported dtype: {}", dtype)),
}
}
pub fn from_string_literal(s: &str) -> Result<Self, String> {
let s = s.trim();
let (data_str, dtype_hint) = if let Some(semi_pos) = s.rfind(';') {
let data = s[..semi_pos].trim();
let suffix = s[semi_pos+1..].trim();
(data, Some(suffix))
} else {
(s, None)
};
let (strings, shape) = Self::parse_nested_array(data_str)?;
Self::from_strings(&strings, &shape, dtype_hint)
}
fn parse_nested_array(s: &str) -> Result<(Vec<&str>, Vec<usize>), String> {
let s = s.trim();
if s.is_empty() {
return Err("Empty string".into());
}
let inner = if s.starts_with('[') && s.ends_with(']') {
&s[1..s.len()-1]
} else {
return Ok((vec![s], vec![]));
};
if inner.is_empty() {
return Ok((vec![], vec![0]));
}
let mut elements = Vec::new();
let mut depth = 0;
let mut start = 0;
let chars: Vec<char> = inner.chars().collect();
for i in 0..chars.len() {
match chars[i] {
'[' => depth += 1,
']' => depth -= 1,
',' if depth == 0 => {
let elem = inner[start..i].trim();
elements.push(elem);
start = i + 1;
}
_ => {}
}
}
if start < inner.len() {
let elem = inner[start..].trim();
elements.push(elem);
}
let mut all_strings = Vec::new();
let mut child_shapes = Vec::new();
for elem in &elements { let (strings, shape) = Self::parse_nested_array(elem)?;
all_strings.extend(strings);
child_shapes.push(shape);
}
if child_shapes.is_empty() {
return Ok((vec![], vec![0]));
}
let first_shape = &child_shapes[0];
for shape in &child_shapes[1..] {
if shape != first_shape {
return Err("Inconsistent dimensions".into());
}
}
let mut shape = vec![elements.len()];
shape.extend(first_shape);
Ok((all_strings, shape))
}
}