use crate::device::Device;
use crate::dtype::{DTYPE_FLOAT32, DTYPE_INT32, DType};
use anyhow::{Result, bail};
fn strip_quotes(token: &str) -> &str {
let token = token.trim();
if token.starts_with('"') && token.ends_with('"') && token.len() >= 2 {
&token[1..token.len() - 1]
} else {
token
}
}
pub(crate) fn parse_full_tensor_string(
s: &str,
) -> Result<(Vec<&str>, Vec<usize>, Option<DType>, Device)> {
let s = s.trim();
let parts: Vec<&str> = s.split(';').map(|p| p.trim()).collect();
if parts.is_empty() {
bail!("Empty string");
}
let data_part = parts[0];
let (strings, shape) = parse_nested_array(data_part)?;
let mut dtype: Option<DType> = None;
let mut device: Device = Device::Cpu;
for part in &parts[1..] {
if part.is_empty() {
continue;
}
let token = strip_quotes(part);
if let Ok(dev) = token.parse::<Device>() {
device = dev;
continue;
}
match token {
"f32" => dtype = Some(DTYPE_FLOAT32),
"i32" => dtype = Some(DTYPE_INT32),
_ => {
bail!(
"Unrecognized suffix token: '{}'. Expected dtype (f32/i32) or device (cpu/cuda:N)",
token
);
}
}
}
Ok((strings, shape, dtype, device))
}
fn parse_nested_array(s: &str) -> Result<(Vec<&str>, Vec<usize>)> {
let s = s.trim();
if s.is_empty() {
return Ok((vec![], vec![0]));
}
let inner = if s.starts_with('[') && s.ends_with(']') {
&s[1..s.len() - 1]
} else {
return Ok((vec![s], vec![]));
};
if inner.is_empty() {
return Ok((vec![], vec![0]));
}
let mut elements = Vec::new();
let mut depth = 0;
let mut start = 0;
let chars: Vec<char> = inner.chars().collect();
for i in 0..chars.len() {
match chars[i] {
'[' => depth += 1,
']' => depth -= 1,
',' if depth == 0 => {
let elem = inner[start..i].trim();
elements.push(elem);
start = i + 1;
}
_ => {}
}
}
if start < inner.len() {
let elem = inner[start..].trim();
elements.push(elem);
}
let mut all_strings = Vec::new();
let mut child_shapes = Vec::new();
for elem in &elements {
let (strings, shape) = parse_nested_array(elem)?;
all_strings.extend(strings);
child_shapes.push(shape);
}
if child_shapes.is_empty() {
return Ok((vec![], vec![0]));
}
let first_shape = &child_shapes[0];
for shape in &child_shapes[1..] {
if shape != first_shape {
bail!("Inconsistent dimensions");
}
}
let mut shape = vec![elements.len()];
shape.extend(first_shape);
Ok((all_strings, shape))
}