Skip to main content

oxibonsai_runtime/
engine.rs

1//! Inference engine orchestrating model loading and generation.
2//!
3//! The [`InferenceEngine`] is the main entry point for running inference.
4//! It owns the model, kernel dispatcher, and sampler, and provides both
5//! blocking ([`InferenceEngine::generate`]) and streaming
6//! ([`InferenceEngine::generate_streaming`]) generation APIs.
7
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11
12use oxibonsai_core::config::Qwen3Config;
13use oxibonsai_core::gguf::reader::GgufFile;
14use oxibonsai_kernels::traits::OneBitKernel;
15use oxibonsai_kernels::KernelDispatcher;
16use oxibonsai_model::model::BonsaiModel;
17
18use crate::batch_engine::{self, BatchResult};
19use crate::error::{RuntimeError, RuntimeResult};
20use crate::metrics::InferenceMetrics;
21#[cfg(all(feature = "metal", target_os = "macos"))]
22use crate::ngram_cache::NgramCache;
23use crate::request_id::RequestId;
24use crate::request_metrics::{RequestRateAggregator, RequestRateSnapshot, RequestRateTracker};
25use crate::sampling::{Sampler, SamplingParams};
26
27/// EOS token for Qwen3 models.
28pub const EOS_TOKEN_ID: u32 = 151645;
29
30/// Statistics about engine usage, accumulated over the engine's lifetime.
31#[derive(Debug)]
32pub struct EngineStats {
33    /// Total number of tokens generated.
34    pub total_tokens_generated: AtomicU64,
35    /// Total number of inference requests completed.
36    pub total_requests: AtomicU64,
37    /// Number of currently active sessions.
38    pub active_sessions: AtomicUsize,
39    /// Engine start time.
40    pub start_time: Instant,
41}
42
43impl EngineStats {
44    /// Create new engine stats, recording the current time as start.
45    pub fn new() -> Self {
46        Self {
47            total_tokens_generated: AtomicU64::new(0),
48            total_requests: AtomicU64::new(0),
49            active_sessions: AtomicUsize::new(0),
50            start_time: Instant::now(),
51        }
52    }
53
54    /// Engine uptime in seconds.
55    pub fn uptime_seconds(&self) -> f64 {
56        self.start_time.elapsed().as_secs_f64()
57    }
58
59    /// Record that a request completed with the given number of generated tokens.
60    pub fn record_request(&self, tokens_generated: usize) {
61        self.total_tokens_generated
62            .fetch_add(tokens_generated as u64, Ordering::Relaxed);
63        self.total_requests.fetch_add(1, Ordering::Relaxed);
64    }
65
66    /// Get total tokens generated.
67    pub fn tokens_generated(&self) -> u64 {
68        self.total_tokens_generated.load(Ordering::Relaxed)
69    }
70
71    /// Get total requests completed.
72    pub fn requests_completed(&self) -> u64 {
73        self.total_requests.load(Ordering::Relaxed)
74    }
75
76    /// Get number of active sessions.
77    pub fn active_session_count(&self) -> usize {
78        self.active_sessions.load(Ordering::Relaxed)
79    }
80
81    /// Average tokens per request (returns 0.0 if no requests).
82    pub fn avg_tokens_per_request(&self) -> f64 {
83        let reqs = self.requests_completed();
84        if reqs == 0 {
85            return 0.0;
86        }
87        self.tokens_generated() as f64 / reqs as f64
88    }
89}
90
91impl Default for EngineStats {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97/// Top-level inference engine.
98pub struct InferenceEngine<'a> {
99    model: BonsaiModel<'a>,
100    kernel: KernelDispatcher,
101    sampler: Sampler,
102    metrics: Option<Arc<InferenceMetrics>>,
103    stats: Arc<EngineStats>,
104    /// Cumulative number of tokens that have been processed by
105    /// [`InferenceEngine::prefill_from_pos`] across the engine's lifetime.
106    ///
107    /// Used by the prefix-cache integration to verify that cached prefixes
108    /// actually reduce prefill work — the cached portion of a prompt is not
109    /// re-fed into prefill, so a repeated prompt should increment this
110    /// counter by strictly fewer tokens than its full length.
111    prefill_token_count: u64,
112    /// Optional workload-level rate aggregator. When attached, every
113    /// `generate_tracked` call records its [`RequestRateSnapshot`] here on
114    /// completion, allowing the operator to surface workload-level p50/p95
115    /// inter-token latency, EWMA tokens-per-second, and queue-wait gauges
116    /// (see [`InferenceMetrics::update_request_rate`]).
117    rate_aggregator: Option<Arc<RequestRateAggregator>>,
118}
119
120impl<'a> InferenceEngine<'a> {
121    /// Create a new inference engine from a configuration (no weights — for testing).
122    pub fn new(config: Qwen3Config, sampling_params: SamplingParams, seed: u64) -> Self {
123        let model = BonsaiModel::new(config);
124        let kernel = KernelDispatcher::auto_detect();
125        let sampler = Sampler::new(sampling_params, seed);
126
127        tracing::info!(kernel = kernel.name(), "inference engine initialized");
128
129        Self {
130            model,
131            kernel,
132            sampler,
133            metrics: None,
134            stats: Arc::new(EngineStats::new()),
135            prefill_token_count: 0,
136            rate_aggregator: None,
137        }
138    }
139
140    /// Wrap an already-constructed [`BonsaiModel`] in an inference engine.
141    ///
142    /// Lets tests (and future custom-model paths) build a model with
143    /// non-trivial weights and then attach the standard sampler/kernel
144    /// machinery without going through the GGUF loader.
145    pub fn from_model(model: BonsaiModel<'a>, sampling_params: SamplingParams, seed: u64) -> Self {
146        Self::from_model_with_kernel(
147            model,
148            KernelDispatcher::auto_detect(),
149            sampling_params,
150            seed,
151        )
152    }
153
154    /// Wrap an already-constructed [`BonsaiModel`] using a caller-supplied
155    /// kernel dispatcher.
156    ///
157    /// Use this when you need to pin the engine to a specific kernel tier
158    /// (e.g. a CPU-only `KernelTier::Reference` for tests that exercise the
159    /// CPU KV-cache path on a host that would otherwise auto-detect a GPU).
160    pub fn from_model_with_kernel(
161        model: BonsaiModel<'a>,
162        kernel: KernelDispatcher,
163        sampling_params: SamplingParams,
164        seed: u64,
165    ) -> Self {
166        let sampler = Sampler::new(sampling_params, seed);
167        Self {
168            model,
169            kernel,
170            sampler,
171            metrics: None,
172            stats: Arc::new(EngineStats::new()),
173            prefill_token_count: 0,
174            rate_aggregator: None,
175        }
176    }
177
178    /// Create a new inference engine from a loaded GGUF file.
179    pub fn from_gguf(
180        gguf: &'a GgufFile<'a>,
181        sampling_params: SamplingParams,
182        seed: u64,
183        max_seq_len: usize,
184    ) -> RuntimeResult<Self> {
185        let mut model = BonsaiModel::from_gguf(gguf, max_seq_len)?;
186        let kernel = KernelDispatcher::auto_detect();
187
188        // Upload all model weights to GPU memory once (no-op on CPU-only tiers).
189        model.upload_weights_to_gpu(&kernel);
190
191        // Pre-build GPU weight cache eagerly so it's outside the timing window.
192        #[cfg(all(feature = "metal", target_os = "macos"))]
193        {
194            tracing::info!("pre-building GPU weight cache");
195            model.get_or_create_gpu_cache().map_err(|e| {
196                RuntimeError::Model(oxibonsai_model::error::ModelError::Internal(format!(
197                    "GPU weight cache init: {e}"
198                )))
199            })?;
200        }
201
202        // Pre-warm both CUDA code paths so all first-call overhead (CUDA driver graph
203        // capture, prefill kernel module loading, weight uploads) is paid during model
204        // loading and NOT inside the benchmark timer.
205        //
206        // Two passes are required:
207        //   1. Single-token decode via `model.forward` → captures the 36-layer CUDA
208        //      driver graph (slow-path ~490ms becomes fast-path ~44ms thereafter).
209        //   2. Two-token batch via `model.forward_prefill` → loads the prefill PTX
210        //      module into GPU driver memory (`init_prefill_modules`) which takes
211        //      ~100-200ms on first call and is not triggered by step 1.
212        //
213        // Both warmup K/V cache writes are at positions that real inference
214        // overwrites immediately (K/V is written before attention reads it).
215        // The CUDA KV cache is separate from the CPU-side `model.kv_cache`.
216        #[cfg(all(
217            feature = "native-cuda",
218            not(all(feature = "metal", target_os = "macos")),
219            any(target_os = "linux", target_os = "windows")
220        ))]
221        {
222            tracing::info!("CUDA warmup: pre-capturing driver graph + prefill modules");
223            // Step 1: capture the 36-layer decode CUDA driver graph.
224            let _ = model.forward(0, 0, &kernel);
225            // Step 2: pre-load the batch-prefill PTX module into GPU driver memory
226            // (`init_prefill_modules`) and pre-allocate the prefill KV cache,
227            // single-token attention buffers, and activation buffers.
228            // We use 17 tokens so the CUDA batch prefill code path is exercised
229            // (prompts ≤ 16 tokens use the fast decode-graph path instead).
230            // This ensures all one-time batch-prefill setup costs are paid before
231            // the benchmark timer, covering longer prompts without a cold-start penalty.
232            let _ = model.forward_prefill(&[0u32; 17], 0, &kernel);
233            tracing::info!("CUDA warmup complete");
234        }
235
236        let sampler = Sampler::new(sampling_params, seed);
237
238        tracing::info!(kernel = kernel.name(), "inference engine loaded from GGUF");
239
240        Ok(Self {
241            model,
242            kernel,
243            sampler,
244            metrics: None,
245            stats: Arc::new(EngineStats::new()),
246            prefill_token_count: 0,
247            rate_aggregator: None,
248        })
249    }
250
251    /// Attach shared metrics to this engine for recording inference telemetry.
252    pub fn set_metrics(&mut self, metrics: Arc<InferenceMetrics>) {
253        self.metrics = Some(metrics);
254    }
255
256    /// Attach a workload-level [`RequestRateAggregator`] to this engine.
257    ///
258    /// Once attached, every call to [`InferenceEngine::generate_tracked`] (or
259    /// [`InferenceEngine::generate_with_request_id`]) will push its
260    /// per-request [`RequestRateSnapshot`] into the aggregator on completion.
261    /// The aggregator is reference-counted, so the same instance can be shared
262    /// with the Prometheus metrics layer or the admin endpoints.
263    pub fn set_rate_aggregator(&mut self, aggregator: Arc<RequestRateAggregator>) {
264        self.rate_aggregator = Some(aggregator);
265    }
266
267    /// Read-only access to the attached rate aggregator, if any.
268    pub fn rate_aggregator(&self) -> Option<&Arc<RequestRateAggregator>> {
269        self.rate_aggregator.as_ref()
270    }
271
272    /// Get a reference to the model.
273    pub fn model(&self) -> &BonsaiModel<'a> {
274        &self.model
275    }
276
277    /// Get a mutable reference to the model.
278    ///
279    /// Used by the prefix-cache integration to inject restored KV blocks
280    /// before running the abbreviated prefill.
281    pub fn model_mut(&mut self) -> &mut BonsaiModel<'a> {
282        &mut self.model
283    }
284
285    /// Get a reference to the kernel dispatcher.
286    pub fn kernel(&self) -> &KernelDispatcher {
287        &self.kernel
288    }
289
290    /// Run prefill at a given KV-cache offset.
291    ///
292    /// Unlike [`InferenceEngine::generate`], this does **not** reset the
293    /// model's KV cache before execution: callers (e.g. the prefix-cache
294    /// engine) are expected to have prepared the cache state explicitly.
295    ///
296    /// Increments the [`prefill_token_count`](Self::prefill_token_count)
297    /// counter by `prompt_tokens.len()` on success.
298    pub fn prefill_from_pos(
299        &mut self,
300        prompt_tokens: &[u32],
301        pos_start: usize,
302    ) -> RuntimeResult<Vec<f32>> {
303        let logits = self
304            .model
305            .forward_prefill(prompt_tokens, pos_start, &self.kernel)?;
306        self.prefill_token_count = self
307            .prefill_token_count
308            .saturating_add(prompt_tokens.len() as u64);
309        Ok(logits)
310    }
311
312    /// Forward one token at the given absolute position.
313    pub fn decode_step(&mut self, token: u32, pos: usize) -> RuntimeResult<Vec<f32>> {
314        Ok(self.model.forward(token, pos, &self.kernel)?)
315    }
316
317    /// Sample one token from `logits` using the engine's current sampler.
318    pub fn sample(&mut self, logits: &[f32]) -> RuntimeResult<u32> {
319        self.sampler.sample(logits)
320    }
321
322    /// Cumulative number of tokens that have been processed by
323    /// [`InferenceEngine::prefill_from_pos`] over this engine's lifetime.
324    pub fn prefill_token_count(&self) -> u64 {
325        self.prefill_token_count
326    }
327
328    /// Reset the model state for a new conversation.
329    pub fn reset(&mut self) {
330        self.model.reset();
331    }
332
333    /// Get a shared reference to the engine statistics.
334    pub fn stats(&self) -> &Arc<EngineStats> {
335        &self.stats
336    }
337
338    /// Number of currently active sessions (tracked via stats).
339    pub fn active_sessions(&self) -> usize {
340        self.stats.active_session_count()
341    }
342
343    /// Total number of completed requests (tracked via stats).
344    pub fn session_count(&self) -> u64 {
345        self.stats.requests_completed()
346    }
347
348    /// Process a batch of prompts, delegating to [`batch_engine::batch_generate`].
349    ///
350    /// Resets the engine state between each prompt. Returns one result per prompt.
351    pub fn batch_generate(
352        &mut self,
353        prompts: &[Vec<u32>],
354        max_tokens: usize,
355    ) -> Vec<RuntimeResult<BatchResult>> {
356        self.stats.active_sessions.fetch_add(1, Ordering::Relaxed);
357
358        let results = batch_engine::batch_generate(self, prompts, max_tokens);
359
360        // Record stats for successful results
361        for br in results.iter().flatten() {
362            self.stats.record_request(br.generated_tokens.len());
363        }
364
365        self.stats.active_sessions.fetch_sub(1, Ordering::Relaxed);
366
367        results
368    }
369
370    /// Generate tokens from a prompt.
371    ///
372    /// Runs prefill (process the entire prompt), then decodes
373    /// token by token until `max_tokens` or EOS is reached.
374    /// Returns the generated token IDs (not including the prompt).
375    #[tracing::instrument(skip(self, prompt_tokens), fields(prompt_len = prompt_tokens.len()))]
376    pub fn generate(
377        &mut self,
378        prompt_tokens: &[u32],
379        max_tokens: usize,
380    ) -> RuntimeResult<Vec<u32>> {
381        if prompt_tokens.is_empty() {
382            return Ok(vec![]);
383        }
384
385        // ═══════════════════════════════════════════════════════
386        // 1. Prefill: batch process all prompt tokens
387        // ═══════════════════════════════════════════════════════
388        let prefill_start = std::time::Instant::now();
389        let mut last_logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
390        if let Some(m) = &self.metrics {
391            m.prefill_duration_seconds
392                .observe(prefill_start.elapsed().as_secs_f64());
393        }
394
395        // ═══════════════════════════════════════════════════════
396        // 2. Decode: sample and generate
397        // ═══════════════════════════════════════════════════════
398        let decode_start = std::time::Instant::now();
399        let mut output_tokens = Vec::with_capacity(max_tokens);
400
401        for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
402            let step_start = std::time::Instant::now();
403
404            // Sample next token
405            let next_token = self.sampler.sample(&last_logits)?;
406
407            // Check for EOS
408            if next_token == EOS_TOKEN_ID {
409                tracing::debug!(pos, "EOS token generated");
410                break;
411            }
412
413            output_tokens.push(next_token);
414
415            // Forward the generated token
416            last_logits = self.model.forward(next_token, pos, &self.kernel)?;
417
418            if let Some(m) = &self.metrics {
419                m.decode_token_duration_seconds
420                    .observe(step_start.elapsed().as_secs_f64());
421            }
422        }
423
424        // Record tokens/sec and update memory gauge
425        if let Some(m) = &self.metrics {
426            let decode_elapsed = decode_start.elapsed().as_secs_f64();
427            if decode_elapsed > 0.0 && !output_tokens.is_empty() {
428                let tok_per_sec = output_tokens.len() as f64 / decode_elapsed;
429                m.tokens_per_second.observe(tok_per_sec);
430            }
431            m.tokens_generated_total.inc_by(output_tokens.len() as u64);
432            m.update_memory_from_rss();
433        }
434
435        // Record engine-level stats
436        self.stats.record_request(output_tokens.len());
437
438        tracing::info!(
439            prompt_len = prompt_tokens.len(),
440            generated = output_tokens.len(),
441            "generation complete"
442        );
443
444        Ok(output_tokens)
445    }
446
447    /// Generate tokens from a prompt while populating a [`RequestRateTracker`].
448    ///
449    /// Behaves identically to [`InferenceEngine::generate`] but additionally:
450    /// - records `record_admission()` immediately on entry,
451    /// - records `record_first_token()` for the first sampled token,
452    /// - records `record_token()` for every subsequent sampled token,
453    /// - on success, pushes the resulting [`RequestRateSnapshot`] into the
454    ///   engine's attached [`RequestRateAggregator`] (if any).
455    ///
456    /// The tracker is borrowed mutably so callers can inspect intermediate
457    /// state via [`RequestRateTracker::snapshot`] after the call returns.
458    #[tracing::instrument(skip(self, prompt_tokens, tracker), fields(prompt_len = prompt_tokens.len()))]
459    pub fn generate_tracked(
460        &mut self,
461        prompt_tokens: &[u32],
462        max_tokens: usize,
463        tracker: &mut RequestRateTracker,
464    ) -> RuntimeResult<Vec<u32>> {
465        if prompt_tokens.is_empty() {
466            return Ok(vec![]);
467        }
468        tracker.record_admission();
469
470        let prefill_start = std::time::Instant::now();
471        let mut last_logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
472        if let Some(m) = &self.metrics {
473            m.prefill_duration_seconds
474                .observe(prefill_start.elapsed().as_secs_f64());
475        }
476
477        let decode_start = std::time::Instant::now();
478        let mut output_tokens = Vec::with_capacity(max_tokens);
479        let mut first_token_recorded = false;
480
481        for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
482            let step_start = std::time::Instant::now();
483            let next_token = self.sampler.sample(&last_logits)?;
484            if next_token == EOS_TOKEN_ID {
485                tracing::debug!(pos, "EOS token generated");
486                break;
487            }
488            output_tokens.push(next_token);
489            if !first_token_recorded {
490                tracker.record_first_token();
491                first_token_recorded = true;
492            } else {
493                tracker.record_token();
494            }
495            last_logits = self.model.forward(next_token, pos, &self.kernel)?;
496
497            if let Some(m) = &self.metrics {
498                m.decode_token_duration_seconds
499                    .observe(step_start.elapsed().as_secs_f64());
500            }
501        }
502
503        if let Some(m) = &self.metrics {
504            let decode_elapsed = decode_start.elapsed().as_secs_f64();
505            if decode_elapsed > 0.0 && !output_tokens.is_empty() {
506                let tok_per_sec = output_tokens.len() as f64 / decode_elapsed;
507                m.tokens_per_second.observe(tok_per_sec);
508            }
509            m.tokens_generated_total.inc_by(output_tokens.len() as u64);
510            m.update_memory_from_rss();
511        }
512        self.stats.record_request(output_tokens.len());
513
514        if let Some(agg) = &self.rate_aggregator {
515            let snap: RequestRateSnapshot = tracker.snapshot();
516            agg.record(snap);
517        }
518
519        tracing::info!(
520            prompt_len = prompt_tokens.len(),
521            generated = output_tokens.len(),
522            "tracked generation complete"
523        );
524
525        Ok(output_tokens)
526    }
527
528    /// Generate tokens from a prompt with a [`RequestId`] tagging the
529    /// surrounding tracing span and an internally-managed
530    /// [`RequestRateTracker`].
531    ///
532    /// Returns both the generated tokens and the final tracker so callers
533    /// can extract per-request metrics (e.g. queue-wait, p95 inter-token
534    /// latency) for client-side observability.
535    pub fn generate_with_request_id(
536        &mut self,
537        request_id: RequestId,
538        prompt_tokens: &[u32],
539        max_tokens: usize,
540    ) -> RuntimeResult<(Vec<u32>, RequestRateTracker)> {
541        let span = tracing::info_span!("generate_request", request_id = %request_id);
542        let _enter = span.enter();
543        let mut tracker = RequestRateTracker::new();
544        let tokens = self.generate_tracked(prompt_tokens, max_tokens, &mut tracker)?;
545        Ok((tokens, tracker))
546    }
547
548    /// Generate tokens from a prompt using a specific seed for this run.
549    ///
550    /// Temporarily overrides the sampler seed for deterministic multi-completion
551    /// generation (`n > 1`). The sampler state is replaced for the duration of
552    /// this call and then restored.
553    pub fn generate_with_seed(
554        &mut self,
555        prompt_tokens: &[u32],
556        max_tokens: usize,
557        seed: u64,
558        params: &crate::sampling::SamplingParams,
559    ) -> RuntimeResult<Vec<u32>> {
560        // Swap in a fresh sampler with the given seed
561        let old_sampler = std::mem::replace(
562            &mut self.sampler,
563            crate::sampling::Sampler::new(params.clone(), seed),
564        );
565        let result = self.generate(prompt_tokens, max_tokens);
566        // Restore the original sampler
567        self.sampler = old_sampler;
568        result
569    }
570
571    /// Generate tokens one at a time, sending each through the channel.
572    /// Returns the total count of generated tokens.
573    ///
574    /// Not available on WASM targets (tokio channels not supported on wasm32-unknown-unknown).
575    #[cfg(not(target_arch = "wasm32"))]
576    #[tracing::instrument(skip(self, prompt_tokens, tx), fields(prompt_len = prompt_tokens.len()))]
577    pub fn generate_streaming(
578        &mut self,
579        prompt_tokens: &[u32],
580        max_tokens: usize,
581        tx: &tokio::sync::mpsc::UnboundedSender<u32>,
582    ) -> RuntimeResult<usize> {
583        if prompt_tokens.is_empty() {
584            return Ok(0);
585        }
586
587        // Prefill: batch process all prompt tokens
588        let prefill_start = std::time::Instant::now();
589        let mut logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
590        if let Some(m) = &self.metrics {
591            m.prefill_duration_seconds
592                .observe(prefill_start.elapsed().as_secs_f64());
593        }
594
595        let decode_start = std::time::Instant::now();
596        let mut generated = 0;
597
598        for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
599            let step_start = std::time::Instant::now();
600            let next_token = self.sampler.sample(&logits)?;
601
602            if next_token == EOS_TOKEN_ID {
603                tracing::debug!(pos, "EOS token generated (streaming)");
604                break;
605            }
606
607            // Send token through channel; if receiver dropped, stop generating
608            if tx.send(next_token).is_err() {
609                tracing::debug!(pos, "receiver dropped, stopping generation");
610                break;
611            }
612
613            logits = self.model.forward(next_token, pos, &self.kernel)?;
614            generated += 1;
615
616            if let Some(m) = &self.metrics {
617                m.decode_token_duration_seconds
618                    .observe(step_start.elapsed().as_secs_f64());
619            }
620        }
621
622        // Record tokens/sec and update memory gauge
623        if let Some(m) = &self.metrics {
624            let decode_elapsed = decode_start.elapsed().as_secs_f64();
625            if decode_elapsed > 0.0 && generated > 0 {
626                let tok_per_sec = generated as f64 / decode_elapsed;
627                m.tokens_per_second.observe(tok_per_sec);
628            }
629            m.tokens_generated_total.inc_by(generated as u64);
630            m.update_memory_from_rss();
631        }
632
633        tracing::info!(
634            prompt_len = prompt_tokens.len(),
635            generated,
636            "streaming generation complete"
637        );
638
639        Ok(generated)
640    }
641
642    /// Streaming generation using a synchronous `std::sync::mpsc::Sender`.
643    ///
644    /// Each generated token is sent through the channel immediately, allowing
645    /// the consumer to print tokens as they arrive without requiring a tokio runtime.
646    #[tracing::instrument(skip(self, prompt_tokens, tx), fields(prompt_len = prompt_tokens.len()))]
647    pub fn generate_streaming_sync(
648        &mut self,
649        prompt_tokens: &[u32],
650        max_tokens: usize,
651        tx: &std::sync::mpsc::Sender<u32>,
652    ) -> RuntimeResult<usize> {
653        if prompt_tokens.is_empty() {
654            return Ok(0);
655        }
656
657        // Prefill: batch process all prompt tokens
658        let prefill_start = std::time::Instant::now();
659        let mut logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
660        if let Some(m) = &self.metrics {
661            m.prefill_duration_seconds
662                .observe(prefill_start.elapsed().as_secs_f64());
663        }
664
665        let decode_start = std::time::Instant::now();
666        let mut generated = 0;
667
668        for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
669            let step_start = std::time::Instant::now();
670
671            let next_token = self.sampler.sample(&logits)?;
672
673            if next_token == EOS_TOKEN_ID {
674                tracing::debug!(pos, "EOS token generated (streaming_sync)");
675                break;
676            }
677
678            if tx.send(next_token).is_err() {
679                tracing::debug!(pos, "receiver dropped, stopping generation");
680                break;
681            }
682
683            logits = self.model.forward(next_token, pos, &self.kernel)?;
684            generated += 1;
685
686            if let Some(m) = &self.metrics {
687                m.decode_token_duration_seconds
688                    .observe(step_start.elapsed().as_secs_f64());
689            }
690        }
691
692        if let Some(m) = &self.metrics {
693            let decode_elapsed = decode_start.elapsed().as_secs_f64();
694            if decode_elapsed > 0.0 && generated > 0 {
695                let tok_per_sec = generated as f64 / decode_elapsed;
696                m.tokens_per_second.observe(tok_per_sec);
697            }
698            m.tokens_generated_total.inc_by(generated as u64);
699            m.update_memory_from_rss();
700        }
701
702        tracing::info!(
703            prompt_len = prompt_tokens.len(),
704            generated,
705            "streaming sync generation complete"
706        );
707
708        Ok(generated)
709    }
710
711    /// Greedy generation entirely on GPU (temperature=0, argmax on Metal).
712    ///
713    /// Runs the full forward pass + argmax in a single GPU command buffer per
714    /// token, downloading only the 4-byte token ID instead of the ~607KB logits
715    /// vector. Falls back to the normal `generate` path if the GPU greedy path
716    /// is not available.
717    ///
718    /// Returns the generated token IDs (not including the prompt).
719    #[cfg(all(feature = "metal", target_os = "macos"))]
720    #[tracing::instrument(skip(self, prompt_tokens), fields(prompt_len = prompt_tokens.len()))]
721    pub fn generate_greedy_gpu(
722        &mut self,
723        prompt_tokens: &[u32],
724        max_tokens: usize,
725    ) -> RuntimeResult<Vec<u32>> {
726        if prompt_tokens.is_empty() {
727            return Ok(vec![]);
728        }
729
730        // ═══════════════════════════════════════════════════════
731        // 1. Prefill: batch process all prompt tokens
732        // ═══════════════════════════════════════════════════════
733        let prefill_start = std::time::Instant::now();
734        let last_logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
735        if let Some(m) = &self.metrics {
736            m.prefill_duration_seconds
737                .observe(prefill_start.elapsed().as_secs_f64());
738        }
739
740        // First decode token: argmax from prefill logits
741        let first_token = {
742            let mut best_idx = 0u32;
743            let mut best_val = f32::NEG_INFINITY;
744            for (i, &v) in last_logits.iter().enumerate() {
745                if v > best_val {
746                    best_val = v;
747                    best_idx = i as u32;
748                }
749            }
750            best_idx
751        };
752
753        // ═══════════════════════════════════════════════════════
754        // 2. Decode: speculative greedy with n-gram drafting
755        // ═══════════════════════════════════════════════════════
756        let decode_start = std::time::Instant::now();
757        let mut output_tokens = Vec::with_capacity(max_tokens);
758
759        if first_token == EOS_TOKEN_ID {
760            self.stats.record_request(0);
761            return Ok(vec![]);
762        }
763        output_tokens.push(first_token);
764
765        // N-gram cache for zero-cost draft generation
766        let mut ngram_cache = NgramCache::new();
767        ngram_cache.record(prompt_tokens);
768
769        // Running context: prompt + generated tokens (for n-gram lookups)
770        let mut context: Vec<u32> = prompt_tokens.to_vec();
771        context.push(first_token);
772
773        let speculation_k: usize = 4;
774        let mut spec_attempts: u64 = 0;
775        let mut spec_accepted_total: u64 = 0;
776        let spec_enabled = std::env::var("OXIBONSAI_SPEC")
777            .map(|v| v == "1")
778            .unwrap_or(false);
779        let spec_warmup = 15_usize; // build cache before speculating
780
781        let mut next_token = first_token;
782        let mut pos = prompt_tokens.len() + 1;
783        let max_pos = prompt_tokens.len() + max_tokens;
784
785        while pos < max_pos && output_tokens.len() < max_tokens {
786            let step_start = std::time::Instant::now();
787            let tokens_generated = output_tokens.len();
788
789            // Try n-gram draft — skip warmup phase unless explicitly enabled
790            let draft = if !spec_enabled || tokens_generated < spec_warmup {
791                Vec::new()
792            } else {
793                ngram_cache.draft(&context, speculation_k)
794            };
795
796            // Adaptive: only speculate if recent accuracy > 60%
797            // (batch of 5 costs ~4x single token, need high hit rate)
798            let spec_ok = if spec_attempts >= 5 {
799                let accuracy = spec_accepted_total as f64
800                    / (spec_attempts as f64 * speculation_k as f64).max(1.0);
801                accuracy > 0.6 || spec_attempts % 20 == 0
802            } else {
803                true // optimistic for first 5 attempts
804            };
805
806            if !draft.is_empty() && spec_ok {
807                // ── Speculative path: batch verify ──────────────
808                let mut batch = Vec::with_capacity(1 + draft.len());
809                batch.push(next_token);
810                batch.extend_from_slice(&draft);
811
812                match self
813                    .model
814                    .forward_prefill_verify(&batch, pos - 1, &self.kernel)
815                {
816                    Ok(model_preds) => {
817                        spec_attempts += 1;
818
819                        // Verify draft against model predictions
820                        let mut accepted: usize = 0;
821                        for i in 0..draft.len() {
822                            if i < model_preds.len() && draft[i] == model_preds[i] {
823                                accepted += 1;
824                            } else {
825                                break;
826                            }
827                        }
828                        spec_accepted_total += accepted as u64;
829
830                        // Collect accepted draft tokens + bonus
831                        let mut eos_seen = false;
832                        for &token in draft.iter().take(accepted) {
833                            if token == EOS_TOKEN_ID {
834                                eos_seen = true;
835                                break;
836                            }
837                            output_tokens.push(token);
838                            context.push(token);
839                        }
840
841                        if !eos_seen {
842                            // Bonus: model's prediction at the accept/reject boundary
843                            let bonus = if accepted < model_preds.len() {
844                                model_preds[accepted]
845                            } else {
846                                // All draft tokens matched, take the last prediction
847                                match model_preds.last() {
848                                    Some(&tok) => tok,
849                                    None => break,
850                                }
851                            };
852
853                            if bonus == EOS_TOKEN_ID {
854                                tracing::debug!(pos, accepted, "EOS from speculative bonus");
855                                break;
856                            }
857
858                            output_tokens.push(bonus);
859                            context.push(bonus);
860                            next_token = bonus;
861                            pos += accepted + 1;
862
863                            // Update n-gram cache with the newly accepted window
864                            let window_start = context.len().saturating_sub(accepted + 4);
865                            ngram_cache.record(&context[window_start..]);
866                        } else {
867                            tracing::debug!(pos, accepted, "EOS in draft tokens");
868                            break;
869                        }
870                    }
871                    Err(_e) => {
872                        // Speculative verify failed — fall through to single-token decode
873                        tracing::debug!("speculative verify failed, using single-token decode");
874                        match self.model.forward_greedy_gpu(next_token, pos - 1) {
875                            Ok(token_id) => {
876                                if token_id == EOS_TOKEN_ID {
877                                    tracing::debug!(pos, "EOS token generated (greedy GPU)");
878                                    break;
879                                }
880                                output_tokens.push(token_id);
881                                context.push(token_id);
882                                let window_start = context.len().saturating_sub(3);
883                                ngram_cache.record(&context[window_start..]);
884                                next_token = token_id;
885                                pos += 1;
886                            }
887                            Err(e) => {
888                                tracing::warn!(
889                                    error = %e, pos,
890                                    "greedy GPU path failed, falling back to normal forward"
891                                );
892                                let logits =
893                                    self.model.forward(next_token, pos - 1, &self.kernel)?;
894                                let mut best_idx = 0u32;
895                                let mut best_val = f32::NEG_INFINITY;
896                                for (i, &v) in logits.iter().enumerate() {
897                                    if v > best_val {
898                                        best_val = v;
899                                        best_idx = i as u32;
900                                    }
901                                }
902                                if best_idx == EOS_TOKEN_ID {
903                                    tracing::debug!(pos, "EOS from CPU fallback");
904                                    break;
905                                }
906                                output_tokens.push(best_idx);
907                                context.push(best_idx);
908                                let window_start = context.len().saturating_sub(3);
909                                ngram_cache.record(&context[window_start..]);
910                                next_token = best_idx;
911                                pos += 1;
912                            }
913                        }
914                    }
915                }
916            } else {
917                // ── Single-token decode (no draft or accuracy too low) ──
918                match self.model.forward_greedy_gpu(next_token, pos - 1) {
919                    Ok(token_id) => {
920                        if token_id == EOS_TOKEN_ID {
921                            tracing::debug!(pos, "EOS token generated (greedy GPU)");
922                            break;
923                        }
924                        output_tokens.push(token_id);
925                        context.push(token_id);
926                        let window_start = context.len().saturating_sub(3);
927                        ngram_cache.record(&context[window_start..]);
928                        next_token = token_id;
929                        pos += 1;
930                    }
931                    Err(e) => {
932                        tracing::warn!(
933                            error = %e, pos,
934                            "greedy GPU path failed, falling back to normal forward"
935                        );
936                        let logits = self.model.forward(next_token, pos - 1, &self.kernel)?;
937                        let mut best_idx = 0u32;
938                        let mut best_val = f32::NEG_INFINITY;
939                        for (i, &v) in logits.iter().enumerate() {
940                            if v > best_val {
941                                best_val = v;
942                                best_idx = i as u32;
943                            }
944                        }
945                        if best_idx == EOS_TOKEN_ID {
946                            tracing::debug!(pos, "EOS from CPU fallback");
947                            break;
948                        }
949                        output_tokens.push(best_idx);
950                        context.push(best_idx);
951                        let window_start = context.len().saturating_sub(3);
952                        ngram_cache.record(&context[window_start..]);
953                        next_token = best_idx;
954                        pos += 1;
955                    }
956                }
957            }
958
959            if let Some(m) = &self.metrics {
960                m.decode_token_duration_seconds
961                    .observe(step_start.elapsed().as_secs_f64());
962            }
963
964            // Check for EOS from single-token path
965            if output_tokens.last() == Some(&EOS_TOKEN_ID) {
966                output_tokens.pop(); // Don't include EOS in output
967                break;
968            }
969        }
970
971        // Log speculative decode statistics
972        if spec_attempts > 0 {
973            let avg_accepted = spec_accepted_total as f64 / spec_attempts as f64;
974            let accuracy =
975                spec_accepted_total as f64 / (spec_attempts as f64 * speculation_k as f64).max(1.0);
976            tracing::info!(
977                spec_attempts,
978                spec_accepted_total,
979                avg_accepted = format!("{:.2}", avg_accepted),
980                accuracy = format!("{:.1}%", accuracy * 100.0),
981                "speculative decode stats"
982            );
983        }
984
985        // Record tokens/sec and update memory gauge
986        if let Some(m) = &self.metrics {
987            let decode_elapsed = decode_start.elapsed().as_secs_f64();
988            if decode_elapsed > 0.0 && !output_tokens.is_empty() {
989                let tok_per_sec = output_tokens.len() as f64 / decode_elapsed;
990                m.tokens_per_second.observe(tok_per_sec);
991            }
992            m.tokens_generated_total.inc_by(output_tokens.len() as u64);
993            m.update_memory_from_rss();
994        }
995
996        self.stats.record_request(output_tokens.len());
997
998        tracing::info!(
999            prompt_len = prompt_tokens.len(),
1000            generated = output_tokens.len(),
1001            "greedy GPU generation complete"
1002        );
1003
1004        Ok(output_tokens)
1005    }
1006}
1007
1008impl InferenceEngine<'static> {
1009    /// Load an [`InferenceEngine`] directly from a path to a GGUF file.
1010    ///
1011    /// This is a convenience wrapper intended for server/CLI entry points that
1012    /// need an owned, `'static` engine.  It memory-maps the file, parses the
1013    /// GGUF container, and leaks both allocations so that the borrowed
1014    /// `GgufFile<'a>` lifetime can be promoted to `'static`.
1015    ///
1016    /// The leaked memory is intentional — the engine is expected to live for
1017    /// the process lifetime.  Do not call this in hot-paths.
1018    ///
1019    /// # Errors
1020    ///
1021    /// Returns [`RuntimeError::FileNotFound`] if `path` does not exist.  Other
1022    /// IO / parse / model-init errors propagate through [`RuntimeError`].
1023    pub fn from_gguf_path(
1024        path: impl AsRef<std::path::Path>,
1025        sampling_params: SamplingParams,
1026        seed: u64,
1027        max_seq_len: usize,
1028    ) -> RuntimeResult<Self> {
1029        let path_ref = path.as_ref();
1030        if !path_ref.exists() {
1031            return Err(RuntimeError::FileNotFound {
1032                path: path_ref.display().to_string(),
1033            });
1034        }
1035
1036        // Memory-map and parse, then leak both so the resulting `GgufFile`
1037        // can live for `'static` without RAII concerns.
1038        let mmap = oxibonsai_core::gguf::reader::mmap_gguf_file(path_ref)?;
1039        let mmap: &'static memmap2::Mmap = Box::leak(Box::new(mmap));
1040        let gguf = oxibonsai_core::gguf::reader::GgufFile::parse(mmap)?;
1041        let gguf: &'static oxibonsai_core::gguf::reader::GgufFile<'static> =
1042            Box::leak(Box::new(gguf));
1043
1044        Self::from_gguf(gguf, sampling_params, seed, max_seq_len)
1045    }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050    use super::*;
1051
1052    #[test]
1053    fn engine_creation() {
1054        let config = Qwen3Config::bonsai_8b();
1055        let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1056        assert_eq!(engine.model().config().num_layers, 36);
1057    }
1058
1059    #[test]
1060    fn engine_stats_initial() {
1061        let config = Qwen3Config::bonsai_8b();
1062        let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1063        let stats = engine.stats();
1064        assert_eq!(stats.tokens_generated(), 0);
1065        assert_eq!(stats.requests_completed(), 0);
1066        assert_eq!(stats.active_session_count(), 0);
1067        assert!(stats.uptime_seconds() >= 0.0);
1068        assert!((stats.avg_tokens_per_request() - 0.0).abs() < f64::EPSILON);
1069    }
1070
1071    #[test]
1072    fn engine_stats_record() {
1073        let stats = EngineStats::new();
1074        stats.record_request(10);
1075        stats.record_request(20);
1076        assert_eq!(stats.tokens_generated(), 30);
1077        assert_eq!(stats.requests_completed(), 2);
1078        assert!((stats.avg_tokens_per_request() - 15.0).abs() < f64::EPSILON);
1079    }
1080
1081    #[test]
1082    fn engine_session_tracking() {
1083        let config = Qwen3Config::bonsai_8b();
1084        let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1085        assert_eq!(engine.active_sessions(), 0);
1086        assert_eq!(engine.session_count(), 0);
1087    }
1088
1089    #[test]
1090    fn engine_batch_generate_empty() {
1091        let config = Qwen3Config::bonsai_8b();
1092        let mut engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1093        let results = engine.batch_generate(&[], 10);
1094        assert!(results.is_empty());
1095        assert_eq!(engine.session_count(), 0);
1096    }
1097
1098    #[test]
1099    fn engine_batch_generate_empty_prompts() {
1100        let config = Qwen3Config::bonsai_8b();
1101        let mut engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1102        let prompts = vec![vec![], vec![]];
1103        let results = engine.batch_generate(&prompts, 5);
1104        assert_eq!(results.len(), 2);
1105        for r in &results {
1106            assert!(r.is_ok());
1107        }
1108        // Stats should reflect the completed requests
1109        assert_eq!(engine.stats().requests_completed(), 2);
1110    }
1111
1112    #[test]
1113    fn engine_stats_default() {
1114        let stats = EngineStats::default();
1115        assert_eq!(stats.tokens_generated(), 0);
1116        assert_eq!(stats.requests_completed(), 0);
1117    }
1118}