1#![allow(incomplete_features)]
2#![feature(generic_const_exprs)]
3
4pub type Float = f64;
5
6#[doc(hidden)]
7pub struct Assert<const CHECK: bool>;
8
9#[doc(hidden)]
10pub trait IsTrue {}
11
12impl IsTrue for Assert<true> {}
13
14#[doc(hidden)]
15pub trait ReshapePreservesElementCount<const FROM: usize, const TO: usize> {}
16
17impl<const N: usize> ReshapePreservesElementCount<N, N> for () {}
18
19#[doc(hidden)]
20pub trait ConvGeometryIsValid<
21 const H: usize,
22 const W: usize,
23 const FH: usize,
24 const FW: usize,
25 const S: usize,
26 const P: usize,
27>
28{
29}
30
31impl<
32 const H: usize,
33 const W: usize,
34 const FH: usize,
35 const FW: usize,
36 const S: usize,
37 const P: usize,
38> ConvGeometryIsValid<H, W, FH, FW, S, P> for ()
39where
40 Assert<{ conv::conv_out_dim(H, P, FH, S) > 0 }>: IsTrue,
41 Assert<{ conv::conv_out_dim(W, P, FW, S) > 0 }>: IsTrue,
42{
43}
44
45pub mod shape;
46mod tensor;
47
48pub mod conv;
49pub mod data;
50
51pub use autodiff::{EvalTape, ExprGraph, Gradients, NodeId, Op, ReverseTape, Tape, TapeError, Var};
52pub use data::Sample;
53pub use network::{
54 Adam, DenseLayer, Flatten, Initializer, KaimingUniform, Layer, LayerDims, LossFunction,
55 MeanSquaredError, ModelBuilder, Optimizer, ReLU, Sequential, Sgd, Sigmoid, TrainConfig,
56 Uniform, XavierUniform, mse_loss,
57};
58pub use shape::TensorShape;
59#[doc(hidden)]
60pub use tensor::__tensor_from_literal;
61pub use tensor::{Tensor, TensorMut, TensorRef};
62
63pub mod network;
65
66pub mod autodiff;