oxillama_arch/traits.rs
1//! Core traits defining the model architecture plugin system.
2//!
3//! Every model family (LLaMA, Qwen3, Mistral, etc.) implements
4//! [`ModelArchitecture`] to register itself, and [`ForwardPass`]
5//! for the actual inference computation.
6
7use crate::common::sequence_state::{AttentionSequenceState, SequenceState};
8use crate::config::ModelConfig;
9use crate::error::{ArchError, ArchResult};
10use crate::lora::{LoadedLora, LoraStack};
11use oxillama_gguf::TensorStore;
12
13/// Pattern for matching expected tensor names in a model file.
14#[derive(Debug, Clone)]
15pub struct TensorNamePattern {
16 /// Regex or glob pattern for tensor names.
17 pub pattern: String,
18 /// Human-readable description of what this tensor represents.
19 pub description: String,
20 /// Whether this tensor is required for the architecture.
21 pub required: bool,
22}
23
24/// Trait for a model architecture plugin.
25///
26/// Implementations register themselves with the [`ArchitectureRegistry`](crate::registry::ArchitectureRegistry)
27/// and provide the ability to build a runnable model from GGUF data.
28pub trait ModelArchitecture: Send + Sync {
29 /// Architecture identifier string (matches GGUF `general.architecture` metadata).
30 ///
31 /// Examples: `"llama"`, `"qwen3"`, `"mistral"`, `"gemma"`, `"phi"`.
32 fn arch_id(&self) -> &str;
33
34 /// Build a runnable model from configuration and loaded tensors.
35 ///
36 /// This is called once during model loading. The returned [`ForwardPass`]
37 /// implementation owns the model weights and is used for inference.
38 fn build(
39 &self,
40 config: &ModelConfig,
41 tensors: &TensorStore,
42 ) -> ArchResult<Box<dyn ForwardPass>>;
43
44 /// Expected tensor name patterns for this architecture.
45 ///
46 /// Used for validation and diagnostics when loading a model file.
47 fn tensor_names(&self) -> Vec<TensorNamePattern>;
48
49 /// Returns the sliding-window attention configuration for this model.
50 ///
51 /// Returns `Some((window_size, is_interleaved))` when the architecture
52 /// uses SWA on at least some layers, or `None` for pure global attention.
53 fn swa_config(&self) -> Option<(u32, bool)> {
54 None
55 }
56}
57
58/// A single request's slot within the shared KV pool.
59///
60/// Each in-flight request receives one `KvSlot` that identifies which
61/// position in the KV pool belongs to it. The slot is released back to the
62/// pool when the request finishes (EOS or max-token limit reached).
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct KvSlot {
65 /// Unique identifier of the request that owns this slot.
66 pub request_id: u64,
67 /// Index into the shared KV cache pool (e.g. the row within a paged KV
68 /// cache or the sequence slot index in a flat pool).
69 pub kv_cache_idx: usize,
70 /// Current sequence position (number of tokens committed so far).
71 pub position: usize,
72}
73
74impl KvSlot {
75 /// Construct a new `KvSlot`.
76 pub fn new(request_id: u64, kv_cache_idx: usize, position: usize) -> Self {
77 Self {
78 request_id,
79 kv_cache_idx,
80 position,
81 }
82 }
83}
84
85/// A view over the KV caches of multiple concurrent requests for batched
86/// decode attention.
87///
88/// During the decode phase each request has already accumulated keys and
89/// values from the prefill + prior decode steps. `BatchedKvView` provides
90/// the batched-attention kernel with access to per-request KV slices without
91/// requiring the caller to lay out memory in any particular way.
92///
93/// Implementors typically wrap a pool of KV cache buffers indexed by [`KvSlot`].
94pub trait BatchedKvView: Sync {
95 /// Number of concurrent request slots in this batch.
96 fn slot_count(&self) -> usize;
97
98 /// Return the flattened key and value slices for slot `slot`.
99 ///
100 /// Both slices have length `position(slot) * kv_dim`, laid out as
101 /// `[seq_len, kv_dim]` in row-major order.
102 ///
103 /// # Panics
104 ///
105 /// Implementations are permitted to panic if `slot >= slot_count()`.
106 fn kv_for_slot(&self, slot: usize) -> (&[f32], &[f32]);
107
108 /// Number of KV tokens already committed for slot `slot`
109 /// (= the sequence position the next token will be written to).
110 fn position(&self, slot: usize) -> usize;
111}
112
113/// Minimal KV cache interface used by forward pass implementations.
114///
115/// This trait is defined in `oxillama-arch` to avoid a circular dependency
116/// with `oxillama-runtime` where the full KV cache lives.
117pub trait KvCacheAccess: Send + Sync {
118 /// Get the current sequence length (number of cached tokens).
119 fn seq_len(&self) -> usize;
120
121 /// Store key and value tensors for a layer at the current position.
122 fn store_kv(&mut self, layer: usize, key: &[f32], value: &[f32]) -> ArchResult<()>;
123
124 /// Retrieve all cached keys for a layer up to the current sequence length.
125 fn get_keys(&self, layer: usize) -> ArchResult<&[f32]>;
126
127 /// Retrieve all cached values for a layer up to the current sequence length.
128 fn get_values(&self, layer: usize) -> ArchResult<&[f32]>;
129
130 /// Advance the cache position by one token.
131 ///
132 /// Called after all layers have stored their K/V for the current token.
133 fn advance(&mut self);
134
135 /// KV dimension per token (num_kv_heads * head_dim).
136 ///
137 /// Returns `0` by default, which signals that per-token iteration is not
138 /// available via the default [`for_each_key`](Self::for_each_key) /
139 /// [`for_each_value`](Self::for_each_value) helpers. Implementations that
140 /// know their KV dimension should override this.
141 fn kv_dim(&self) -> usize {
142 0
143 }
144
145 /// Iterate over every cached key token for `layer`, calling `f(pos, key_data)`.
146 ///
147 /// The default implementation chunks `get_keys()` using [`kv_dim()`](Self::kv_dim).
148 /// Paged implementations override this to avoid assembling a contiguous slice.
149 ///
150 /// # Errors
151 ///
152 /// Returns [`ArchError::NotSupported`] if `kv_dim()` returns `0`.
153 /// Propagates any error from [`get_keys()`](Self::get_keys).
154 fn for_each_key(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
155 let dim = self.kv_dim();
156 if dim == 0 {
157 return Err(ArchError::NotSupported {
158 detail: "kv_dim() not implemented; cannot iterate per-token keys".to_string(),
159 });
160 }
161 let keys = self.get_keys(layer)?;
162 for (pos, slice) in keys.chunks_exact(dim).enumerate() {
163 f(pos, slice);
164 }
165 Ok(())
166 }
167
168 /// Iterate over every cached value token for `layer`, calling `f(pos, value_data)`.
169 ///
170 /// The default implementation chunks `get_values()` using [`kv_dim()`](Self::kv_dim).
171 /// Paged implementations override this to avoid assembling a contiguous slice.
172 ///
173 /// # Errors
174 ///
175 /// Returns [`ArchError::NotSupported`] if `kv_dim()` returns `0`.
176 /// Propagates any error from [`get_values()`](Self::get_values).
177 fn for_each_value(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
178 let dim = self.kv_dim();
179 if dim == 0 {
180 return Err(ArchError::NotSupported {
181 detail: "kv_dim() not implemented; cannot iterate per-token values".to_string(),
182 });
183 }
184 let values = self.get_values(layer)?;
185 for (pos, slice) in values.chunks_exact(dim).enumerate() {
186 f(pos, slice);
187 }
188 Ok(())
189 }
190}
191
192/// Trait for running forward passes through a loaded model.
193///
194/// Implementations own the model weights and maintain any mutable state
195/// needed during inference (e.g., internal buffers).
196pub trait ForwardPass: Send + Sync {
197 /// Run one forward pass, returning logits for the next token prediction.
198 ///
199 /// # Arguments
200 /// * `tokens` - Input token IDs for this step.
201 /// * `kv_cache` - Mutable reference to the key-value cache.
202 ///
203 /// # Returns
204 /// A vector of logits with length equal to the vocabulary size.
205 fn forward(&mut self, tokens: &[u32], kv_cache: &mut dyn KvCacheAccess)
206 -> ArchResult<Vec<f32>>;
207
208 /// Run one forward pass, returning the post-output-norm hidden state
209 /// (not projected through the LM head).
210 ///
211 /// This is the embedding extraction path: it runs all transformer layers
212 /// and applies the final RMSNorm, but stops before the LM-head projection
213 /// that maps hidden_size → vocab_size. The returned vector has length
214 /// `hidden_size`, not `vocab_size`.
215 ///
216 /// The default implementation returns [`ArchError::NotSupported`].
217 /// Each architecture overrides this with a concrete implementation.
218 fn embed(&mut self, tokens: &[u32], kv_cache: &mut dyn KvCacheAccess) -> ArchResult<Vec<f32>> {
219 let _ = (tokens, kv_cache);
220 Err(ArchError::NotSupported {
221 detail: "embed() not implemented for this architecture".to_string(),
222 })
223 }
224
225 /// Run one forward pass, returning per-token hidden states for **all** tokens
226 /// as a flat `[seq_len × hidden_size]` vector in row-major order.
227 ///
228 /// This is the multi-token embedding extraction path used when a pooling
229 /// mode other than `Last` is requested. The returned vector has length
230 /// `seq_len * hidden_size`.
231 ///
232 /// The default implementation returns [`ArchError::NotSupported`].
233 /// Architectures that require multi-token pooling should override this.
234 fn embed_all(
235 &mut self,
236 tokens: &[u32],
237 kv_cache: &mut dyn KvCacheAccess,
238 ) -> ArchResult<Vec<f32>> {
239 let _ = (tokens, kv_cache);
240 Err(ArchError::NotSupported {
241 detail: "embed_all() not implemented for this architecture; use embed() instead"
242 .to_string(),
243 })
244 }
245
246 /// Returns the model's vocabulary size.
247 fn vocab_size(&self) -> usize;
248
249 /// Returns the model's maximum context length.
250 fn max_context_length(&self) -> usize;
251
252 /// Returns the model's hidden size (embedding dimension).
253 fn hidden_size(&self) -> usize;
254
255 /// Apply LoRA adapter corrections to this model's linear layers.
256 ///
257 /// Walks the model's `QuantLinear` fields and calls
258 /// [`QuantLinear::set_lora`](crate::common::linear::QuantLinear::set_lora)
259 /// for every layer whose name appears in `lora.adapters`.
260 ///
261 /// The default implementation is a no-op: models that do not yet support
262 /// LoRA patching will silently ignore the adapter. Override this method
263 /// in each architecture implementation that supports LoRA.
264 fn apply_lora(&mut self, lora: &LoadedLora) -> ArchResult<()> {
265 let _ = lora;
266 Ok(())
267 }
268
269 /// Apply an ordered stack of LoRA adapters.
270 ///
271 /// Default implementation: applies each adapter in the stack in order via
272 /// [`apply_lora`](Self::apply_lora), ignoring per-entry scale multipliers.
273 /// Override for architectures that support scaled stacking.
274 fn apply_lora_stack(&mut self, stack: &LoraStack) -> ArchResult<()> {
275 for (lora, _scale) in stack.entries() {
276 self.apply_lora(lora)?;
277 }
278 Ok(())
279 }
280
281 /// Returns the sliding-window attention configuration for this loaded model.
282 ///
283 /// Returns `Some((window_size, is_interleaved))` when the model uses SWA
284 /// on at least some layers, or `None` for pure global attention.
285 fn swa_config(&self) -> Option<(u32, bool)> {
286 None
287 }
288
289 /// Set a persistent LoRA adapter stack that applies to all subsequent
290 /// `forward()` calls.
291 ///
292 /// Default implementation is a no-op — architectures that do not support
293 /// LoRA stacking silently ignore the call (compatible with the existing
294 /// `apply_lora_stack` interface).
295 ///
296 /// # Errors
297 ///
298 /// Returns [`ArchError::LoraIncompatible`] if the adapter's rank or
299 /// dimensions are incompatible with this model.
300 fn with_lora_stack(&mut self, _stack: LoraStack) -> ArchResult<()> {
301 Ok(())
302 }
303
304 /// Remove all LoRA adapters from every `QuantLinear` in this model.
305 ///
306 /// The inverse of [`Self::apply_lora_stack`]: sets `lora = None` on every linear
307 /// layer that was patched. The default is a no-op; architectures that
308 /// override [`Self::apply_lora`] must also override this.
309 fn unapply_all_loras(&mut self) {}
310
311 /// Allocate a fresh per-sequence state object for this model.
312 ///
313 /// The runtime calls this once per pool slot at model load time.
314 /// Default implementation returns an [`AttentionSequenceState`] suitable
315 /// for all KV-cache-based architectures.
316 ///
317 /// SSM and hybrid architectures **must** override this to return the
318 /// correct state type (e.g. [`Mamba2SequenceState`] or `JambaSequenceState`).
319 ///
320 /// [`Mamba2SequenceState`]: crate::common::sequence_state::Mamba2SequenceState
321 fn allocate_sequence_state(&self, max_context_length: usize) -> Box<dyn SequenceState> {
322 Box::new(AttentionSequenceState::new(max_context_length))
323 }
324
325 /// Run a batched decode-phase forward pass across multiple concurrent requests.
326 ///
327 /// Each slot in `kv_view` corresponds to one batch element. `q_batch` is
328 /// laid out as `[batch_size, num_heads, head_dim]` in row-major order.
329 ///
330 /// The default implementation returns [`ArchError::NotSupported`].
331 /// Architectures that support continuous batching override this.
332 ///
333 /// # Arguments
334 ///
335 /// * `q_batch` - Query tensor, shape `[batch_size, num_heads, head_dim]`.
336 /// * `kv_view` - Per-slot KV cache view.
337 /// * `num_heads` - Number of query attention heads.
338 /// * `head_dim` - Per-head dimension.
339 /// * `scale` - Softmax scale factor (typically `1 / sqrt(head_dim)`).
340 ///
341 /// # Returns
342 ///
343 /// Output tensor with layout `[batch_size, num_heads, head_dim]`.
344 fn forward_batched(
345 &mut self,
346 q_batch: &[f32],
347 kv_view: &dyn BatchedKvView,
348 num_heads: usize,
349 head_dim: usize,
350 scale: f32,
351 ) -> ArchResult<Vec<f32>> {
352 let _ = (q_batch, kv_view, num_heads, head_dim, scale);
353 Err(ArchError::NotSupported {
354 detail: "forward_batched() not implemented for this architecture".to_string(),
355 })
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::error::ArchError;
363
364 /// A minimal stub implementing ForwardPass to test the default
365 /// `forward_batched` returns NotSupported.
366 struct StubModel;
367
368 impl ForwardPass for StubModel {
369 fn forward(
370 &mut self,
371 _tokens: &[u32],
372 _kv_cache: &mut dyn KvCacheAccess,
373 ) -> ArchResult<Vec<f32>> {
374 Ok(vec![])
375 }
376
377 fn vocab_size(&self) -> usize {
378 1
379 }
380
381 fn max_context_length(&self) -> usize {
382 1
383 }
384
385 fn hidden_size(&self) -> usize {
386 1
387 }
388 }
389
390 /// A minimal BatchedKvView for testing.
391 struct EmptyKvView;
392 impl BatchedKvView for EmptyKvView {
393 fn slot_count(&self) -> usize {
394 0
395 }
396
397 fn kv_for_slot(&self, _slot: usize) -> (&[f32], &[f32]) {
398 (&[], &[])
399 }
400
401 fn position(&self, _slot: usize) -> usize {
402 0
403 }
404 }
405
406 #[test]
407 fn forward_batched_default_returns_not_supported() {
408 let mut model = StubModel;
409 let view = EmptyKvView;
410 let result = model.forward_batched(&[], &view, 2, 4, 0.5);
411 match result {
412 Err(ArchError::NotSupported { detail }) => {
413 assert!(
414 detail.contains("forward_batched"),
415 "error detail should mention forward_batched, got: {detail}"
416 );
417 }
418 other => panic!("expected NotSupported, got: {other:?}"),
419 }
420 }
421
422 #[test]
423 fn forward_batched_empty_batch_via_default_is_not_supported() {
424 // The default implementation always returns NotSupported regardless of
425 // batch size — it cannot know the correct answer without weights.
426 let mut model = StubModel;
427 let view = EmptyKvView;
428 let result = model.forward_batched(&[], &view, 1, 8, 1.0);
429 assert!(result.is_err(), "default must return Err");
430 }
431
432 #[test]
433 fn kv_slot_construction() {
434 let slot = KvSlot::new(42, 7, 100);
435 assert_eq!(slot.request_id, 42);
436 assert_eq!(slot.kv_cache_idx, 7);
437 assert_eq!(slot.position, 100);
438 }
439
440 #[test]
441 fn kv_cache_access_default_kv_dim_is_zero() {
442 /// Minimal KvCacheAccess impl that does not override kv_dim().
443 struct MinimalCache;
444 impl KvCacheAccess for MinimalCache {
445 fn seq_len(&self) -> usize {
446 0
447 }
448 fn store_kv(&mut self, _layer: usize, _key: &[f32], _value: &[f32]) -> ArchResult<()> {
449 Ok(())
450 }
451 fn get_keys(&self, _layer: usize) -> ArchResult<&[f32]> {
452 Ok(&[])
453 }
454 fn get_values(&self, _layer: usize) -> ArchResult<&[f32]> {
455 Ok(&[])
456 }
457 fn advance(&mut self) {}
458 }
459
460 let cache = MinimalCache;
461 assert_eq!(cache.kv_dim(), 0, "default kv_dim must be 0");
462
463 // for_each_key must return NotSupported when kv_dim == 0
464 let mut called = false;
465 let result = cache.for_each_key(0, &mut |_, _| {
466 called = true;
467 });
468 assert!(result.is_err(), "must return Err when kv_dim() == 0");
469 assert!(!called, "callback must not be invoked");
470 }
471}