use std::sync::Arc;
use sim_kernel::{
Cx, DefaultFactory, Factory, NoopEvalPolicy, NumberLiteral, Result, Symbol, Value,
};
use crate::Tensor;
use sim_lib_numbers_core::domains;
pub trait SpecTensor: Send + Sync + 'static {
fn shape(&self) -> &[usize];
fn dtype(&self) -> Symbol;
fn to_uniform(&self) -> Tensor;
fn from_uniform(tensor: &Tensor) -> Option<Self>
where
Self: Sized;
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SpecTensorDescriptor {
pub symbol: Symbol,
pub dtype: Symbol,
pub implementation: &'static str,
pub storage: &'static str,
}
pub fn spec_tensor_symbol(name: &str) -> Symbol {
Symbol::qualified("numbers/tensor-spec", name)
}
pub fn spec_tensor_descriptor_value(
factory: &dyn Factory,
descriptor: SpecTensorDescriptor,
) -> Result<Value> {
factory.table(vec![
(
Symbol::new("kind"),
factory.string("spec-tensor".to_owned())?,
),
(Symbol::new("symbol"), factory.symbol(descriptor.symbol)?),
(Symbol::new("dtype"), factory.symbol(descriptor.dtype)?),
(
Symbol::new("implementation"),
factory.string(descriptor.implementation.to_owned())?,
),
(
Symbol::new("storage"),
factory.string(descriptor.storage.to_owned())?,
),
])
}
pub fn element_count(shape: &[usize]) -> usize {
if shape.is_empty() {
1
} else {
shape.iter().product()
}
}
pub fn checked_element_count(shape: &[usize]) -> Result<usize> {
shape.iter().try_fold(1_usize, |acc, &dim| {
acc.checked_mul(dim).ok_or_else(|| {
sim_kernel::Error::Eval(format!("tensor shape {shape:?} cell count overflows usize"))
})
})
}
pub fn number_literal_for_tensor_cell(value: &Value) -> Option<NumberLiteral> {
let mut cx = Cx::new(Arc::new(NoopEvalPolicy), Arc::new(DefaultFactory));
value
.object()
.as_number_value()?
.number_literal(&mut cx)
.ok()?
}
pub fn parse_i64_literal_cell(value: &Value) -> Option<i64> {
let literal = number_literal_for_tensor_cell(value)?;
(literal.domain == domains::i64())
.then(|| literal.canonical.parse::<i64>().ok())
.flatten()
}
pub fn parse_f64_literal_cell(value: &Value) -> Option<f64> {
let literal = number_literal_for_tensor_cell(value)?;
(literal.domain == domains::f64())
.then(|| literal.canonical.parse::<f64>().ok())
.flatten()
}
pub fn parse_rational_literal_cell(value: &Value) -> Option<(i64, i64)> {
let literal = number_literal_for_tensor_cell(value)?;
if literal.domain != domains::rational() {
return None;
}
let (num, den) = literal.canonical.split_once('/')?;
Some((num.parse::<i64>().ok()?, den.parse::<i64>().ok()?))
}
pub fn parse_complex_literal_cell(value: &Value) -> Option<(f64, f64)> {
let literal = number_literal_for_tensor_cell(value)?;
if literal.domain != domains::complex() {
return None;
}
let text = literal.canonical.strip_suffix('i')?;
let split = text
.char_indices()
.skip(1)
.find(|(_, ch)| *ch == '+' || *ch == '-')
.map(|(index, _)| index)?;
let (real, imag) = text.split_at(split);
Some((real.parse::<f64>().ok()?, imag.parse::<f64>().ok()?))
}