Skip to main content

openinfer_simulator/
random.rs

1use std::marker::PhantomData;
2
3use anyhow::{anyhow, Result};
4use rand::rngs::StdRng;
5use rand::{Rng, SeedableRng};
6
7use crate::tensor::{numel, BF16, F8, I1, I2, I4, Tensor, TensorOptions};
8
9/// Random tensor generator for a specific element type.
10pub struct Random<T> {
11    rng: StdRng,
12    _marker: PhantomData<T>,
13}
14
15impl<T> Random<T>
16where
17    T: RandomValue,
18{
19    /// Create a seeded random generator.
20    pub fn with_seed(seed: u64) -> Self {
21        Self {
22            rng: StdRng::seed_from_u64(seed),
23            _marker: PhantomData,
24        }
25    }
26
27    /// Generate a tensor with a default seed and options.
28    pub fn generate(range: (T, T), len: usize) -> Result<Tensor<T>> {
29        Self::generate_with_seed_opts(0, range, len, TensorOptions::default())
30    }
31
32    /// Generate a tensor with custom options and default seed.
33    pub fn generate_with_opts(range: (T, T), len: usize, opts: TensorOptions) -> Result<Tensor<T>> {
34        Self::generate_with_seed_opts(0, range, len, opts)
35    }
36
37    /// Generate a tensor with an explicit seed.
38    pub fn generate_with_seed(seed: u64, range: (T, T), len: usize) -> Result<Tensor<T>> {
39        Self::generate_with_seed_opts(seed, range, len, TensorOptions::default())
40    }
41
42    /// Generate a tensor with an explicit seed and options.
43    pub fn generate_with_seed_opts(
44        seed: u64,
45        range: (T, T),
46        len: usize,
47        opts: TensorOptions,
48    ) -> Result<Tensor<T>> {
49        let mut rng = StdRng::seed_from_u64(seed);
50        generate_with_rng::<T>(&mut rng, range, len, opts)
51    }
52
53    /// Generate the next tensor using the internal RNG.
54    pub fn next(&mut self, range: (T, T), len: usize) -> Result<Tensor<T>> {
55        self.next_with_opts(range, len, TensorOptions::default())
56    }
57
58    /// Generate the next tensor using the internal RNG and options.
59    pub fn next_with_opts(
60        &mut self,
61        range: (T, T),
62        len: usize,
63        opts: TensorOptions,
64    ) -> Result<Tensor<T>> {
65        generate_with_rng::<T>(&mut self.rng, range, len, opts)
66    }
67}
68
69fn generate_with_rng<T: RandomValue>(
70    rng: &mut StdRng,
71    range: (T, T),
72    len: usize,
73    opts: TensorOptions,
74) -> Result<Tensor<T>> {
75    if opts.shape.is_none() && opts.strides.is_some() {
76        return Err(anyhow!("random tensor strides require an explicit shape"));
77    }
78    let shape = match opts.shape.as_ref() {
79        Some(shape) => {
80            let expected = numel(shape);
81            if expected != len {
82                return Err(anyhow!(
83                    "random tensor shape {:?} expects {} values, got {}",
84                    shape,
85                    expected,
86                    len
87                ));
88            }
89            shape.clone()
90        }
91        None => vec![len],
92    };
93    let mut data = Vec::with_capacity(len);
94    for _ in 0..len {
95        data.push(T::sample(rng, range)?);
96    }
97    Tensor::from_vec_with_opts(
98        data,
99        TensorOptions {
100            shape: Some(shape),
101            strides: opts.strides,
102            allow_len_mismatch: opts.allow_len_mismatch,
103        },
104    )
105}
106
107/// Trait for values that can be sampled by `Random`.
108pub trait RandomValue: Sized + Copy {
109    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self>;
110}
111
112impl RandomValue for f32 {
113    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
114        Ok(rng.gen_range(range.0..=range.1))
115    }
116}
117
118impl RandomValue for f64 {
119    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
120        Ok(rng.gen_range(range.0..=range.1))
121    }
122}
123
124impl RandomValue for i8 {
125    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
126        Ok(rng.gen_range(range.0..=range.1))
127    }
128}
129
130impl RandomValue for i16 {
131    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
132        Ok(rng.gen_range(range.0..=range.1))
133    }
134}
135
136impl RandomValue for i32 {
137    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
138        Ok(rng.gen_range(range.0..=range.1))
139    }
140}
141
142impl RandomValue for i64 {
143    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
144        Ok(rng.gen_range(range.0..=range.1))
145    }
146}
147
148impl RandomValue for u8 {
149    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
150        Ok(rng.gen_range(range.0..=range.1))
151    }
152}
153
154impl RandomValue for u16 {
155    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
156        Ok(rng.gen_range(range.0..=range.1))
157    }
158}
159
160impl RandomValue for u32 {
161    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
162        Ok(rng.gen_range(range.0..=range.1))
163    }
164}
165
166impl RandomValue for u64 {
167    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
168        Ok(rng.gen_range(range.0..=range.1))
169    }
170}
171
172impl RandomValue for crate::tensor::F16 {
173    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
174        let value = rng.gen_range(range.0.to_f32()..=range.1.to_f32());
175        Ok(crate::tensor::F16::from_f32(value))
176    }
177}
178
179impl RandomValue for BF16 {
180    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
181        let value = rng.gen_range(range.0.to_f32()..=range.1.to_f32());
182        Ok(BF16::from_f32(value))
183    }
184}
185
186impl RandomValue for F8 {
187    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
188        let value = rng.gen_range(range.0.to_f32()..=range.1.to_f32());
189        Ok(F8::from_f32(value))
190    }
191}
192
193impl RandomValue for I4 {
194    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
195        let value = rng.gen_range(range.0.to_i8()..=range.1.to_i8());
196        Ok(I4::from_i8(value))
197    }
198}
199
200impl RandomValue for I2 {
201    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
202        let value = rng.gen_range(range.0.to_i8()..=range.1.to_i8());
203        Ok(I2::from_i8(value))
204    }
205}
206
207impl RandomValue for I1 {
208    fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
209        let value = rng.gen_range(range.0.to_i8()..=range.1.to_i8());
210        Ok(I1::from_i8(value))
211    }
212}