Skip to main content

oxillama_arch/
traits.rs

1//! Core traits defining the model architecture plugin system.
2//!
3//! Every model family (LLaMA, Qwen3, Mistral, etc.) implements
4//! [`ModelArchitecture`] to register itself, and [`ForwardPass`]
5//! for the actual inference computation.
6
7use crate::common::sequence_state::{AttentionSequenceState, SequenceState};
8use crate::config::ModelConfig;
9use crate::error::{ArchError, ArchResult};
10use crate::lora::{LoadedLora, LoraStack};
11use oxillama_gguf::TensorStore;
12
13/// Pattern for matching expected tensor names in a model file.
14#[derive(Debug, Clone)]
15pub struct TensorNamePattern {
16    /// Regex or glob pattern for tensor names.
17    pub pattern: String,
18    /// Human-readable description of what this tensor represents.
19    pub description: String,
20    /// Whether this tensor is required for the architecture.
21    pub required: bool,
22}
23
24/// Trait for a model architecture plugin.
25///
26/// Implementations register themselves with the [`ArchitectureRegistry`](crate::registry::ArchitectureRegistry)
27/// and provide the ability to build a runnable model from GGUF data.
28pub trait ModelArchitecture: Send + Sync {
29    /// Architecture identifier string (matches GGUF `general.architecture` metadata).
30    ///
31    /// Examples: `"llama"`, `"qwen3"`, `"mistral"`, `"gemma"`, `"phi"`.
32    fn arch_id(&self) -> &str;
33
34    /// Build a runnable model from configuration and loaded tensors.
35    ///
36    /// This is called once during model loading. The returned [`ForwardPass`]
37    /// implementation owns the model weights and is used for inference.
38    fn build(
39        &self,
40        config: &ModelConfig,
41        tensors: &TensorStore,
42    ) -> ArchResult<Box<dyn ForwardPass>>;
43
44    /// Expected tensor name patterns for this architecture.
45    ///
46    /// Used for validation and diagnostics when loading a model file.
47    fn tensor_names(&self) -> Vec<TensorNamePattern>;
48
49    /// Returns the sliding-window attention configuration for this model.
50    ///
51    /// Returns `Some((window_size, is_interleaved))` when the architecture
52    /// uses SWA on at least some layers, or `None` for pure global attention.
53    fn swa_config(&self) -> Option<(u32, bool)> {
54        None
55    }
56}
57
58/// A single request's slot within the shared KV pool.
59///
60/// Each in-flight request receives one `KvSlot` that identifies which
61/// position in the KV pool belongs to it.  The slot is released back to the
62/// pool when the request finishes (EOS or max-token limit reached).
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct KvSlot {
65    /// Unique identifier of the request that owns this slot.
66    pub request_id: u64,
67    /// Index into the shared KV cache pool (e.g. the row within a paged KV
68    /// cache or the sequence slot index in a flat pool).
69    pub kv_cache_idx: usize,
70    /// Current sequence position (number of tokens committed so far).
71    pub position: usize,
72}
73
74impl KvSlot {
75    /// Construct a new `KvSlot`.
76    pub fn new(request_id: u64, kv_cache_idx: usize, position: usize) -> Self {
77        Self {
78            request_id,
79            kv_cache_idx,
80            position,
81        }
82    }
83}
84
85/// A view over the KV caches of multiple concurrent requests for batched
86/// decode attention.
87///
88/// During the decode phase each request has already accumulated keys and
89/// values from the prefill + prior decode steps.  `BatchedKvView` provides
90/// the batched-attention kernel with access to per-request KV slices without
91/// requiring the caller to lay out memory in any particular way.
92///
93/// Implementors typically wrap a pool of KV cache buffers indexed by [`KvSlot`].
94pub trait BatchedKvView: Sync {
95    /// Number of concurrent request slots in this batch.
96    fn slot_count(&self) -> usize;
97
98    /// Return the flattened key and value slices for slot `slot`.
99    ///
100    /// Both slices have length `position(slot) * kv_dim`, laid out as
101    /// `[seq_len, kv_dim]` in row-major order.
102    ///
103    /// # Panics
104    ///
105    /// Implementations are permitted to panic if `slot >= slot_count()`.
106    fn kv_for_slot(&self, slot: usize) -> (&[f32], &[f32]);
107
108    /// Number of KV tokens already committed for slot `slot`
109    /// (= the sequence position the next token will be written to).
110    fn position(&self, slot: usize) -> usize;
111}
112
113/// Minimal KV cache interface used by forward pass implementations.
114///
115/// This trait is defined in `oxillama-arch` to avoid a circular dependency
116/// with `oxillama-runtime` where the full KV cache lives.
117pub trait KvCacheAccess: Send + Sync {
118    /// Get the current sequence length (number of cached tokens).
119    fn seq_len(&self) -> usize;
120
121    /// Store key and value tensors for a layer at the current position.
122    fn store_kv(&mut self, layer: usize, key: &[f32], value: &[f32]) -> ArchResult<()>;
123
124    /// Retrieve all cached keys for a layer up to the current sequence length.
125    fn get_keys(&self, layer: usize) -> ArchResult<&[f32]>;
126
127    /// Retrieve all cached values for a layer up to the current sequence length.
128    fn get_values(&self, layer: usize) -> ArchResult<&[f32]>;
129
130    /// Advance the cache position by one token.
131    ///
132    /// Called after all layers have stored their K/V for the current token.
133    fn advance(&mut self);
134
135    /// KV dimension per token (num_kv_heads * head_dim).
136    ///
137    /// Returns `0` by default, which signals that per-token iteration is not
138    /// available via the default [`for_each_key`](Self::for_each_key) /
139    /// [`for_each_value`](Self::for_each_value) helpers.  Implementations that
140    /// know their KV dimension should override this.
141    fn kv_dim(&self) -> usize {
142        0
143    }
144
145    /// Iterate over every cached key token for `layer`, calling `f(pos, key_data)`.
146    ///
147    /// The default implementation chunks `get_keys()` using [`kv_dim()`](Self::kv_dim).
148    /// Paged implementations override this to avoid assembling a contiguous slice.
149    ///
150    /// # Errors
151    ///
152    /// Returns [`ArchError::NotSupported`] if `kv_dim()` returns `0`.
153    /// Propagates any error from [`get_keys()`](Self::get_keys).
154    fn for_each_key(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
155        let dim = self.kv_dim();
156        if dim == 0 {
157            return Err(ArchError::NotSupported {
158                detail: "kv_dim() not implemented; cannot iterate per-token keys".to_string(),
159            });
160        }
161        let keys = self.get_keys(layer)?;
162        for (pos, slice) in keys.chunks_exact(dim).enumerate() {
163            f(pos, slice);
164        }
165        Ok(())
166    }
167
168    /// Iterate over every cached value token for `layer`, calling `f(pos, value_data)`.
169    ///
170    /// The default implementation chunks `get_values()` using [`kv_dim()`](Self::kv_dim).
171    /// Paged implementations override this to avoid assembling a contiguous slice.
172    ///
173    /// # Errors
174    ///
175    /// Returns [`ArchError::NotSupported`] if `kv_dim()` returns `0`.
176    /// Propagates any error from [`get_values()`](Self::get_values).
177    fn for_each_value(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
178        let dim = self.kv_dim();
179        if dim == 0 {
180            return Err(ArchError::NotSupported {
181                detail: "kv_dim() not implemented; cannot iterate per-token values".to_string(),
182            });
183        }
184        let values = self.get_values(layer)?;
185        for (pos, slice) in values.chunks_exact(dim).enumerate() {
186            f(pos, slice);
187        }
188        Ok(())
189    }
190}
191
192/// Trait for running forward passes through a loaded model.
193///
194/// Implementations own the model weights and maintain any mutable state
195/// needed during inference (e.g., internal buffers).
196pub trait ForwardPass: Send + Sync {
197    /// Run one forward pass, returning logits for the next token prediction.
198    ///
199    /// # Arguments
200    /// * `tokens` - Input token IDs for this step.
201    /// * `kv_cache` - Mutable reference to the key-value cache.
202    ///
203    /// # Returns
204    /// A vector of logits with length equal to the vocabulary size.
205    fn forward(&mut self, tokens: &[u32], kv_cache: &mut dyn KvCacheAccess)
206        -> ArchResult<Vec<f32>>;
207
208    /// Run one forward pass, returning the post-output-norm hidden state
209    /// (not projected through the LM head).
210    ///
211    /// This is the embedding extraction path: it runs all transformer layers
212    /// and applies the final RMSNorm, but stops before the LM-head projection
213    /// that maps hidden_size → vocab_size. The returned vector has length
214    /// `hidden_size`, not `vocab_size`.
215    ///
216    /// The default implementation returns [`ArchError::NotSupported`].
217    /// Each architecture overrides this with a concrete implementation.
218    fn embed(&mut self, tokens: &[u32], kv_cache: &mut dyn KvCacheAccess) -> ArchResult<Vec<f32>> {
219        let _ = (tokens, kv_cache);
220        Err(ArchError::NotSupported {
221            detail: "embed() not implemented for this architecture".to_string(),
222        })
223    }
224
225    /// Run one forward pass, returning per-token hidden states for **all** tokens
226    /// as a flat `[seq_len × hidden_size]` vector in row-major order.
227    ///
228    /// This is the multi-token embedding extraction path used when a pooling
229    /// mode other than `Last` is requested. The returned vector has length
230    /// `seq_len * hidden_size`.
231    ///
232    /// The default implementation returns [`ArchError::NotSupported`].
233    /// Architectures that require multi-token pooling should override this.
234    fn embed_all(
235        &mut self,
236        tokens: &[u32],
237        kv_cache: &mut dyn KvCacheAccess,
238    ) -> ArchResult<Vec<f32>> {
239        let _ = (tokens, kv_cache);
240        Err(ArchError::NotSupported {
241            detail: "embed_all() not implemented for this architecture; use embed() instead"
242                .to_string(),
243        })
244    }
245
246    /// Returns the model's vocabulary size.
247    fn vocab_size(&self) -> usize;
248
249    /// Returns the model's maximum context length.
250    fn max_context_length(&self) -> usize;
251
252    /// Returns the model's hidden size (embedding dimension).
253    fn hidden_size(&self) -> usize;
254
255    /// Apply LoRA adapter corrections to this model's linear layers.
256    ///
257    /// Walks the model's `QuantLinear` fields and calls
258    /// [`QuantLinear::set_lora`](crate::common::linear::QuantLinear::set_lora)
259    /// for every layer whose name appears in `lora.adapters`.
260    ///
261    /// The default implementation is a no-op: models that do not yet support
262    /// LoRA patching will silently ignore the adapter.  Override this method
263    /// in each architecture implementation that supports LoRA.
264    fn apply_lora(&mut self, lora: &LoadedLora) -> ArchResult<()> {
265        let _ = lora;
266        Ok(())
267    }
268
269    /// Apply an ordered stack of LoRA adapters.
270    ///
271    /// Default implementation: applies each adapter in the stack in order via
272    /// [`apply_lora`](Self::apply_lora), ignoring per-entry scale multipliers.
273    /// Override for architectures that support scaled stacking.
274    fn apply_lora_stack(&mut self, stack: &LoraStack) -> ArchResult<()> {
275        for (lora, _scale) in stack.entries() {
276            self.apply_lora(lora)?;
277        }
278        Ok(())
279    }
280
281    /// Returns the sliding-window attention configuration for this loaded model.
282    ///
283    /// Returns `Some((window_size, is_interleaved))` when the model uses SWA
284    /// on at least some layers, or `None` for pure global attention.
285    fn swa_config(&self) -> Option<(u32, bool)> {
286        None
287    }
288
289    /// Set a persistent LoRA adapter stack that applies to all subsequent
290    /// `forward()` calls.
291    ///
292    /// Default implementation is a no-op — architectures that do not support
293    /// LoRA stacking silently ignore the call (compatible with the existing
294    /// `apply_lora_stack` interface).
295    ///
296    /// # Errors
297    ///
298    /// Returns [`ArchError::LoraIncompatible`] if the adapter's rank or
299    /// dimensions are incompatible with this model.
300    fn with_lora_stack(&mut self, _stack: LoraStack) -> ArchResult<()> {
301        Ok(())
302    }
303
304    /// Remove all LoRA adapters from every `QuantLinear` in this model.
305    ///
306    /// The inverse of [`Self::apply_lora_stack`]: sets `lora = None` on every linear
307    /// layer that was patched.  The default is a no-op; architectures that
308    /// override [`Self::apply_lora`] must also override this.
309    fn unapply_all_loras(&mut self) {}
310
311    /// Allocate a fresh per-sequence state object for this model.
312    ///
313    /// The runtime calls this once per pool slot at model load time.
314    /// Default implementation returns an [`AttentionSequenceState`] suitable
315    /// for all KV-cache-based architectures.
316    ///
317    /// SSM and hybrid architectures **must** override this to return the
318    /// correct state type (e.g. [`Mamba2SequenceState`] or `JambaSequenceState`).
319    ///
320    /// [`Mamba2SequenceState`]: crate::common::sequence_state::Mamba2SequenceState
321    fn allocate_sequence_state(&self, max_context_length: usize) -> Box<dyn SequenceState> {
322        Box::new(AttentionSequenceState::new(max_context_length))
323    }
324
325    /// Run a batched decode-phase forward pass across multiple concurrent requests.
326    ///
327    /// Each slot in `kv_view` corresponds to one batch element.  `q_batch` is
328    /// laid out as `[batch_size, num_heads, head_dim]` in row-major order.
329    ///
330    /// The default implementation returns [`ArchError::NotSupported`].
331    /// Architectures that support continuous batching override this.
332    ///
333    /// # Arguments
334    ///
335    /// * `q_batch`    - Query tensor, shape `[batch_size, num_heads, head_dim]`.
336    /// * `kv_view`    - Per-slot KV cache view.
337    /// * `num_heads`  - Number of query attention heads.
338    /// * `head_dim`   - Per-head dimension.
339    /// * `scale`      - Softmax scale factor (typically `1 / sqrt(head_dim)`).
340    ///
341    /// # Returns
342    ///
343    /// Output tensor with layout `[batch_size, num_heads, head_dim]`.
344    fn forward_batched(
345        &mut self,
346        q_batch: &[f32],
347        kv_view: &dyn BatchedKvView,
348        num_heads: usize,
349        head_dim: usize,
350        scale: f32,
351    ) -> ArchResult<Vec<f32>> {
352        let _ = (q_batch, kv_view, num_heads, head_dim, scale);
353        Err(ArchError::NotSupported {
354            detail: "forward_batched() not implemented for this architecture".to_string(),
355        })
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use crate::error::ArchError;
363
364    /// A minimal stub implementing ForwardPass to test the default
365    /// `forward_batched` returns NotSupported.
366    struct StubModel;
367
368    impl ForwardPass for StubModel {
369        fn forward(
370            &mut self,
371            _tokens: &[u32],
372            _kv_cache: &mut dyn KvCacheAccess,
373        ) -> ArchResult<Vec<f32>> {
374            Ok(vec![])
375        }
376
377        fn vocab_size(&self) -> usize {
378            1
379        }
380
381        fn max_context_length(&self) -> usize {
382            1
383        }
384
385        fn hidden_size(&self) -> usize {
386            1
387        }
388    }
389
390    /// A minimal BatchedKvView for testing.
391    struct EmptyKvView;
392    impl BatchedKvView for EmptyKvView {
393        fn slot_count(&self) -> usize {
394            0
395        }
396
397        fn kv_for_slot(&self, _slot: usize) -> (&[f32], &[f32]) {
398            (&[], &[])
399        }
400
401        fn position(&self, _slot: usize) -> usize {
402            0
403        }
404    }
405
406    #[test]
407    fn forward_batched_default_returns_not_supported() {
408        let mut model = StubModel;
409        let view = EmptyKvView;
410        let result = model.forward_batched(&[], &view, 2, 4, 0.5);
411        match result {
412            Err(ArchError::NotSupported { detail }) => {
413                assert!(
414                    detail.contains("forward_batched"),
415                    "error detail should mention forward_batched, got: {detail}"
416                );
417            }
418            other => panic!("expected NotSupported, got: {other:?}"),
419        }
420    }
421
422    #[test]
423    fn forward_batched_empty_batch_via_default_is_not_supported() {
424        // The default implementation always returns NotSupported regardless of
425        // batch size — it cannot know the correct answer without weights.
426        let mut model = StubModel;
427        let view = EmptyKvView;
428        let result = model.forward_batched(&[], &view, 1, 8, 1.0);
429        assert!(result.is_err(), "default must return Err");
430    }
431
432    #[test]
433    fn kv_slot_construction() {
434        let slot = KvSlot::new(42, 7, 100);
435        assert_eq!(slot.request_id, 42);
436        assert_eq!(slot.kv_cache_idx, 7);
437        assert_eq!(slot.position, 100);
438    }
439
440    #[test]
441    fn kv_cache_access_default_kv_dim_is_zero() {
442        /// Minimal KvCacheAccess impl that does not override kv_dim().
443        struct MinimalCache;
444        impl KvCacheAccess for MinimalCache {
445            fn seq_len(&self) -> usize {
446                0
447            }
448            fn store_kv(&mut self, _layer: usize, _key: &[f32], _value: &[f32]) -> ArchResult<()> {
449                Ok(())
450            }
451            fn get_keys(&self, _layer: usize) -> ArchResult<&[f32]> {
452                Ok(&[])
453            }
454            fn get_values(&self, _layer: usize) -> ArchResult<&[f32]> {
455                Ok(&[])
456            }
457            fn advance(&mut self) {}
458        }
459
460        let cache = MinimalCache;
461        assert_eq!(cache.kv_dim(), 0, "default kv_dim must be 0");
462
463        // for_each_key must return NotSupported when kv_dim == 0
464        let mut called = false;
465        let result = cache.for_each_key(0, &mut |_, _| {
466            called = true;
467        });
468        assert!(result.is_err(), "must return Err when kv_dim() == 0");
469        assert!(!called, "callback must not be invoked");
470    }
471}