Skip to main content

ferrotorch_llama/
lib.rs

1// Crate-level lint baseline. Mirrors the workspace-wide rust-quality
2// posture: deny correctness/idiom/Debug/docs problems; warn pedantic
3// stylistic issues. Specific pedantic lints are allowed crate-wide
4// where the lint is consistently wrong for ML/numeric kernel code —
5// each allow names the reason it's noise here rather than signal.
6
7// Correctness / hygiene — these are real bugs if they fire.
8#![deny(unsafe_code)]
9#![deny(rust_2018_idioms)]
10#![deny(missing_debug_implementations)]
11#![deny(missing_docs)]
12// Style baseline.
13#![warn(clippy::all)]
14#![warn(clippy::pedantic)]
15// Pedantic lints allowed crate-wide. Each of these has a reason that
16// repeats in this crate at scale; allow at the lint level (not at
17// `clippy::pedantic` group level) so we still catch the rest.
18//
19// Casts: dimension math (`as usize`, `as f32`, `as u32`) is intrinsic
20// to tensor indexing and bf16 ↔ f32 conversion — every kernel call
21// would otherwise need a per-call allow. Lint fires ~150x in the crate
22// with no actionable signal.
23#![allow(clippy::cast_possible_truncation)]
24#![allow(clippy::cast_precision_loss)]
25#![allow(clippy::cast_sign_loss)]
26#![allow(clippy::cast_possible_wrap)]
27#![allow(clippy::cast_lossless)]
28// `must_use_candidate` would flag every getter on every public struct.
29// We use `Result` returns where misuse is observable; the warning is
30// noise for infallible accessors.
31#![allow(clippy::must_use_candidate)]
32// Long kernel-dispatch routines (gpu.rs forward_core) are a single
33// linear pipeline; splitting them yields no clarity. Fires ~6 times.
34#![allow(clippy::too_many_lines)]
35// Identifiers like `bf16`, `f32`, `RoPE`, `cuBLAS` are flagged as
36// missing backticks even when they appear in code-fenced text. The
37// lint's heuristic produces false positives in this crate's docs.
38#![allow(clippy::doc_markdown)]
39// Triggers on `if !x { ... } else { ... }` patterns that are clearer
40// in context (e.g. error-path-first matches the audit-finding shape).
41#![allow(clippy::if_not_else)]
42// `explicit_iter_loop` flags `for x in v.iter()` where `for x in &v`
43// is supposedly clearer; the explicit form is more discoverable here.
44#![allow(clippy::explicit_iter_loop)]
45// `items_after_statements` flags the kernel-dispatch helper structs
46// declared inside `forward_core`-shaped functions; co-locating them
47// with the only caller is the more readable choice.
48#![allow(clippy::items_after_statements)]
49// `match_same_arms` collapses arms whose bodies happen to match but
50// whose meanings differ (different error categories with the same
51// message format).
52#![allow(clippy::match_same_arms)]
53// `match_wildcard_for_single_variants` requires enumerating
54// `#[non_exhaustive]` enums explicitly; with `FerrotorchError` and
55// `GpuError` both `#[non_exhaustive]` upstream the wildcard is the
56// future-proof match.
57#![allow(clippy::match_wildcard_for_single_variants)]
58// `return_self_not_must_use` flags every builder-style method; we
59// use `must_use` selectively where it actually matters.
60#![allow(clippy::return_self_not_must_use)]
61// `redundant_closure_for_method_calls` would fail on every
62// `.map(|e| e.to_string())` style call; rewriting to method ref
63// is purely cosmetic.
64#![allow(clippy::redundant_closure_for_method_calls)]
65// `manual_let_else` flags `match … { Some(x) => x, None => return … }`
66// patterns that pre-date `let-else`; the explicit form is fine here.
67#![allow(clippy::manual_let_else)]
68// `needless_pass_by_value` would force `&Config` and `&Tensor<T>`
69// signatures throughout, hiding ownership transfer in the API.
70#![allow(clippy::needless_pass_by_value)]
71// `map_unwrap_or` and `option_if_let_else` favour combinators that
72// hurt readability when the branches are non-trivial.
73#![allow(clippy::map_unwrap_or)]
74// `float_cmp` fires on epsilon-free comparisons that are correct here
75// (e.g. `temperature == 0.0` switches to greedy decoding — exact
76// match against the sentinel is the contract).
77#![allow(clippy::float_cmp)]
78// `implicit_clone` fires on `String::from(s)` style conversions
79// inside hot paths where the explicit form is clearer than `.clone()`.
80#![allow(clippy::implicit_clone)]
81// `ptr_as_ptr` flags numeric casts in bytemuck-backed reinterprets;
82// the casts are correct and bytemuck verifies layout.
83#![allow(clippy::ptr_as_ptr)]
84// `unnecessary_wraps` flags `Result`-returning helpers that today
85// always succeed but are part of an extensible API surface.
86#![allow(clippy::unnecessary_wraps)]
87// `uninlined_format_args` flags `format!("x={}", x)` vs `format!("x={x}")`.
88// Both forms are equally clear and the fixup churn is high.
89#![allow(clippy::uninlined_format_args)]
90// `unnested_or_patterns` flags `(A, X) | (B, X)` vs `(A | B, X)`. The
91// nested form mirrors the pattern shape across the rest of the
92// match in grammar/state.rs (every other arm is a 2-tuple of refs);
93// rewriting collapses readability.
94#![allow(clippy::unnested_or_patterns)]
95// `needless_continue` flags a redundant `continue` at the tail of a
96// loop branch; some sites use it for symmetry with parallel branches.
97#![allow(clippy::needless_continue)]
98// `needless_range_loop` flags index-driven loops where the body needs
99// the index for parallel arrays (`cos_bits`, `sin_bits` etc.). The
100// `enumerate()` rewrite obscures the intent.
101#![allow(clippy::needless_range_loop)]
102// `similar_names` flags variable pairs like `walk_l_state` and
103// `walk_u_state` (lower / upper bounds in the DFA construction) — the
104// similarity is the *point* and the name is shorter than any
105// disambiguator the lint would prefer.
106#![allow(clippy::similar_names)]
107// `many_single_char_names` flags conventional ML kernel locals
108// (`q`, `k`, `v` for query/key/value, `t` for time/token, `h` for
109// hidden) where the convention is the documentation.
110#![allow(clippy::many_single_char_names)]
111// `doc_link_with_quotes` flags identifiers in single-quoted
112// terminator examples (`',', ']'`) as missing intra-doc-link
113// backticks. Char literals are not intra-doc links.
114#![allow(clippy::doc_link_with_quotes)]
115
116//! Llama 3 (Meta LLaMA) model composition for ferrotorch.
117//!
118//! Assembles the standard Llama decoder stack from ferrotorch primitives:
119//!
120//! ```text
121//! LlamaForCausalLM
122//! ├── LlamaModel
123//! │   ├── Embedding                      (token embeddings)
124//! │   ├── LlamaDecoderLayer × N
125//! │   │   ├── RMSNorm                    (pre-attn)
126//! │   │   ├── LlamaAttention             (GQA + RoPE)
127//! │   │   ├── residual
128//! │   │   ├── RMSNorm                    (pre-MLP)
129//! │   │   ├── SwiGLU                     (gate/up/down projections)
130//! │   │   └── residual
131//! │   └── RMSNorm                        (final)
132//! └── Linear lm_head                     (projection to vocab)
133//! ```
134//!
135//! # Loading real weights
136//!
137//! [`LlamaForCausalLM::load_hf_state_dict`] accepts a `StateDict` whose
138//! keys use the HuggingFace transformers naming convention and rewrites
139//! them to match the ferrotorch parameter paths before delegating to
140//! [`Module::load_state_dict`]. Combined with
141//! `ferrotorch_serialize::load_safetensors_sharded` this gives a direct
142//! path from a downloaded Meta-Llama-3-8B checkpoint to a loaded model.
143
144pub mod attention;
145pub mod config;
146pub mod generation;
147pub mod gguf_remap;
148#[cfg(feature = "cuda")]
149pub mod gpu;
150#[cfg(feature = "cuda")]
151pub mod gpu_gguf;
152/// Re-export of [`ferrotorch_grammar`] for backward compatibility.
153///
154/// The constrained-decoding grammar processors used to live in
155/// `ferrotorch_llama::grammar`; in v0.5.1 they were extracted into a
156/// standalone [`ferrotorch_grammar`] crate. This alias keeps the old
157/// path working so existing callers compile unchanged.
158pub use ferrotorch_grammar as grammar;
159pub mod kv_cache;
160pub mod layer;
161pub mod mlp;
162pub mod model;
163pub mod quant_loaders;
164pub mod spec_decode;
165
166pub use attention::LlamaAttention;
167pub use config::{LlamaActivation, LlamaConfig};
168pub use generation::{
169    GenerationConfig, apply_repetition_penalty, apply_temperature, argmax, generate,
170    generate_with_streamer, sample_softmax, top_k_filter, top_p_filter,
171};
172pub use gguf_remap::{gguf_key_to_hf, gguf_to_hf_state_dict};
173pub use kv_cache::{LayerKvCache, LlamaKvCache};
174pub use layer::LlamaDecoderLayer;
175pub use mlp::LlamaMLP;
176pub use model::{LlamaForCausalLM, LlamaModel};
177pub use quant_loaders::{
178    AwqQ4, GptqQ4, HqqQ4Axis1, dequantize_awq_q4, dequantize_gptq_q4, dequantize_hqq_q4_axis1,
179    hqq_q4_axis1_to_dense, hqq_state_dict_to_dense,
180};
181pub use spec_decode::{
182    LlamaHandle, ModelHandle, SpecDecodeConfig, SpecDecodeOutput, speculative_decode,
183};
184
185#[cfg(feature = "cuda")]
186pub use gpu::{GraphedDecoder, LlamaGpuInferencer, LlamaGpuLayer, ProfiledForwardResult};
187#[cfg(feature = "cuda")]
188pub use gpu_gguf::{
189    apply_grammar_mask_gpu, cubecl_cuda_client, gpu_dequantize_to_bf16_cudarc, masked_decode_loop,
190};