Skip to main content

attnres/
lib.rs

1//! # attnres
2//!
3//! First Rust implementation of Attention Residuals from the MoonshotAI/Kimi paper,
4//! built on the [burn](https://burn.dev) deep learning framework.
5//!
6//! Attention Residuals replace standard fixed-weight residual connections in Transformers
7//! with learned softmax attention over depth, enabling selective information routing
8//! across layers.
9//!
10//! ## Quick Start
11//!
12//! ```rust
13//! use attnres::{AttnResConfig, AttnResTransformer};
14//! use burn::prelude::*;
15//! use burn::backend::NdArray;
16//!
17//! type B = NdArray;
18//!
19//! let device = Default::default();
20//! let config = AttnResConfig::new(128, 8, 2)
21//!     .with_num_heads(4)
22//!     .with_vocab_size(1000);
23//!
24//! let model: AttnResTransformer<B> = config.init_model(&device);
25//! let input_ids = Tensor::<B, 2, Int>::zeros([1, 16], &device);
26//! let logits = model.forward(input_ids, None);
27//! assert_eq!(logits.dims(), [1, 16, 1000]);
28//! ```
29
30pub mod attention;
31pub mod attn_res_op;
32pub mod block_state;
33pub mod config;
34pub mod feed_forward;
35pub mod layer;
36pub mod model;
37pub mod rms_norm;
38pub mod serialization;
39pub mod two_phase;
40pub mod utils;
41
42// Public API re-exports
43pub use attention::{MultiHeadAttention, MultiHeadAttentionConfig};
44pub use attn_res_op::AttnResOp;
45pub use block_state::BlockState;
46pub use config::AttnResConfig;
47pub use feed_forward::{FeedForward, FeedForwardConfig};
48pub use layer::AttnResLayer;
49pub use model::AttnResTransformer;
50pub use rms_norm::{RmsNorm, RmsNormConfig};
51pub use serialization::SerializationError;
52pub use utils::causal_mask;