kerasty/lib.rs
1//! Keras for Candle (Rust ML framework) with support for Web Assembly.
2//!
3//! # Roadmap of Supported Layers
4//!
5//! | Layer | State | Example |
6//! |------------|-------|-----------------------------------------------------------|
7//! | Dense |✅| [Dense](https://docs.rs/kerasty/latest/kerasty/layer/dense/struct.Dense.html) |
8//! | Convolution |🏗️| CNN|
9//! | Normalization |🏗️| Norm|
10//! | Flatten |🏗️| Flatten|
11//! | Pooling |🏗️| Pool|
12//! | Recurrent |🏗️| RNN|
13//! | Attention |🏗️| Attn|
14//! | Bert |🏗️| BERT|
15//! | Llama |🏗️| LLAMA|
16//!
17//!# Examples
18//!Solution to the classic [XOR problem](https://www.geeksforgeeks.org/how-neural-networks-solve-the-xor-problem)
19//!```rust,no_run
20//!use kerasty::{Dense, Device, Loss, Metric, Model, Optimizer, Result, Sequential, Tensor};
21//!
22//!fn test_xor_problem() -> Result<()> {
23//! // Define the XOR input and output data
24//! let x_data = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0];
25//! let x = Tensor::from_slice(&x_data, (4, 2), &Device::Cpu)?;
26//! let y_data = vec![0.0, 1.0, 1.0, 0.0];
27//! let y = Tensor::from_slice(&y_data, (4, 1), &Device::Cpu)?;
28//!
29//! // Build the neural network model
30//! let mut model = Sequential::new();
31//! model.add(Dense::new(2, 2, "relu"));
32//! model.add(Dense::new(1, 2, "sigmoid"));
33//!
34//! // Compile the model
35//! model.compile(
36//! Optimizer::Adam(0.001, 0.9, 0.999, 1e-8, 0.0),
37//! Loss::BinaryCrossEntropyWithLogit,
38//! vec![Metric::Accuracy],
39//! )?;
40//!
41//! // Train the model
42//! model.fit(x.clone(), y.clone(), 10000)?;
43//!
44//! // Make predictions
45//! let predictions = model.predict(&x);
46//! let predictions = predictions.reshape(4)?.to_vec1::<f64>()?;
47//! let predictions: Vec<i32> = predictions
48//! .iter()
49//! .map(|&p| if p >= 0.5 { 1 } else { 0 })
50//! .collect();
51//!
52//! println!("Predictions:");
53//! for i in 0..4 {
54//! println!(
55//! "Input: {:?} => Predicted Output: {}, Actual Output: {}",
56//! &x_data[i * 2..i * 2 + 2],
57//! predictions[i],
58//! y_data[i]
59//! );
60//! }
61//!
62//! // Evaluate the model
63//! let score = model.evaluate(&x, &y);
64//! println!("Average loss: {}", score[0]);
65//! println!("Accuracy: {}", score[1]);
66//!
67//! Ok(())
68//!}
69//!```
70//!The expected output is as follows:
71//!```shell,no_run
72//!Predictions:
73//!Input: [0.0, 0.0] => Predicted Output: 0, Actual Output: 0
74//!Input: [0.0, 1.0] => Predicted Output: 1, Actual Output: 1
75//!Input: [1.0, 0.0] => Predicted Output: 1, Actual Output: 1
76//!Input: [1.0, 1.0] => Predicted Output: 0, Actual Output: 0
77//!```
78//!# License
79//!MIT
80//!
81//!Copyright © 2025-2035 Homero Roman Roman
82//!Copyright © 2025-2035 Frederick Roman
83
84pub use candle_core::{bail, DType, Device, Result, Tensor};
85
86pub mod common;
87pub use crate::common::definitions::{
88 Activation, Initializer, Loss, Metric, Optimizer, Regularizer,
89};
90pub use crate::common::traits::{Layer, Model};
91
92pub mod layer;
93pub use crate::layer::dense::Dense;
94
95pub mod sequential;
96pub use crate::sequential::Sequential;