Trait ha_ndarray::NDArrayNumeric
source · pub trait NDArrayNumeric: NDArray + Sizedwhere
Self::DType: Float,{
// Provided methods
fn is_inf(self) -> Result<ArrayOp<ArrayUnary<Self::DType, u8, Self>>, Error> { ... }
fn is_nan(self) -> Result<ArrayOp<ArrayUnary<Self::DType, u8, Self>>, Error> { ... }
}
Expand description
Float-specific array methods
Provided Methods§
sourcefn is_inf(self) -> Result<ArrayOp<ArrayUnary<Self::DType, u8, Self>>, Error>
fn is_inf(self) -> Result<ArrayOp<ArrayUnary<Self::DType, u8, Self>>, Error>
Test which elements of this array are infinite.
Examples found in repository?
examples/backprop.rs (line 60)
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
fn main() -> Result<(), Error> {
let context = Context::default()?;
let weights = RandomNormal::with_context(context.clone(), 2)?;
let weights = ArrayOp::new(vec![2, 1], weights) - 0.5;
let mut weights = ArrayBase::<Arc<RwLock<Buffer<f32>>>>::copy(&weights)?;
let inputs = RandomUniform::with_context(context, vec![NUM_EXAMPLES, 2])?;
let inputs = ArrayOp::new(vec![NUM_EXAMPLES, 2], inputs) * 2.;
let inputs = ArrayBase::<Arc<Buffer<f32>>>::copy(&inputs)?;
let inputs_bool = inputs.clone().lt_scalar(1.0)?;
let inputs_left = inputs_bool
.clone()
.slice(vec![(0..NUM_EXAMPLES).into(), 0.into()])?;
let inputs_right = inputs_bool.slice(vec![(0..NUM_EXAMPLES).into(), 1.into()])?;
let labels = inputs_left
.and(inputs_right)?
.expand_dims(vec![1])?
.cast()?;
let labels = ArrayBase::<Buffer<f32>>::copy(&labels)?;
let output = inputs.matmul(weights.clone())?;
let error = labels.sub(output)?;
let loss = error.clone().pow_scalar(2.)?;
let d_loss = error * 2.;
let weights_t = weights.clone().transpose(None)?;
let gradient = d_loss.matmul(weights_t)?;
let deltas = gradient.sum(vec![0], false)?.expand_dims(vec![1])?;
let new_weights = weights.clone().add(deltas * LEARNING_RATE)?;
let mut i = 0;
loop {
let loss = ArrayBase::<Buffer<f32>>::copy(&loss)?;
if loss.clone().lt_scalar(1.0)?.all()? {
return Ok(());
}
if i % 100 == 0 {
println!(
"loss: {} (max {})",
loss.clone().sum_all()?,
loss.clone().max_all()?
);
}
assert!(!loss.clone().is_inf()?.any()?, "divergence at iteration {i}");
assert!(!loss.is_nan()?.any()?, "unstable by iteration {i}");
weights.write(&new_weights)?;
i += 1;
}
}
sourcefn is_nan(self) -> Result<ArrayOp<ArrayUnary<Self::DType, u8, Self>>, Error>
fn is_nan(self) -> Result<ArrayOp<ArrayUnary<Self::DType, u8, Self>>, Error>
Test which elements of this array are not-a-number.
Examples found in repository?
examples/backprop.rs (line 61)
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
fn main() -> Result<(), Error> {
let context = Context::default()?;
let weights = RandomNormal::with_context(context.clone(), 2)?;
let weights = ArrayOp::new(vec![2, 1], weights) - 0.5;
let mut weights = ArrayBase::<Arc<RwLock<Buffer<f32>>>>::copy(&weights)?;
let inputs = RandomUniform::with_context(context, vec![NUM_EXAMPLES, 2])?;
let inputs = ArrayOp::new(vec![NUM_EXAMPLES, 2], inputs) * 2.;
let inputs = ArrayBase::<Arc<Buffer<f32>>>::copy(&inputs)?;
let inputs_bool = inputs.clone().lt_scalar(1.0)?;
let inputs_left = inputs_bool
.clone()
.slice(vec![(0..NUM_EXAMPLES).into(), 0.into()])?;
let inputs_right = inputs_bool.slice(vec![(0..NUM_EXAMPLES).into(), 1.into()])?;
let labels = inputs_left
.and(inputs_right)?
.expand_dims(vec![1])?
.cast()?;
let labels = ArrayBase::<Buffer<f32>>::copy(&labels)?;
let output = inputs.matmul(weights.clone())?;
let error = labels.sub(output)?;
let loss = error.clone().pow_scalar(2.)?;
let d_loss = error * 2.;
let weights_t = weights.clone().transpose(None)?;
let gradient = d_loss.matmul(weights_t)?;
let deltas = gradient.sum(vec![0], false)?.expand_dims(vec![1])?;
let new_weights = weights.clone().add(deltas * LEARNING_RATE)?;
let mut i = 0;
loop {
let loss = ArrayBase::<Buffer<f32>>::copy(&loss)?;
if loss.clone().lt_scalar(1.0)?.all()? {
return Ok(());
}
if i % 100 == 0 {
println!(
"loss: {} (max {})",
loss.clone().sum_all()?,
loss.clone().max_all()?
);
}
assert!(!loss.clone().is_inf()?.any()?, "divergence at iteration {i}");
assert!(!loss.is_nan()?.any()?, "unstable by iteration {i}");
weights.write(&new_weights)?;
i += 1;
}
}
Object Safety§
This trait is not object safe.