use runmat_builtins::{NumericDType, Tensor, Value};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum WindowSampling {
Symmetric,
Periodic,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum WindowOutputType {
Double,
Single,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct WindowOptions {
pub len: usize,
pub sampling: WindowSampling,
pub output_type: WindowOutputType,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum WindowArgError {
InvalidLength,
InvalidOptionType,
UnknownOption(String),
TensorBuild(String),
}
pub(crate) fn scalar_length_arg(value: Value) -> Result<usize, WindowArgError> {
let scalar = match value {
Value::Num(n) => n,
Value::Int(i) => i.to_f64(),
Value::Bool(b) => usize::from(b) as f64,
Value::Tensor(t) if t.data.len() == 1 => t.data[0],
_ => return Err(WindowArgError::InvalidLength),
};
if !scalar.is_finite() || scalar < 0.0 {
return Err(WindowArgError::InvalidLength);
}
Ok(scalar.round() as usize)
}
pub(crate) fn window_tensor(
options: WindowOptions,
coeff: impl Fn(usize, usize) -> f64,
) -> Result<Value, WindowArgError> {
let len = options.len;
if len == 0 {
return Tensor::new_with_dtype(Vec::new(), vec![0, 1], dtype_for(options.output_type))
.map(Value::Tensor)
.map_err(|e| WindowArgError::TensorBuild(e.to_string()));
}
if len == 1 {
return Tensor::new_with_dtype(vec![1.0], vec![1, 1], dtype_for(options.output_type))
.map(Value::Tensor)
.map_err(|e| WindowArgError::TensorBuild(e.to_string()));
}
let effective_len = match options.sampling {
WindowSampling::Symmetric => len,
WindowSampling::Periodic => len + 1,
};
let mut data = (0..effective_len)
.map(|idx| coeff(idx, effective_len))
.collect::<Vec<_>>();
if matches!(options.sampling, WindowSampling::Periodic) {
data.pop();
}
Tensor::new_with_dtype(data, vec![len, 1], dtype_for(options.output_type))
.map(Value::Tensor)
.map_err(|e| WindowArgError::TensorBuild(e.to_string()))
}
pub(crate) fn parse_window_options(
len_value: Value,
rest: &[Value],
allow_type_name: bool,
) -> Result<WindowOptions, WindowArgError> {
let len = scalar_length_arg(len_value)?;
let mut sampling = WindowSampling::Symmetric;
let mut output_type = WindowOutputType::Double;
for arg in rest {
let Some(keyword) = string_keyword(arg) else {
return Err(WindowArgError::InvalidOptionType);
};
match keyword.as_str() {
"symmetric" => sampling = WindowSampling::Symmetric,
"periodic" => sampling = WindowSampling::Periodic,
"double" if allow_type_name => output_type = WindowOutputType::Double,
"single" if allow_type_name => output_type = WindowOutputType::Single,
_ => return Err(WindowArgError::UnknownOption(keyword)),
}
}
Ok(WindowOptions {
len,
sampling,
output_type,
})
}
pub(crate) fn dtype_for(output_type: WindowOutputType) -> NumericDType {
match output_type {
WindowOutputType::Double => NumericDType::F64,
WindowOutputType::Single => NumericDType::F32,
}
}
pub(crate) fn provider_precision_matches(output_type: WindowOutputType) -> bool {
let Some(provider) = runmat_accelerate_api::provider() else {
return false;
};
matches!(
(provider.precision(), output_type),
(
runmat_accelerate_api::ProviderPrecision::F64,
WindowOutputType::Double
) | (
runmat_accelerate_api::ProviderPrecision::F32,
WindowOutputType::Single
)
)
}
fn string_keyword(value: &Value) -> Option<String> {
match value {
Value::String(s) => Some(s.to_ascii_lowercase()),
Value::CharArray(chars) => Some(chars.data.iter().collect::<String>().to_ascii_lowercase()),
_ => None,
}
}