1use std::marker::PhantomData;
2
3use anyhow::{anyhow, Result};
4use rand::rngs::StdRng;
5use rand::{Rng, SeedableRng};
6
7use crate::tensor::{numel, BF16, F8, I1, I2, I4, Tensor, TensorOptions};
8
9pub struct Random<T> {
11 rng: StdRng,
12 _marker: PhantomData<T>,
13}
14
15impl<T> Random<T>
16where
17 T: RandomValue,
18{
19 pub fn with_seed(seed: u64) -> Self {
21 Self {
22 rng: StdRng::seed_from_u64(seed),
23 _marker: PhantomData,
24 }
25 }
26
27 pub fn generate(range: (T, T), len: usize) -> Result<Tensor<T>> {
29 Self::generate_with_seed_opts(0, range, len, TensorOptions::default())
30 }
31
32 pub fn generate_with_opts(range: (T, T), len: usize, opts: TensorOptions) -> Result<Tensor<T>> {
34 Self::generate_with_seed_opts(0, range, len, opts)
35 }
36
37 pub fn generate_with_seed(seed: u64, range: (T, T), len: usize) -> Result<Tensor<T>> {
39 Self::generate_with_seed_opts(seed, range, len, TensorOptions::default())
40 }
41
42 pub fn generate_with_seed_opts(
44 seed: u64,
45 range: (T, T),
46 len: usize,
47 opts: TensorOptions,
48 ) -> Result<Tensor<T>> {
49 let mut rng = StdRng::seed_from_u64(seed);
50 generate_with_rng::<T>(&mut rng, range, len, opts)
51 }
52
53 pub fn next(&mut self, range: (T, T), len: usize) -> Result<Tensor<T>> {
55 self.next_with_opts(range, len, TensorOptions::default())
56 }
57
58 pub fn next_with_opts(
60 &mut self,
61 range: (T, T),
62 len: usize,
63 opts: TensorOptions,
64 ) -> Result<Tensor<T>> {
65 generate_with_rng::<T>(&mut self.rng, range, len, opts)
66 }
67}
68
69fn generate_with_rng<T: RandomValue>(
70 rng: &mut StdRng,
71 range: (T, T),
72 len: usize,
73 opts: TensorOptions,
74) -> Result<Tensor<T>> {
75 if opts.shape.is_none() && opts.strides.is_some() {
76 return Err(anyhow!("random tensor strides require an explicit shape"));
77 }
78 let shape = match opts.shape.as_ref() {
79 Some(shape) => {
80 let expected = numel(shape);
81 if expected != len {
82 return Err(anyhow!(
83 "random tensor shape {:?} expects {} values, got {}",
84 shape,
85 expected,
86 len
87 ));
88 }
89 shape.clone()
90 }
91 None => vec![len],
92 };
93 let mut data = Vec::with_capacity(len);
94 for _ in 0..len {
95 data.push(T::sample(rng, range)?);
96 }
97 Tensor::from_vec_with_opts(
98 data,
99 TensorOptions {
100 shape: Some(shape),
101 strides: opts.strides,
102 allow_len_mismatch: opts.allow_len_mismatch,
103 },
104 )
105}
106
107pub trait RandomValue: Sized + Copy {
109 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self>;
110}
111
112impl RandomValue for f32 {
113 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
114 Ok(rng.gen_range(range.0..=range.1))
115 }
116}
117
118impl RandomValue for f64 {
119 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
120 Ok(rng.gen_range(range.0..=range.1))
121 }
122}
123
124impl RandomValue for i8 {
125 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
126 Ok(rng.gen_range(range.0..=range.1))
127 }
128}
129
130impl RandomValue for i16 {
131 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
132 Ok(rng.gen_range(range.0..=range.1))
133 }
134}
135
136impl RandomValue for i32 {
137 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
138 Ok(rng.gen_range(range.0..=range.1))
139 }
140}
141
142impl RandomValue for i64 {
143 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
144 Ok(rng.gen_range(range.0..=range.1))
145 }
146}
147
148impl RandomValue for u8 {
149 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
150 Ok(rng.gen_range(range.0..=range.1))
151 }
152}
153
154impl RandomValue for u16 {
155 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
156 Ok(rng.gen_range(range.0..=range.1))
157 }
158}
159
160impl RandomValue for u32 {
161 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
162 Ok(rng.gen_range(range.0..=range.1))
163 }
164}
165
166impl RandomValue for u64 {
167 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
168 Ok(rng.gen_range(range.0..=range.1))
169 }
170}
171
172impl RandomValue for crate::tensor::F16 {
173 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
174 let value = rng.gen_range(range.0.to_f32()..=range.1.to_f32());
175 Ok(crate::tensor::F16::from_f32(value))
176 }
177}
178
179impl RandomValue for BF16 {
180 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
181 let value = rng.gen_range(range.0.to_f32()..=range.1.to_f32());
182 Ok(BF16::from_f32(value))
183 }
184}
185
186impl RandomValue for F8 {
187 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
188 let value = rng.gen_range(range.0.to_f32()..=range.1.to_f32());
189 Ok(F8::from_f32(value))
190 }
191}
192
193impl RandomValue for I4 {
194 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
195 let value = rng.gen_range(range.0.to_i8()..=range.1.to_i8());
196 Ok(I4::from_i8(value))
197 }
198}
199
200impl RandomValue for I2 {
201 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
202 let value = rng.gen_range(range.0.to_i8()..=range.1.to_i8());
203 Ok(I2::from_i8(value))
204 }
205}
206
207impl RandomValue for I1 {
208 fn sample(rng: &mut StdRng, range: (Self, Self)) -> Result<Self> {
209 let value = rng.gen_range(range.0.to_i8()..=range.1.to_i8());
210 Ok(I1::from_i8(value))
211 }
212}