Struct ha_ndarray::ArrayOp
source · pub struct ArrayOp<Op> { /* private fields */ }
Implementations§
source§impl<Op> ArrayOp<Op>
impl<Op> ArrayOp<Op>
sourcepub fn new(shape: Shape, op: Op) -> Self
pub fn new(shape: Shape, op: Op) -> Self
Examples found in repository?
examples/backprop.rs (line 12)
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
fn main() -> Result<(), Error> {
let context = Context::default()?;
let weights = RandomNormal::with_context(context.clone(), 2)?;
let weights = ArrayOp::new(vec![2, 1], weights) - 0.5;
let mut weights = ArrayBase::<Arc<RwLock<Buffer<f32>>>>::copy(&weights)?;
let inputs = RandomUniform::with_context(context, vec![NUM_EXAMPLES, 2])?;
let inputs = ArrayOp::new(vec![NUM_EXAMPLES, 2], inputs) * 2.;
let inputs = ArrayBase::<Arc<Buffer<f32>>>::copy(&inputs)?;
let inputs_bool = inputs.clone().lt_scalar(1.0)?;
let inputs_left = inputs_bool
.clone()
.slice(vec![(0..NUM_EXAMPLES).into(), 0.into()])?;
let inputs_right = inputs_bool.slice(vec![(0..NUM_EXAMPLES).into(), 1.into()])?;
let labels = inputs_left
.and(inputs_right)?
.expand_dims(vec![1])?
.cast()?;
let labels = ArrayBase::<Buffer<f32>>::copy(&labels)?;
let output = inputs.matmul(weights.clone())?;
let error = labels.sub(output)?;
let loss = error.clone().pow_scalar(2.)?;
let d_loss = error * 2.;
let weights_t = weights.clone().transpose(None)?;
let gradient = d_loss.matmul(weights_t)?;
let deltas = gradient.sum(vec![0], false)?.expand_dims(vec![1])?;
let new_weights = weights.clone().add(deltas * LEARNING_RATE)?;
let mut i = 0;
loop {
let loss = ArrayBase::<Buffer<f32>>::copy(&loss)?;
if loss.clone().lt_scalar(1.0)?.all()? {
return Ok(());
}
if i % 100 == 0 {
println!(
"loss: {} (max {})",
loss.clone().sum_all()?,
loss.clone().max_all()?
);
}
assert!(!loss.clone().is_inf()?.any()?, "divergence at iteration {i}");
assert!(!loss.is_nan()?.any()?, "unstable by iteration {i}");
weights.write(&new_weights)?;
i += 1;
}
}
Trait Implementations§
source§impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArrayBase<O>> for ArrayOp<Op>where
ArrayBase<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArrayBase<O>> for ArrayOp<Op>where ArrayBase<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Add<ArrayOp<O>> for ArrayBase<Buf>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Add<ArrayOp<O>> for ArrayBase<Buf>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArrayOp<O>> for ArrayOp<Op>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArrayOp<O>> for ArrayOp<Op>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Add<ArrayOp<O>> for ArraySlice<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Add<ArrayOp<O>> for ArraySlice<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Add<ArrayOp<O>> for ArrayView<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Add<ArrayOp<O>> for ArrayView<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArraySlice<O>> for ArrayOp<Op>where
ArraySlice<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArraySlice<O>> for ArrayOp<Op>where ArraySlice<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArrayView<O>> for ArrayOp<Op>where
ArrayView<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Add<ArrayView<O>> for ArrayOp<Op>where ArrayView<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArrayBase<O>> for ArrayOp<Op>where
ArrayBase<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArrayBase<O>> for ArrayOp<Op>where ArrayBase<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Div<ArrayOp<O>> for ArrayBase<Buf>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Div<ArrayOp<O>> for ArrayBase<Buf>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArrayOp<O>> for ArrayOp<Op>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArrayOp<O>> for ArrayOp<Op>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Div<ArrayOp<O>> for ArraySlice<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Div<ArrayOp<O>> for ArraySlice<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Div<ArrayOp<O>> for ArrayView<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Div<ArrayOp<O>> for ArrayView<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArraySlice<O>> for ArrayOp<Op>where
ArraySlice<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArraySlice<O>> for ArrayOp<Op>where ArraySlice<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArrayView<O>> for ArrayOp<Op>where
ArrayView<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Div<ArrayView<O>> for ArrayOp<Op>where ArrayView<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArrayBase<O>> for ArrayOp<Op>where
ArrayBase<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArrayBase<O>> for ArrayOp<Op>where ArrayBase<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Mul<ArrayOp<O>> for ArrayBase<Buf>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Mul<ArrayOp<O>> for ArrayBase<Buf>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArrayOp<O>> for ArrayOp<Op>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArrayOp<O>> for ArrayOp<Op>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Mul<ArrayOp<O>> for ArraySlice<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Mul<ArrayOp<O>> for ArraySlice<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Mul<ArrayOp<O>> for ArrayView<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Mul<ArrayOp<O>> for ArrayView<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArraySlice<O>> for ArrayOp<Op>where
ArraySlice<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArraySlice<O>> for ArrayOp<Op>where ArraySlice<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArrayView<O>> for ArrayOp<Op>where
ArrayView<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Mul<ArrayView<O>> for ArrayOp<Op>where ArrayView<O>: NDArray<DType = T>,
source§impl<Op: Op> NDArrayRead for ArrayOp<Op>
impl<Op: Op> NDArrayRead for ArrayOp<Op>
source§impl<Op: Op> NDArrayTransform for ArrayOp<Op>
impl<Op: Op> NDArrayTransform for ArrayOp<Op>
type Broadcast = ArrayView<ArrayOp<Op>>
type Expand = ArrayOp<Op>
type Reshape = ArrayOp<Op>
type Slice = ArraySlice<ArrayOp<Op>>
type Transpose = ArrayView<ArrayOp<Op>>
fn broadcast(self, shape: Shape) -> Result<Self::Broadcast, Error>
fn expand_dims(self, axes: Vec<usize>) -> Result<Self::Expand, Error>
fn reshape(self, shape: Shape) -> Result<Self::Reshape, Error>
fn slice(self, bounds: Vec<AxisBound>) -> Result<Self::Slice, Error>
fn transpose(self, axes: Option<Vec<usize>>) -> Result<Self::Transpose, Error>
source§impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArrayBase<O>> for ArrayOp<Op>where
ArrayBase<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArrayBase<O>> for ArrayOp<Op>where ArrayBase<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Rem<ArrayOp<O>> for ArrayBase<Buf>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Rem<ArrayOp<O>> for ArrayBase<Buf>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArrayOp<O>> for ArrayOp<Op>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArrayOp<O>> for ArrayOp<Op>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Rem<ArrayOp<O>> for ArraySlice<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Rem<ArrayOp<O>> for ArraySlice<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Rem<ArrayOp<O>> for ArrayView<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Rem<ArrayOp<O>> for ArrayView<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArraySlice<O>> for ArrayOp<Op>where
ArraySlice<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArraySlice<O>> for ArrayOp<Op>where ArraySlice<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArrayView<O>> for ArrayOp<Op>where
ArrayView<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Rem<ArrayView<O>> for ArrayOp<Op>where ArrayView<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Sub<ArrayBase<O>> for ArrayOp<Op>where
ArrayBase<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Sub<ArrayBase<O>> for ArrayOp<Op>where ArrayBase<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Sub<ArrayOp<O>> for ArrayBase<Buf>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Buf: BufferInstance<DType = T>, O> Sub<ArrayOp<O>> for ArrayBase<Buf>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Sub<ArrayOp<O>> for ArrayOp<Op>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Sub<ArrayOp<O>> for ArrayOp<Op>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Sub<ArrayOp<O>> for ArraySlice<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Sub<ArrayOp<O>> for ArraySlice<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, A: NDArray<DType = T>, O> Sub<ArrayOp<O>> for ArrayView<A>where
ArrayOp<O>: NDArray<DType = T>,
impl<T: CDatatype, A: NDArray<DType = T>, O> Sub<ArrayOp<O>> for ArrayView<A>where ArrayOp<O>: NDArray<DType = T>,
source§impl<T: CDatatype, Op: Op<Out = T>, O> Sub<ArraySlice<O>> for ArrayOp<Op>where
ArraySlice<O>: NDArray<DType = T>,
impl<T: CDatatype, Op: Op<Out = T>, O> Sub<ArraySlice<O>> for ArrayOp<Op>where ArraySlice<O>: NDArray<DType = T>,
Auto Trait Implementations§
impl<Op> RefUnwindSafe for ArrayOp<Op>where Op: RefUnwindSafe,
impl<Op> Send for ArrayOp<Op>where Op: Send,
impl<Op> Sync for ArrayOp<Op>where Op: Sync,
impl<Op> Unpin for ArrayOp<Op>where Op: Unpin,
impl<Op> UnwindSafe for ArrayOp<Op>where Op: UnwindSafe,
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more