kerasty/
lib.rs

1//! Keras for Candle (Rust ML framework) with support for Web Assembly.
2//!
3//! # Roadmap of Supported Layers
4//!
5//! |    Layer   | State |                        Example                            |
6//! |------------|-------|-----------------------------------------------------------|
7//! |    Dense   |✅| [Dense](https://docs.rs/kerasty/latest/kerasty/layer/dense/struct.Dense.html) |
8//! |    Convolution   |🏗️| CNN|
9//! |    Normalization |🏗️| Norm|
10//! |    Flatten       |🏗️| Flatten|
11//! |    Pooling       |🏗️| Pool|
12//! |    Recurrent     |🏗️| RNN|
13//! |    Attention     |🏗️| Attn|
14//! |    Bert          |🏗️| BERT|
15//! |    Llama         |🏗️| LLAMA|
16//!
17//!# Examples
18//!Solution to the classic [XOR problem](https://www.geeksforgeeks.org/how-neural-networks-solve-the-xor-problem)
19//!```rust,no_run
20//!use kerasty::{Dense, Device, Loss, Metric, Model, Optimizer, Result, Sequential, Tensor};
21//!
22//!fn test_xor_problem() -> Result<()> {
23//!  // Define the XOR input and output data
24//!  let x_data = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0];
25//!  let x = Tensor::from_slice(&x_data, (4, 2), &Device::Cpu)?;
26//!  let y_data = vec![0.0, 1.0, 1.0, 0.0];
27//!  let y = Tensor::from_slice(&y_data, (4, 1), &Device::Cpu)?;
28//!
29//!  // Build the neural network model
30//!  let mut model = Sequential::new();
31//!  model.add(Dense::new(2, 2, "relu"));
32//!  model.add(Dense::new(1, 2, "sigmoid"));
33//!
34//!  // Compile the model
35//!  model.compile(
36//!      Optimizer::Adam(0.001, 0.9, 0.999, 1e-8, 0.0),
37//!      Loss::BinaryCrossEntropyWithLogit,
38//!    vec![Metric::Accuracy],
39//!  )?;
40//!
41//!  // Train the model
42//!  model.fit(x.clone(), y.clone(), 10000)?;
43//!
44//!  // Make predictions
45//!  let predictions = model.predict(&x);
46//!  let predictions = predictions.reshape(4)?.to_vec1::<f64>()?;
47//!  let predictions: Vec<i32> = predictions
48//!      .iter()
49//!      .map(|&p| if p >= 0.5 { 1 } else { 0 })
50//!      .collect();
51//!
52//!  println!("Predictions:");
53//!  for i in 0..4 {
54//!      println!(
55//!          "Input: {:?} => Predicted Output: {}, Actual Output: {}",
56//!          &x_data[i * 2..i * 2 + 2],
57//!          predictions[i],
58//!          y_data[i]
59//!      );
60//!  }
61//! 
62//!  // Evaluate the model
63//!  let score = model.evaluate(&x, &y);
64//!  println!("Average loss: {}", score[0]);
65//!  println!("Accuracy: {}", score[1]);
66//!
67//!  Ok(())
68//!}
69//!```
70//!The expected output is as follows:
71//!```shell,no_run
72//!Predictions:
73//!Input: [0.0, 0.0] => Predicted Output: 0, Actual Output: 0
74//!Input: [0.0, 1.0] => Predicted Output: 1, Actual Output: 1
75//!Input: [1.0, 0.0] => Predicted Output: 1, Actual Output: 1
76//!Input: [1.0, 1.0] => Predicted Output: 0, Actual Output: 0
77//!```
78//!# License
79//!MIT
80//!
81//!Copyright © 2025-2035 Homero Roman Roman  
82//!Copyright © 2025-2035 Frederick Roman
83
84pub use candle_core::{bail, DType, Device, Result, Tensor};
85
86pub mod common;
87pub use crate::common::definitions::{
88    Activation, Initializer, Loss, Metric, Optimizer, Regularizer,
89};
90pub use crate::common::traits::{Layer, Model};
91
92pub mod layer;
93pub use crate::layer::dense::Dense;
94
95pub mod sequential;
96pub use crate::sequential::Sequential;