Struct ha_ndarray::ops::RandomNormal
source · pub struct RandomNormal { /* private fields */ }
Implementations§
source§impl RandomNormal
impl RandomNormal
pub fn new(size: usize) -> Result<Self, Error>
sourcepub fn with_context(context: Context, size: usize) -> Result<Self, Error>
pub fn with_context(context: Context, size: usize) -> Result<Self, Error>
Examples found in repository?
examples/backprop.rs (line 11)
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
fn main() -> Result<(), Error> {
let context = Context::default()?;
let weights = RandomNormal::with_context(context.clone(), 2)?;
let weights = ArrayOp::new(vec![2, 1], weights) - 0.5;
let mut weights = ArrayBase::<Arc<RwLock<Buffer<f32>>>>::copy(&weights)?;
let inputs = RandomUniform::with_context(context, vec![NUM_EXAMPLES, 2])?;
let inputs = ArrayOp::new(vec![NUM_EXAMPLES, 2], inputs) * 2.;
let inputs = ArrayBase::<Arc<Buffer<f32>>>::copy(&inputs)?;
let inputs_bool = inputs.clone().lt_scalar(1.0)?;
let inputs_left = inputs_bool
.clone()
.slice(vec![(0..NUM_EXAMPLES).into(), 0.into()])?;
let inputs_right = inputs_bool.slice(vec![(0..NUM_EXAMPLES).into(), 1.into()])?;
let labels = inputs_left
.and(inputs_right)?
.expand_dims(vec![1])?
.cast()?;
let labels = ArrayBase::<Buffer<f32>>::copy(&labels)?;
let output = inputs.matmul(weights.clone())?;
let error = labels.sub(output)?;
let loss = error.clone().pow_scalar(2.)?;
let d_loss = error * 2.;
let weights_t = weights.clone().transpose(None)?;
let gradient = d_loss.matmul(weights_t)?;
let deltas = gradient.sum(vec![0], false)?.expand_dims(vec![1])?;
let new_weights = weights.clone().add(deltas * LEARNING_RATE)?;
let mut i = 0;
loop {
let loss = ArrayBase::<Buffer<f32>>::copy(&loss)?;
if loss.clone().lt_scalar(1.0)?.all()? {
return Ok(());
}
if i % 100 == 0 {
println!(
"loss: {} (max {})",
loss.clone().sum_all()?,
loss.clone().max_all()?
);
}
assert!(!loss.clone().is_inf()?.any()?, "divergence at iteration {i}");
assert!(!loss.is_nan()?.any()?, "unstable by iteration {i}");
weights.write(&new_weights)?;
i += 1;
}
}
Trait Implementations§
source§impl Clone for RandomNormal
impl Clone for RandomNormal
source§fn clone(&self) -> RandomNormal
fn clone(&self) -> RandomNormal
Returns a copy of the value. Read more
1.0.0 · source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source
. Read moresource§impl Op for RandomNormal
impl Op for RandomNormal
Auto Trait Implementations§
impl RefUnwindSafe for RandomNormal
impl Send for RandomNormal
impl Sync for RandomNormal
impl Unpin for RandomNormal
impl UnwindSafe for RandomNormal
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more