attnres 0.1.0

Attention Residuals (MoonshotAI/Kimi) implementation in Rust using burn
Documentation
<p align="center">
  <strong>attnres</strong>
</p>

<p align="center">
  <em>Learned depth attention for Transformers. Replaces fixed residual connections with softmax attention over layers.</em>
</p>

<p align="center">
  <a href="https://github.com/AbdelStark/attnres/actions"><img src="https://github.com/AbdelStark/attnres/actions/workflows/ci.yml/badge.svg" alt="CI"></a>
  <a href="https://crates.io/crates/attnres"><img src="https://img.shields.io/crates/v/attnres.svg" alt="crates.io"></a>
  <a href="https://docs.rs/attnres"><img src="https://img.shields.io/docsrs/attnres" alt="docs.rs"></a>
  <a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT"></a>
</p>

---

The first Rust implementation of [**Attention Residuals**](https://github.com/MoonshotAI/Attention-Residuals) (MoonshotAI / Kimi). Built on [burn](https://github.com/tracel-ai/burn). Runs on CPU, CUDA, Metal, and wgpu.

Standard residual connections add every layer's output with equal weight. In deep networks, this dilutes early representations and offers no mechanism to *choose* which layers matter. **AttnRes** replaces the fixed `h + F(h)` with a learned softmax over all prior block representations, giving each layer selective, content-aware access to the full depth of the network.

The result: **1.25x compute advantage** at **< 2% inference overhead**.

## How It Works

```
Standard Residual              Attention Residual
─────────────────              ──────────────────
h = h_l + F(h_l)              V = [b₀; b₁; …; bₙ]      ← stack all blocks
    └─ fixed +1               K = RMSNorm(V)             ← normalize keys
                               α = softmax(K · w_l)      ← depth attention
                               h = Σ αᵢ · Vᵢ             ← weighted combination
```

Each transformer layer has **two** AttnRes operations (before self-attention and before MLP), each with its own learned pseudo-query vector **w_l** initialized to zero. At initialization, all blocks receive equal weight (standard residual behavior). During training, the model learns to selectively route information from the most relevant depths.

## Quick Start

```toml
[dependencies]
attnres = "0.1"
burn = { version = "0.20", features = ["ndarray"] }
```

```rust
use attnres::{AttnResConfig, AttnResTransformer};
use burn::prelude::*;
use burn::backend::NdArray;

type B = NdArray;

let device = Default::default();
let config = AttnResConfig::new(128, 8, 2)  // d_model, 8 sublayers, 2 blocks
    .with_num_heads(4)
    .with_vocab_size(1000);

let model: AttnResTransformer<B> = config.init_model(&device);
let input = Tensor::<B, 2, Int>::zeros([1, 16], &device);
let logits = model.forward(input, None);  // [1, 16, 1000]
```

## Interactive TUI Demo

Watch the algorithm come alive. The TUI demo trains a model in real time and visualizes depth attention patterns evolving from uniform to selective:

```bash
cargo run --example demo_tui --release
```

<p align="center">
  <img src="docs/assets/img/attnres-tui-demo-1.png" width="700" alt="TUI demo at initialization — uniform depth attention weights, all pseudo-query norms at zero">
</p>

<p align="center">
  <em>Step 0: All pseudo-queries are zero-initialized. Every block receives equal attention weight (1/N).</em>
</p>

<p align="center">
  <img src="docs/assets/img/attnres-tui-demo-2.png" width="700" alt="TUI demo mid-training — loss decreasing, attention patterns becoming selective, pseudo-query norms growing">
</p>

<p align="center">
  <em>Mid-training: Deeper layers develop stronger preferences. Pseudo-query norms grow as the model learns which blocks are most useful at each depth.</em>
</p>

**Controls:** `Space` start/pause | `Up/Down` speed | `r` reset | `q` quit

**What you're seeing:**
- **Loss Curve** — Cross-entropy loss decreasing as the model learns
- **Depth Attention Weights** — Actual softmax attention weights computed from the model's pseudo-query vectors and block states. Rows are transformer layers, columns are source blocks (Emb, B1, ..., Partial)
- **Pseudo-Query Norms** — The magnitude of each layer's learned w_l vector. Zero means uniform; larger means more selective
- **Algorithm** — The 5-step AttnRes computation, cycling to build intuition

## Architecture

```
AttnResTransformer
  ├── Token Embedding
  ├── AttnResLayer ×N              ← N = num_layers / 2 (each layer = 2 sublayers)
  │     ├── AttnResOp              ← depth attention before self-attention
  │     │     ├── RMSNorm(V)       ← normalize stacked blocks
  │     │     ├── K · w_l          ← pseudo-query scores each block
  │     │     └── softmax → Σ αV   ← weighted combination
  │     ├── RMSNorm + MultiHeadAttention
  │     ├── AttnResOp              ← depth attention before MLP
  │     └── RMSNorm + FeedForward
  ├── Final RMSNorm
  └── LM Head
```

**Block AttnRes** groups layers into N blocks. With `num_blocks=4` for a 100-layer model, the attention operates over 4 block representations instead of 100 individual layers — keeping overhead minimal while retaining the selective routing benefit.

> **Note on `num_layers`:** This counts *sublayers*. Each transformer layer has 2 sublayers (attention + MLP), so `num_layers=8` creates 4 transformer layers.

## Features

- **Full algorithm** — AttnResOp, BlockState tracking, RMSNorm, block boundary management
- **Two-phase inference** — Batched inter-block + sequential intra-block attention (Algorithm 1 from the paper)
- **Serialization** — Save/load via burn records (NamedMpk, binary, compact half-precision) + JSON config
- **Backend-generic** — Runs on any burn backend: NdArray (CPU), wgpu (cross-platform GPU), CUDA, Metal
- **Zero-cost initialization** — Pseudo-queries start at zero; AttnRes begins as standard residual
- **Comprehensive tests** — 87 tests: unit, differential, property-based (proptest), integration, doctest
- **Interactive demos** — Terminal TUI + browser WASM demo with live visualizations

## Examples

```bash
# Interactive TUI training visualization (recommended first experience)
cargo run --example demo_tui --release

# Train a small model on synthetic data
cargo run --example train_tiny

# Compare standard residuals vs AttnRes on same architecture
cargo run --example compare_residuals

# Visualize depth attention weight patterns
cargo run --example visualize_weights
```

## Web Demo

An interactive browser demo runs the core algorithm via Rust compiled to WebAssembly:

```bash
cd web-demo
npm install
npm run build:wasm   # Requires wasm-pack + wasm32-unknown-unknown target
npm run dev          # localhost:5173
```

Configurable model parameters, live heatmaps, training simulation, and standard vs AttnRes comparison — all running in the browser with no GPU required.

## Development

```bash
cargo build                        # Build
cargo test --all-features          # 87 tests (unit + differential + property + integration + doctest)
cargo clippy -- -D warnings        # Lint
cargo fmt                          # Format
cargo bench                        # Criterion benchmarks
cargo doc --open                   # Documentation
```

## Key Implementation Details

These are the invariants that make AttnRes correct. Getting any of them wrong breaks training:

| Invariant | Why |
|---|---|
| **Zero-init pseudo-queries** | Ensures AttnRes starts as standard residual (uniform 1/N weights). Random init destabilizes training. |
| **Two AttnRes per layer** | One before self-attention, one before MLP. Each has its own w_l. |
| **Softmax over depth** | Attention is over the block/depth dimension (dim=0), NOT the sequence dimension. This is attention over *layers*. |
| **RMSNorm on keys** | Prevents blocks with large accumulated magnitudes from dominating the attention logits. |
| **Cumulative block sums** | `blocks[n]` is the sum of all layer outputs in block n, not individual layer outputs. |

## Project Structure

```
src/
  lib.rs              Public API re-exports
  config.rs           AttnResConfig (builder pattern)
  attn_res_op.rs      Core depth attention operation
  block_state.rs      Block state tracking
  layer.rs            AttnResLayer (2 AttnRes + attention + MLP)
  model.rs            AttnResTransformer (full model)
  rms_norm.rs         RMSNorm (3D + 4D)
  attention.rs        Multi-head self-attention
  feed_forward.rs     MLP (SwiGLU-style)
  serialization.rs    Model save/load
  two_phase.rs        Two-phase inference optimization
  utils.rs            Causal mask helper

tests/                Unit, differential, property-based, integration
examples/             train_tiny, compare_residuals, visualize_weights, demo_tui
benches/              Criterion benchmarks
web-demo/             Interactive WASM browser demo (Vite + TypeScript)
```

## Citation

If you use this implementation in your research:

```bibtex
@software{attnres_rust,
  author = {Abdel},
  title = {attnres: Attention Residuals in Rust},
  url = {https://github.com/AbdelStark/attnres},
  year = {2026}
}
```

Original paper:

```bibtex
@article{attention_residuals,
  title = {Attention Residuals},
  author = {Kimi Team, MoonshotAI},
  year = {2026},
  url = {https://github.com/MoonshotAI/Attention-Residuals}
}
```

## License

[MIT](LICENSE)