use runmat_accelerate_api::{GpuTensorHandle, HostTensorView};
use runmat_builtins::{LogicalArray, Tensor, Value};
use runmat_macros::runtime_builtin;
use crate::build_runtime_error;
use crate::builtins::array::type_resolvers::tensor_type_from_rank;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
};
use crate::builtins::common::{random, tensor};
use runmat_builtins::{ResolveContext, Type};
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::creation::randi")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "randi",
op_kind: GpuOpKind::Custom("generator"),
supported_precisions: &[ScalarType::F32, ScalarType::F64],
broadcast: BroadcastSemantics::None,
provider_hooks: &[
ProviderHook::Custom("random_integer_range"),
ProviderHook::Custom("random_integer_like"),
],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::NewHandle,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Providers may offer integer RNG kernels via random_integer_range / random_integer_like; the runtime falls back to host sampling and upload when unavailable.",
};
fn builtin_error(message: impl Into<String>) -> crate::RuntimeError {
build_runtime_error(message).with_builtin("randi").build()
}
fn randi_type(args: &[Type], ctx: &ResolveContext) -> Type {
if args.is_empty() {
return Type::Unknown;
}
if args.len() == 1 {
return Type::Num;
}
let rest = &args[1..];
if rest.iter().any(|arg| matches!(arg, Type::String)) {
return Type::Unknown;
}
let rest_ctx = ResolveContext::new(ctx.literal_args.get(1..).unwrap_or(&[]).to_vec());
tensor_type_from_rank(rest, &rest_ctx)
}
#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::creation::randi")]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "randi",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: false,
notes: "Random integer generation is treated as a sink and excluded from fusion planning.",
};
#[runtime_builtin(
name = "randi",
category = "array/creation",
summary = "Uniform random integers with inclusive bounds.",
keywords = "randi,random,integer,gpu,like",
accel = "array_construct",
type_resolver(randi_type),
builtin_path = "crate::builtins::array::creation::randi"
)]
async fn randi_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
let parsed = ParsedRandi::parse(args).await?;
build_output(parsed).await
}
struct ParsedRandi {
bounds: Bounds,
shape: Vec<usize>,
template: OutputTemplate,
}
#[derive(Clone)]
enum OutputTemplate {
Double,
Logical,
Like(Value),
}
#[derive(Clone, Copy)]
struct Bounds {
lower: i64,
upper: i64,
span: u64,
}
impl Bounds {
fn new(lower: i64, upper: i64) -> crate::BuiltinResult<Self> {
if lower > upper {
return Err(builtin_error("randi: lower bound must be <= upper bound"));
}
let span = (upper as i128)
.checked_sub(lower as i128)
.and_then(|delta| delta.checked_add(1))
.ok_or_else(|| builtin_error("randi: range width overflows 64-bit arithmetic"))?;
if span <= 0 {
return Err(builtin_error("randi: invalid bounds"));
}
if span > (1u64 << 53) as i128 {
return Err(builtin_error(
"randi: range width exceeds RNG precision (2^53)",
));
}
Ok(Self {
lower,
upper,
span: span as u64,
})
}
}
impl ParsedRandi {
async fn parse(args: Vec<Value>) -> crate::BuiltinResult<Self> {
if args.is_empty() {
return Err(builtin_error("randi: requires at least one input argument"));
}
let mut iter = args.into_iter();
let bounds_value = iter.next().unwrap();
let bounds = parse_bounds(bounds_value).await?;
let mut dims: Vec<usize> = Vec::new();
let mut saw_dims_arg = false;
let mut shape_source: Option<Vec<usize>> = None;
let mut like_proto: Option<Value> = None;
let mut class_override: Option<OutputTemplate> = None;
let mut implicit_proto: Option<Value> = None;
let rest: Vec<Value> = iter.collect();
let mut idx = 0;
while idx < rest.len() {
let arg = rest[idx].clone();
if let Some(keyword) = keyword_of(&arg) {
match keyword.as_str() {
"like" => {
if like_proto.is_some() {
return Err(builtin_error(
"randi: multiple 'like' specifications are not supported",
));
}
if let Some(spec) = &class_override {
let keyword = match spec {
OutputTemplate::Logical => "'logical'",
OutputTemplate::Double => "'double'",
OutputTemplate::Like(_) => "another class specifier",
};
return Err(builtin_error(format!(
"randi: cannot combine 'like' with {keyword}"
)));
}
let Some(proto) = rest.get(idx + 1).cloned() else {
return Err(builtin_error("randi: expected prototype after 'like'"));
};
like_proto = Some(proto.clone());
if shape_source.is_none() && !saw_dims_arg {
shape_source = Some(shape_from_value(&proto)?);
}
idx += 2;
continue;
}
"double" => {
if like_proto.is_some() {
return Err(builtin_error(
"randi: cannot combine 'like' with 'double'",
));
}
class_override = Some(OutputTemplate::Double);
idx += 1;
continue;
}
"logical" => {
if like_proto.is_some() {
return Err(builtin_error(
"randi: cannot combine 'like' with 'logical'",
));
}
class_override = Some(OutputTemplate::Logical);
idx += 1;
continue;
}
"single" => {
return Err(builtin_error(
"randi: single precision output is not implemented yet",
));
}
"int8" | "uint8" | "int16" | "uint16" | "int32" | "uint32" | "int64"
| "uint64" => {
return Err(builtin_error(format!(
"randi: output class '{}' is not implemented yet",
keyword
)));
}
other => {
return Err(builtin_error(format!(
"randi: unrecognised option '{other}'"
)));
}
}
}
if let Some(parsed_dims) = extract_dims(&arg).await? {
saw_dims_arg = true;
if dims.is_empty() {
dims = parsed_dims;
} else {
dims.extend(parsed_dims);
}
idx += 1;
continue;
}
if shape_source.is_none() {
shape_source = Some(shape_from_value(&arg)?);
}
if implicit_proto.is_none() {
implicit_proto = Some(arg.clone());
}
idx += 1;
}
let shape = if saw_dims_arg {
if dims.is_empty() {
vec![0, 0]
} else if dims.len() == 1 {
vec![dims[0], dims[0]]
} else {
dims
}
} else if let Some(shape) = shape_source {
shape
} else {
vec![1, 1]
};
let template = if let Some(proto) = like_proto {
OutputTemplate::Like(proto)
} else if let Some(spec) = class_override {
spec
} else if let Some(proto) = implicit_proto {
OutputTemplate::Like(proto)
} else {
OutputTemplate::Double
};
Ok(Self {
bounds,
shape,
template,
})
}
}
async fn build_output(parsed: ParsedRandi) -> crate::BuiltinResult<Value> {
match parsed.template {
OutputTemplate::Double => randi_double(&parsed.bounds, &parsed.shape),
OutputTemplate::Logical => randi_logical(&parsed.bounds, &parsed.shape),
OutputTemplate::Like(proto) => randi_like(&proto, &parsed.bounds, &parsed.shape).await,
}
}
fn randi_double(bounds: &Bounds, shape: &[usize]) -> crate::BuiltinResult<Value> {
let tensor = integer_tensor(bounds, shape)?;
Ok(tensor::tensor_into_value(tensor))
}
fn randi_logical(bounds: &Bounds, shape: &[usize]) -> crate::BuiltinResult<Value> {
if bounds.lower < 0 || bounds.upper > 1 {
return Err(builtin_error(
"randi: logical output requires bounds contained within the inclusive range [0, 1]",
));
}
let len = tensor::element_count(shape);
let mut data: Vec<u8> = Vec::with_capacity(len);
if len == 0 {
let logical = LogicalArray::new(data, shape.to_vec())
.map_err(|e| builtin_error(format!("randi: {e}")))?;
return Ok(Value::LogicalArray(logical));
}
if bounds.span == 1 {
let byte = if bounds.lower == 0 { 0u8 } else { 1u8 };
data.resize(len, byte);
} else {
let samples = generate_integer_data(bounds, len)?;
data = samples
.into_iter()
.map(|value| if value != 0.0 { 1u8 } else { 0u8 })
.collect();
}
let logical = LogicalArray::new(data, shape.to_vec())
.map_err(|e| builtin_error(format!("randi: {e}")))?;
Ok(Value::LogicalArray(logical))
}
#[async_recursion::async_recursion(?Send)]
async fn randi_like(
proto: &Value,
bounds: &Bounds,
shape: &[usize],
) -> crate::BuiltinResult<Value> {
match proto {
Value::GpuTensor(handle) => randi_like_gpu(handle, bounds, shape).await,
Value::LogicalArray(_) | Value::Bool(_) => randi_logical(bounds, shape),
Value::Tensor(_) | Value::Num(_) | Value::Int(_) => randi_double(bounds, shape),
Value::CharArray(_) | Value::String(_) | Value::StringArray(_) => {
randi_double(bounds, shape)
}
Value::Complex(_, _) | Value::ComplexTensor(_) => Err(builtin_error(
"randi: complex prototypes are not supported; expected real-valued arrays",
)),
Value::Cell(_) => Err(builtin_error("randi: cell prototypes are not supported")),
other => Err(builtin_error(format!(
"randi: unsupported prototype {other:?}"
))),
}
}
#[async_recursion::async_recursion(?Send)]
async fn randi_like_gpu(
handle: &GpuTensorHandle,
bounds: &Bounds,
shape: &[usize],
) -> crate::BuiltinResult<Value> {
if let Some(provider) = runmat_accelerate_api::provider() {
let attempt = if handle.shape == shape {
provider.random_integer_like(handle, bounds.lower, bounds.upper)
} else {
provider.random_integer_range(bounds.lower, bounds.upper, shape)
};
if let Ok(gpu) = attempt {
return Ok(Value::GpuTensor(gpu));
}
let tensor = integer_tensor(bounds, shape)?;
let view = HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
if let Ok(gpu) = provider.upload(&view) {
return Ok(Value::GpuTensor(gpu));
}
return Ok(tensor::tensor_into_value(tensor));
}
let gathered = crate::dispatcher::gather_if_needed_async(&Value::GpuTensor(handle.clone()))
.await
.map_err(|e| builtin_error(format!("randi: {e}")))?;
randi_like(&gathered, bounds, shape).await
}
fn integer_tensor(bounds: &Bounds, shape: &[usize]) -> crate::BuiltinResult<Tensor> {
let len = tensor::element_count(shape);
let data = generate_integer_data(bounds, len)?;
Tensor::new(data, shape.to_vec()).map_err(|e| builtin_error(format!("randi: {e}")))
}
fn generate_integer_data(bounds: &Bounds, len: usize) -> crate::BuiltinResult<Vec<f64>> {
if len == 0 {
return Ok(Vec::new());
}
if bounds.span == 1 {
return Ok(vec![bounds.lower as f64; len]);
}
let uniforms = random::generate_uniform(len, "randi")?;
let span = bounds.span as f64;
let lower = bounds.lower as i128;
let upper = bounds.upper as i128;
let mut out = Vec::with_capacity(len);
for u in uniforms {
let mut offset = (u * span).floor() as u64;
if offset >= bounds.span {
offset = bounds.span - 1;
}
let mut value = lower
.checked_add(offset as i128)
.ok_or_else(|| builtin_error("randi: integer overflow while sampling"))?;
if value > upper {
value = upper;
}
out.push(value as f64);
}
Ok(out)
}
async fn parse_bounds(value: Value) -> crate::BuiltinResult<Bounds> {
let value = match value {
Value::GpuTensor(_) => crate::dispatcher::gather_if_needed_async(&value)
.await
.map_err(|e| builtin_error(format!("randi: {e}")))?,
other => other,
};
match value {
Value::Tensor(t) => parse_bounds_tensor(&t),
Value::LogicalArray(_) | Value::Bool(_) => Err(builtin_error(
"randi: bounds must be numeric scalars or vectors",
)),
Value::String(s) => Err(builtin_error(format!(
"randi: unexpected option '{s}' in first argument"
))),
Value::StringArray(_) => Err(builtin_error(
"randi: unexpected string array in first argument",
)),
Value::CharArray(_) => Err(builtin_error("randi: string bounds are not supported")),
Value::Complex(_, _) | Value::ComplexTensor(_) => {
Err(builtin_error("randi: complex bounds are not supported"))
}
other => {
let Some(raw) = tensor::scalar_f64_from_value_async(&other)
.await
.map_err(|e| builtin_error(format!("randi: {e}")))?
else {
return Err(builtin_error(format!(
"randi: unsupported bounds argument {other:?}"
)));
};
parse_upper_num(raw)
}
}
}
fn parse_upper_scalar(upper: i64) -> crate::BuiltinResult<Bounds> {
if upper < 1 {
return Err(builtin_error("randi: upper bound must be >= 1"));
}
Bounds::new(1, upper)
}
fn parse_upper_num(n: f64) -> crate::BuiltinResult<Bounds> {
if !n.is_finite() {
return Err(builtin_error("randi: bounds must be finite"));
}
let rounded = n.round();
if (rounded - n).abs() > f64::EPSILON {
return Err(builtin_error("randi: bounds must be integers"));
}
let upper = rounded as i64;
parse_upper_scalar(upper)
}
fn parse_bounds_tensor(tensor: &Tensor) -> crate::BuiltinResult<Bounds> {
let len = tensor.data.len();
if len == 0 {
return Err(builtin_error("randi: empty bound vector is not allowed"));
}
if len == 1 {
return parse_upper_num(tensor.data[0]);
}
if len == 2 && is_vector_like(tensor) {
let lower = parse_integer_component(tensor.data[0])?;
let upper = parse_integer_component(tensor.data[1])?;
Bounds::new(lower, upper)
} else {
Err(builtin_error(
"randi: bound vector must contain exactly two elements",
))
}
}
fn parse_integer_component(value: f64) -> crate::BuiltinResult<i64> {
if !value.is_finite() {
return Err(builtin_error("randi: bounds must be finite"));
}
let rounded = value.round();
if (rounded - value).abs() > f64::EPSILON {
return Err(builtin_error("randi: bounds must be integers"));
}
Ok(rounded as i64)
}
fn is_vector_like(tensor: &Tensor) -> bool {
tensor.rows() == 1 || tensor.cols() == 1 || tensor.shape.len() == 1
}
fn keyword_of(value: &Value) -> Option<String> {
match value {
Value::String(s) => Some(s.to_ascii_lowercase()),
Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].to_ascii_lowercase()),
Value::CharArray(ca) if ca.rows == 1 => {
let text: String = ca.data.iter().collect();
Some(text.to_ascii_lowercase())
}
_ => None,
}
}
async fn extract_dims(value: &Value) -> crate::BuiltinResult<Option<Vec<usize>>> {
if matches!(value, Value::LogicalArray(_)) {
return Ok(None);
}
let gpu_scalar = match value {
Value::GpuTensor(handle) => tensor::element_count(&handle.shape) == 1,
_ => false,
};
match tensor::dims_from_value_async(value).await {
Ok(dims) => Ok(dims),
Err(err) => {
if matches!(value, Value::Tensor(_))
|| (matches!(value, Value::GpuTensor(_)) && !gpu_scalar)
{
Ok(None)
} else {
Err(builtin_error(format!("randi: {err}")))
}
}
}
}
fn shape_from_value(value: &Value) -> crate::BuiltinResult<Vec<usize>> {
match value {
Value::Tensor(t) => Ok(t.shape.clone()),
Value::ComplexTensor(_) => {
Err(builtin_error("randi: complex prototypes are not supported"))
}
Value::LogicalArray(l) => Ok(l.shape.clone()),
Value::GpuTensor(h) => Ok(h.shape.clone()),
Value::CharArray(ca) => Ok(vec![ca.rows, ca.cols]),
Value::Cell(cell) => Ok(vec![cell.rows, cell.cols]),
Value::Num(_) | Value::Int(_) | Value::Bool(_) => Ok(vec![1, 1]),
other => Err(builtin_error(format!(
"randi: unsupported prototype {other:?}"
))),
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::builtins::common::{random, test_support};
use futures::executor::block_on;
use runmat_builtins::LogicalArray;
fn reset_rng_clean() {
runmat_accelerate_api::clear_provider();
random::reset_rng();
}
fn expected_sequence(bounds: &Bounds, count: usize) -> Vec<i64> {
let uniforms = random::expected_uniform_sequence(count);
let span = bounds.span as f64;
uniforms
.into_iter()
.map(|u| {
let mut offset = (u * span).floor() as u64;
if offset >= bounds.span {
offset = bounds.span - 1;
}
bounds.lower + offset as i64
})
.collect()
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_default_scalar() {
let _guard = random::test_lock().lock().unwrap();
reset_rng_clean();
let result = block_on(randi_builtin(vec![Value::Num(6.0)])).expect("randi");
let expected = expected_sequence(&Bounds::new(1, 6).unwrap(), 1)[0] as f64;
match result {
Value::Num(v) => {
assert!((1.0..=6.0).contains(&v));
assert!((v - expected).abs() < 1e-12);
}
other => panic!("expected scalar double, got {other:?}"),
}
}
#[test]
fn randi_type_single_bound_is_num() {
assert_eq!(
randi_type(&[Type::Num], &ResolveContext::new(Vec::new())),
Type::Num
);
}
#[test]
fn randi_type_infers_rank_from_dims() {
let ctx = ResolveContext::new(Vec::new());
assert_eq!(
randi_type(&[Type::Num, Type::Num, Type::Num], &ctx),
Type::Tensor {
shape: Some(vec![None, None])
}
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_range_with_dims() {
let _guard = random::test_lock().lock().unwrap();
reset_rng_clean();
let bounds = Tensor::new(vec![3.0, 8.0], vec![1, 2]).unwrap();
let args = vec![Value::Tensor(bounds), Value::Num(2.0), Value::Num(3.0)];
let result = block_on(randi_builtin(args)).expect("randi");
match result {
Value::Tensor(t) => {
assert_eq!(t.shape, vec![2, 3]);
let expected = expected_sequence(&Bounds::new(3, 8).unwrap(), 6);
for (observed, exp) in t.data.iter().zip(expected.iter().map(|v| *v as f64)) {
assert!((*observed - exp).abs() < 1e-12);
}
}
other => panic!("expected tensor result, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_like_tensor() {
let _guard = random::test_lock().lock().unwrap();
reset_rng_clean();
let proto = Tensor::new(vec![0.0; 4], vec![2, 2]).unwrap();
let args = vec![Value::Num(5.0), Value::from("like"), Value::Tensor(proto)];
let result = block_on(randi_builtin(args)).expect("randi");
match result {
Value::Tensor(t) => {
assert_eq!(t.shape, vec![2, 2]);
for v in &t.data {
assert!((1.0..=5.0).contains(v));
}
}
other => panic!("expected tensor result, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_logical_output() {
let _guard = random::test_lock().lock().unwrap();
reset_rng_clean();
let bounds = Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap();
let args = vec![
Value::Tensor(bounds),
Value::Num(2.0),
Value::Num(2.0),
Value::from("logical"),
];
let result = block_on(randi_builtin(args)).expect("randi logical");
match result {
Value::LogicalArray(logical) => {
assert_eq!(logical.shape, vec![2, 2]);
let expected = expected_sequence(&Bounds::new(0, 1).unwrap(), 4);
for (idx, &byte) in logical.data.iter().enumerate() {
assert!(byte <= 1);
assert_eq!(byte, if expected[idx] == 0 { 0 } else { 1 });
}
}
other => panic!("expected logical array, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_logical_requires_binary_bounds() {
let err =
block_on(randi_builtin(vec![Value::Num(3.0), Value::from("logical")])).unwrap_err();
let message = err.to_string();
assert!(message.contains("logical output requires"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_like_logical_prototype() {
let _guard = random::test_lock().lock().unwrap();
reset_rng_clean();
let proto = LogicalArray::zeros(vec![2, 3]);
let bounds = Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap();
let args = vec![
Value::Tensor(bounds),
Value::from("like"),
Value::LogicalArray(proto),
];
let result = block_on(randi_builtin(args)).expect("randi logical like");
match result {
Value::LogicalArray(logical) => {
assert_eq!(logical.shape, vec![2, 3]);
assert!(logical.data.iter().all(|&b| b <= 1));
}
other => panic!("expected logical array, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_like_requires_prototype() {
let err = block_on(randi_builtin(vec![Value::Num(5.0), Value::from("like")])).unwrap_err();
let message = err.to_string();
assert!(message.contains("expected prototype"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_duplicate_like_is_error() {
let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
let args = vec![
Value::Num(5.0),
Value::from("like"),
Value::Tensor(proto.clone()),
Value::from("like"),
Value::Tensor(proto),
];
let err = block_on(randi_builtin(args)).unwrap_err();
let message = err.to_string();
assert!(message.contains("multiple 'like' specifications"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_like_logical_conflict_is_error() {
let proto = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
let args = vec![
Value::Num(1.0),
Value::from("logical"),
Value::from("like"),
Value::Tensor(proto),
];
let err = block_on(randi_builtin(args)).unwrap_err();
let message = err.to_string();
assert!(message.contains("cannot combine 'like' with 'logical'"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_gpu_like_roundtrip() {
let _guard = random::test_lock().lock().unwrap();
random::reset_rng();
test_support::with_test_provider(|provider| {
let tensor = Tensor::new(vec![0.0; 4], vec![2, 2]).unwrap();
let view = HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
let handle = provider.upload(&view).expect("upload");
let args = vec![
Value::Num(4.0),
Value::from("like"),
Value::GpuTensor(handle),
];
let result = block_on(randi_builtin(args)).expect("randi");
match result {
Value::GpuTensor(gpu) => {
let gathered =
test_support::gather(Value::GpuTensor(gpu)).expect("gather to host");
assert_eq!(gathered.shape, vec![2, 2]);
for value in gathered.data {
assert!((1.0..=4.0).contains(&value));
}
}
other => panic!("expected GPU tensor, got {other:?}"),
}
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_gpu_like_shape_override() {
let _guard = random::test_lock().lock().unwrap();
random::reset_rng();
test_support::with_test_provider(|provider| {
let proto = Tensor::new(vec![0.0; 4], vec![2, 2]).unwrap();
let view = HostTensorView {
data: &proto.data,
shape: &proto.shape,
};
let handle = provider.upload(&view).expect("upload");
let bounds = Tensor::new(vec![1.0, 4.0], vec![1, 2]).unwrap();
let args = vec![
Value::Tensor(bounds),
Value::Num(3.0),
Value::Num(1.0),
Value::from("like"),
Value::GpuTensor(handle),
];
let result = block_on(randi_builtin(args)).expect("randi gpu override");
match result {
Value::GpuTensor(gpu) => {
let gathered =
test_support::gather(Value::GpuTensor(gpu)).expect("gather override");
assert_eq!(gathered.shape, vec![3, 1]);
for value in gathered.data {
assert!((1.0..=4.0).contains(&value));
}
}
other => panic!("expected GPU tensor, got {other:?}"),
}
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn randi_invalid_upper_errors() {
let err = block_on(randi_builtin(vec![Value::Num(0.0)])).unwrap_err();
let message = err.to_string();
assert!(message.contains("upper bound"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "wgpu")]
fn randi_wgpu_like_produces_in_range_values() {
let _guard = random::test_lock().lock().unwrap();
random::reset_rng();
let provider = match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
) {
Ok(_) => runmat_accelerate_api::provider().expect("wgpu provider registered"),
Err(err) => {
tracing::warn!("randi_wgpu_like_produces_in_range_values skipped: {err}");
return;
}
};
let proto = Tensor::new(vec![0.0; 6], vec![2, 3]).unwrap();
let view = runmat_accelerate_api::HostTensorView {
data: &proto.data,
shape: &proto.shape,
};
let handle = provider.upload(&view).expect("upload prototype");
let bounds = Tensor::new(vec![1.0, 8.0], vec![1, 2]).unwrap();
let args = vec![
Value::Tensor(bounds),
Value::from("like"),
Value::GpuTensor(handle),
];
let result = block_on(randi_builtin(args)).expect("randi");
match result {
Value::GpuTensor(gpu) => {
let gathered =
test_support::gather(Value::GpuTensor(gpu)).expect("gather gpu result");
assert_eq!(gathered.shape, vec![2, 3]);
for value in gathered.data {
assert!(
(1.0..=8.0).contains(&value),
"expected value within [1, 8], got {value}"
);
}
}
other => panic!("expected GPU tensor result, got {other:?}"),
}
}
}