1use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11
12use oxibonsai_core::config::Qwen3Config;
13use oxibonsai_core::gguf::reader::GgufFile;
14use oxibonsai_kernels::traits::OneBitKernel;
15use oxibonsai_kernels::KernelDispatcher;
16use oxibonsai_model::model::BonsaiModel;
17
18use crate::batch_engine::{self, BatchResult};
19use crate::error::{RuntimeError, RuntimeResult};
20use crate::metrics::InferenceMetrics;
21#[cfg(all(feature = "metal", target_os = "macos"))]
22use crate::ngram_cache::NgramCache;
23use crate::request_id::RequestId;
24use crate::request_metrics::{RequestRateAggregator, RequestRateSnapshot, RequestRateTracker};
25use crate::sampling::{Sampler, SamplingParams};
26
27pub const EOS_TOKEN_ID: u32 = 151645;
29
30#[derive(Debug)]
32pub struct EngineStats {
33 pub total_tokens_generated: AtomicU64,
35 pub total_requests: AtomicU64,
37 pub active_sessions: AtomicUsize,
39 pub start_time: Instant,
41}
42
43impl EngineStats {
44 pub fn new() -> Self {
46 Self {
47 total_tokens_generated: AtomicU64::new(0),
48 total_requests: AtomicU64::new(0),
49 active_sessions: AtomicUsize::new(0),
50 start_time: Instant::now(),
51 }
52 }
53
54 pub fn uptime_seconds(&self) -> f64 {
56 self.start_time.elapsed().as_secs_f64()
57 }
58
59 pub fn record_request(&self, tokens_generated: usize) {
61 self.total_tokens_generated
62 .fetch_add(tokens_generated as u64, Ordering::Relaxed);
63 self.total_requests.fetch_add(1, Ordering::Relaxed);
64 }
65
66 pub fn tokens_generated(&self) -> u64 {
68 self.total_tokens_generated.load(Ordering::Relaxed)
69 }
70
71 pub fn requests_completed(&self) -> u64 {
73 self.total_requests.load(Ordering::Relaxed)
74 }
75
76 pub fn active_session_count(&self) -> usize {
78 self.active_sessions.load(Ordering::Relaxed)
79 }
80
81 pub fn avg_tokens_per_request(&self) -> f64 {
83 let reqs = self.requests_completed();
84 if reqs == 0 {
85 return 0.0;
86 }
87 self.tokens_generated() as f64 / reqs as f64
88 }
89}
90
91impl Default for EngineStats {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97pub struct InferenceEngine<'a> {
99 model: BonsaiModel<'a>,
100 kernel: KernelDispatcher,
101 sampler: Sampler,
102 metrics: Option<Arc<InferenceMetrics>>,
103 stats: Arc<EngineStats>,
104 prefill_token_count: u64,
112 rate_aggregator: Option<Arc<RequestRateAggregator>>,
118}
119
120impl<'a> InferenceEngine<'a> {
121 pub fn new(config: Qwen3Config, sampling_params: SamplingParams, seed: u64) -> Self {
123 let model = BonsaiModel::new(config);
124 let kernel = KernelDispatcher::auto_detect();
125 let sampler = Sampler::new(sampling_params, seed);
126
127 tracing::info!(kernel = kernel.name(), "inference engine initialized");
128
129 Self {
130 model,
131 kernel,
132 sampler,
133 metrics: None,
134 stats: Arc::new(EngineStats::new()),
135 prefill_token_count: 0,
136 rate_aggregator: None,
137 }
138 }
139
140 pub fn from_model(model: BonsaiModel<'a>, sampling_params: SamplingParams, seed: u64) -> Self {
146 Self::from_model_with_kernel(
147 model,
148 KernelDispatcher::auto_detect(),
149 sampling_params,
150 seed,
151 )
152 }
153
154 pub fn from_model_with_kernel(
161 model: BonsaiModel<'a>,
162 kernel: KernelDispatcher,
163 sampling_params: SamplingParams,
164 seed: u64,
165 ) -> Self {
166 let sampler = Sampler::new(sampling_params, seed);
167 Self {
168 model,
169 kernel,
170 sampler,
171 metrics: None,
172 stats: Arc::new(EngineStats::new()),
173 prefill_token_count: 0,
174 rate_aggregator: None,
175 }
176 }
177
178 pub fn from_gguf(
180 gguf: &'a GgufFile<'a>,
181 sampling_params: SamplingParams,
182 seed: u64,
183 max_seq_len: usize,
184 ) -> RuntimeResult<Self> {
185 let mut model = BonsaiModel::from_gguf(gguf, max_seq_len)?;
186 let kernel = KernelDispatcher::auto_detect();
187
188 model.upload_weights_to_gpu(&kernel);
190
191 #[cfg(all(feature = "metal", target_os = "macos"))]
193 {
194 tracing::info!("pre-building GPU weight cache");
195 model.get_or_create_gpu_cache().map_err(|e| {
196 RuntimeError::Model(oxibonsai_model::error::ModelError::Internal(format!(
197 "GPU weight cache init: {e}"
198 )))
199 })?;
200 }
201
202 #[cfg(all(
217 feature = "native-cuda",
218 not(all(feature = "metal", target_os = "macos")),
219 any(target_os = "linux", target_os = "windows")
220 ))]
221 {
222 tracing::info!("CUDA warmup: pre-capturing driver graph + prefill modules");
223 let _ = model.forward(0, 0, &kernel);
225 let _ = model.forward_prefill(&[0u32; 17], 0, &kernel);
233 tracing::info!("CUDA warmup complete");
234 }
235
236 let sampler = Sampler::new(sampling_params, seed);
237
238 tracing::info!(kernel = kernel.name(), "inference engine loaded from GGUF");
239
240 Ok(Self {
241 model,
242 kernel,
243 sampler,
244 metrics: None,
245 stats: Arc::new(EngineStats::new()),
246 prefill_token_count: 0,
247 rate_aggregator: None,
248 })
249 }
250
251 pub fn set_metrics(&mut self, metrics: Arc<InferenceMetrics>) {
253 self.metrics = Some(metrics);
254 }
255
256 pub fn set_rate_aggregator(&mut self, aggregator: Arc<RequestRateAggregator>) {
264 self.rate_aggregator = Some(aggregator);
265 }
266
267 pub fn rate_aggregator(&self) -> Option<&Arc<RequestRateAggregator>> {
269 self.rate_aggregator.as_ref()
270 }
271
272 pub fn model(&self) -> &BonsaiModel<'a> {
274 &self.model
275 }
276
277 pub fn model_mut(&mut self) -> &mut BonsaiModel<'a> {
282 &mut self.model
283 }
284
285 pub fn kernel(&self) -> &KernelDispatcher {
287 &self.kernel
288 }
289
290 pub fn prefill_from_pos(
299 &mut self,
300 prompt_tokens: &[u32],
301 pos_start: usize,
302 ) -> RuntimeResult<Vec<f32>> {
303 let logits = self
304 .model
305 .forward_prefill(prompt_tokens, pos_start, &self.kernel)?;
306 self.prefill_token_count = self
307 .prefill_token_count
308 .saturating_add(prompt_tokens.len() as u64);
309 Ok(logits)
310 }
311
312 pub fn decode_step(&mut self, token: u32, pos: usize) -> RuntimeResult<Vec<f32>> {
314 Ok(self.model.forward(token, pos, &self.kernel)?)
315 }
316
317 pub fn sample(&mut self, logits: &[f32]) -> RuntimeResult<u32> {
319 self.sampler.sample(logits)
320 }
321
322 pub fn prefill_token_count(&self) -> u64 {
325 self.prefill_token_count
326 }
327
328 pub fn reset(&mut self) {
330 self.model.reset();
331 }
332
333 pub fn stats(&self) -> &Arc<EngineStats> {
335 &self.stats
336 }
337
338 pub fn active_sessions(&self) -> usize {
340 self.stats.active_session_count()
341 }
342
343 pub fn session_count(&self) -> u64 {
345 self.stats.requests_completed()
346 }
347
348 pub fn batch_generate(
352 &mut self,
353 prompts: &[Vec<u32>],
354 max_tokens: usize,
355 ) -> Vec<RuntimeResult<BatchResult>> {
356 self.stats.active_sessions.fetch_add(1, Ordering::Relaxed);
357
358 let results = batch_engine::batch_generate(self, prompts, max_tokens);
359
360 for br in results.iter().flatten() {
362 self.stats.record_request(br.generated_tokens.len());
363 }
364
365 self.stats.active_sessions.fetch_sub(1, Ordering::Relaxed);
366
367 results
368 }
369
370 #[tracing::instrument(skip(self, prompt_tokens), fields(prompt_len = prompt_tokens.len()))]
376 pub fn generate(
377 &mut self,
378 prompt_tokens: &[u32],
379 max_tokens: usize,
380 ) -> RuntimeResult<Vec<u32>> {
381 if prompt_tokens.is_empty() {
382 return Ok(vec![]);
383 }
384
385 let prefill_start = std::time::Instant::now();
389 let mut last_logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
390 if let Some(m) = &self.metrics {
391 m.prefill_duration_seconds
392 .observe(prefill_start.elapsed().as_secs_f64());
393 }
394
395 let decode_start = std::time::Instant::now();
399 let mut output_tokens = Vec::with_capacity(max_tokens);
400
401 for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
402 let step_start = std::time::Instant::now();
403
404 let next_token = self.sampler.sample(&last_logits)?;
406
407 if next_token == EOS_TOKEN_ID {
409 tracing::debug!(pos, "EOS token generated");
410 break;
411 }
412
413 output_tokens.push(next_token);
414
415 last_logits = self.model.forward(next_token, pos, &self.kernel)?;
417
418 if let Some(m) = &self.metrics {
419 m.decode_token_duration_seconds
420 .observe(step_start.elapsed().as_secs_f64());
421 }
422 }
423
424 if let Some(m) = &self.metrics {
426 let decode_elapsed = decode_start.elapsed().as_secs_f64();
427 if decode_elapsed > 0.0 && !output_tokens.is_empty() {
428 let tok_per_sec = output_tokens.len() as f64 / decode_elapsed;
429 m.tokens_per_second.observe(tok_per_sec);
430 }
431 m.tokens_generated_total.inc_by(output_tokens.len() as u64);
432 m.update_memory_from_rss();
433 }
434
435 self.stats.record_request(output_tokens.len());
437
438 tracing::info!(
439 prompt_len = prompt_tokens.len(),
440 generated = output_tokens.len(),
441 "generation complete"
442 );
443
444 Ok(output_tokens)
445 }
446
447 #[tracing::instrument(skip(self, prompt_tokens, tracker), fields(prompt_len = prompt_tokens.len()))]
459 pub fn generate_tracked(
460 &mut self,
461 prompt_tokens: &[u32],
462 max_tokens: usize,
463 tracker: &mut RequestRateTracker,
464 ) -> RuntimeResult<Vec<u32>> {
465 if prompt_tokens.is_empty() {
466 return Ok(vec![]);
467 }
468 tracker.record_admission();
469
470 let prefill_start = std::time::Instant::now();
471 let mut last_logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
472 if let Some(m) = &self.metrics {
473 m.prefill_duration_seconds
474 .observe(prefill_start.elapsed().as_secs_f64());
475 }
476
477 let decode_start = std::time::Instant::now();
478 let mut output_tokens = Vec::with_capacity(max_tokens);
479 let mut first_token_recorded = false;
480
481 for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
482 let step_start = std::time::Instant::now();
483 let next_token = self.sampler.sample(&last_logits)?;
484 if next_token == EOS_TOKEN_ID {
485 tracing::debug!(pos, "EOS token generated");
486 break;
487 }
488 output_tokens.push(next_token);
489 if !first_token_recorded {
490 tracker.record_first_token();
491 first_token_recorded = true;
492 } else {
493 tracker.record_token();
494 }
495 last_logits = self.model.forward(next_token, pos, &self.kernel)?;
496
497 if let Some(m) = &self.metrics {
498 m.decode_token_duration_seconds
499 .observe(step_start.elapsed().as_secs_f64());
500 }
501 }
502
503 if let Some(m) = &self.metrics {
504 let decode_elapsed = decode_start.elapsed().as_secs_f64();
505 if decode_elapsed > 0.0 && !output_tokens.is_empty() {
506 let tok_per_sec = output_tokens.len() as f64 / decode_elapsed;
507 m.tokens_per_second.observe(tok_per_sec);
508 }
509 m.tokens_generated_total.inc_by(output_tokens.len() as u64);
510 m.update_memory_from_rss();
511 }
512 self.stats.record_request(output_tokens.len());
513
514 if let Some(agg) = &self.rate_aggregator {
515 let snap: RequestRateSnapshot = tracker.snapshot();
516 agg.record(snap);
517 }
518
519 tracing::info!(
520 prompt_len = prompt_tokens.len(),
521 generated = output_tokens.len(),
522 "tracked generation complete"
523 );
524
525 Ok(output_tokens)
526 }
527
528 pub fn generate_with_request_id(
536 &mut self,
537 request_id: RequestId,
538 prompt_tokens: &[u32],
539 max_tokens: usize,
540 ) -> RuntimeResult<(Vec<u32>, RequestRateTracker)> {
541 let span = tracing::info_span!("generate_request", request_id = %request_id);
542 let _enter = span.enter();
543 let mut tracker = RequestRateTracker::new();
544 let tokens = self.generate_tracked(prompt_tokens, max_tokens, &mut tracker)?;
545 Ok((tokens, tracker))
546 }
547
548 pub fn generate_with_seed(
554 &mut self,
555 prompt_tokens: &[u32],
556 max_tokens: usize,
557 seed: u64,
558 params: &crate::sampling::SamplingParams,
559 ) -> RuntimeResult<Vec<u32>> {
560 let old_sampler = std::mem::replace(
562 &mut self.sampler,
563 crate::sampling::Sampler::new(params.clone(), seed),
564 );
565 let result = self.generate(prompt_tokens, max_tokens);
566 self.sampler = old_sampler;
568 result
569 }
570
571 #[cfg(not(target_arch = "wasm32"))]
576 #[tracing::instrument(skip(self, prompt_tokens, tx), fields(prompt_len = prompt_tokens.len()))]
577 pub fn generate_streaming(
578 &mut self,
579 prompt_tokens: &[u32],
580 max_tokens: usize,
581 tx: &tokio::sync::mpsc::UnboundedSender<u32>,
582 ) -> RuntimeResult<usize> {
583 if prompt_tokens.is_empty() {
584 return Ok(0);
585 }
586
587 let prefill_start = std::time::Instant::now();
589 let mut logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
590 if let Some(m) = &self.metrics {
591 m.prefill_duration_seconds
592 .observe(prefill_start.elapsed().as_secs_f64());
593 }
594
595 let decode_start = std::time::Instant::now();
596 let mut generated = 0;
597
598 for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
599 let step_start = std::time::Instant::now();
600 let next_token = self.sampler.sample(&logits)?;
601
602 if next_token == EOS_TOKEN_ID {
603 tracing::debug!(pos, "EOS token generated (streaming)");
604 break;
605 }
606
607 if tx.send(next_token).is_err() {
609 tracing::debug!(pos, "receiver dropped, stopping generation");
610 break;
611 }
612
613 logits = self.model.forward(next_token, pos, &self.kernel)?;
614 generated += 1;
615
616 if let Some(m) = &self.metrics {
617 m.decode_token_duration_seconds
618 .observe(step_start.elapsed().as_secs_f64());
619 }
620 }
621
622 if let Some(m) = &self.metrics {
624 let decode_elapsed = decode_start.elapsed().as_secs_f64();
625 if decode_elapsed > 0.0 && generated > 0 {
626 let tok_per_sec = generated as f64 / decode_elapsed;
627 m.tokens_per_second.observe(tok_per_sec);
628 }
629 m.tokens_generated_total.inc_by(generated as u64);
630 m.update_memory_from_rss();
631 }
632
633 tracing::info!(
634 prompt_len = prompt_tokens.len(),
635 generated,
636 "streaming generation complete"
637 );
638
639 Ok(generated)
640 }
641
642 #[tracing::instrument(skip(self, prompt_tokens, tx), fields(prompt_len = prompt_tokens.len()))]
647 pub fn generate_streaming_sync(
648 &mut self,
649 prompt_tokens: &[u32],
650 max_tokens: usize,
651 tx: &std::sync::mpsc::Sender<u32>,
652 ) -> RuntimeResult<usize> {
653 if prompt_tokens.is_empty() {
654 return Ok(0);
655 }
656
657 let prefill_start = std::time::Instant::now();
659 let mut logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
660 if let Some(m) = &self.metrics {
661 m.prefill_duration_seconds
662 .observe(prefill_start.elapsed().as_secs_f64());
663 }
664
665 let decode_start = std::time::Instant::now();
666 let mut generated = 0;
667
668 for (pos, _) in (prompt_tokens.len()..).zip(0..max_tokens) {
669 let step_start = std::time::Instant::now();
670
671 let next_token = self.sampler.sample(&logits)?;
672
673 if next_token == EOS_TOKEN_ID {
674 tracing::debug!(pos, "EOS token generated (streaming_sync)");
675 break;
676 }
677
678 if tx.send(next_token).is_err() {
679 tracing::debug!(pos, "receiver dropped, stopping generation");
680 break;
681 }
682
683 logits = self.model.forward(next_token, pos, &self.kernel)?;
684 generated += 1;
685
686 if let Some(m) = &self.metrics {
687 m.decode_token_duration_seconds
688 .observe(step_start.elapsed().as_secs_f64());
689 }
690 }
691
692 if let Some(m) = &self.metrics {
693 let decode_elapsed = decode_start.elapsed().as_secs_f64();
694 if decode_elapsed > 0.0 && generated > 0 {
695 let tok_per_sec = generated as f64 / decode_elapsed;
696 m.tokens_per_second.observe(tok_per_sec);
697 }
698 m.tokens_generated_total.inc_by(generated as u64);
699 m.update_memory_from_rss();
700 }
701
702 tracing::info!(
703 prompt_len = prompt_tokens.len(),
704 generated,
705 "streaming sync generation complete"
706 );
707
708 Ok(generated)
709 }
710
711 #[cfg(all(feature = "metal", target_os = "macos"))]
720 #[tracing::instrument(skip(self, prompt_tokens), fields(prompt_len = prompt_tokens.len()))]
721 pub fn generate_greedy_gpu(
722 &mut self,
723 prompt_tokens: &[u32],
724 max_tokens: usize,
725 ) -> RuntimeResult<Vec<u32>> {
726 if prompt_tokens.is_empty() {
727 return Ok(vec![]);
728 }
729
730 let prefill_start = std::time::Instant::now();
734 let last_logits = self.model.forward_prefill(prompt_tokens, 0, &self.kernel)?;
735 if let Some(m) = &self.metrics {
736 m.prefill_duration_seconds
737 .observe(prefill_start.elapsed().as_secs_f64());
738 }
739
740 let first_token = {
742 let mut best_idx = 0u32;
743 let mut best_val = f32::NEG_INFINITY;
744 for (i, &v) in last_logits.iter().enumerate() {
745 if v > best_val {
746 best_val = v;
747 best_idx = i as u32;
748 }
749 }
750 best_idx
751 };
752
753 let decode_start = std::time::Instant::now();
757 let mut output_tokens = Vec::with_capacity(max_tokens);
758
759 if first_token == EOS_TOKEN_ID {
760 self.stats.record_request(0);
761 return Ok(vec![]);
762 }
763 output_tokens.push(first_token);
764
765 let mut ngram_cache = NgramCache::new();
767 ngram_cache.record(prompt_tokens);
768
769 let mut context: Vec<u32> = prompt_tokens.to_vec();
771 context.push(first_token);
772
773 let speculation_k: usize = 4;
774 let mut spec_attempts: u64 = 0;
775 let mut spec_accepted_total: u64 = 0;
776 let spec_enabled = std::env::var("OXIBONSAI_SPEC")
777 .map(|v| v == "1")
778 .unwrap_or(false);
779 let spec_warmup = 15_usize; let mut next_token = first_token;
782 let mut pos = prompt_tokens.len() + 1;
783 let max_pos = prompt_tokens.len() + max_tokens;
784
785 while pos < max_pos && output_tokens.len() < max_tokens {
786 let step_start = std::time::Instant::now();
787 let tokens_generated = output_tokens.len();
788
789 let draft = if !spec_enabled || tokens_generated < spec_warmup {
791 Vec::new()
792 } else {
793 ngram_cache.draft(&context, speculation_k)
794 };
795
796 let spec_ok = if spec_attempts >= 5 {
799 let accuracy = spec_accepted_total as f64
800 / (spec_attempts as f64 * speculation_k as f64).max(1.0);
801 accuracy > 0.6 || spec_attempts % 20 == 0
802 } else {
803 true };
805
806 if !draft.is_empty() && spec_ok {
807 let mut batch = Vec::with_capacity(1 + draft.len());
809 batch.push(next_token);
810 batch.extend_from_slice(&draft);
811
812 match self
813 .model
814 .forward_prefill_verify(&batch, pos - 1, &self.kernel)
815 {
816 Ok(model_preds) => {
817 spec_attempts += 1;
818
819 let mut accepted: usize = 0;
821 for i in 0..draft.len() {
822 if i < model_preds.len() && draft[i] == model_preds[i] {
823 accepted += 1;
824 } else {
825 break;
826 }
827 }
828 spec_accepted_total += accepted as u64;
829
830 let mut eos_seen = false;
832 for &token in draft.iter().take(accepted) {
833 if token == EOS_TOKEN_ID {
834 eos_seen = true;
835 break;
836 }
837 output_tokens.push(token);
838 context.push(token);
839 }
840
841 if !eos_seen {
842 let bonus = if accepted < model_preds.len() {
844 model_preds[accepted]
845 } else {
846 match model_preds.last() {
848 Some(&tok) => tok,
849 None => break,
850 }
851 };
852
853 if bonus == EOS_TOKEN_ID {
854 tracing::debug!(pos, accepted, "EOS from speculative bonus");
855 break;
856 }
857
858 output_tokens.push(bonus);
859 context.push(bonus);
860 next_token = bonus;
861 pos += accepted + 1;
862
863 let window_start = context.len().saturating_sub(accepted + 4);
865 ngram_cache.record(&context[window_start..]);
866 } else {
867 tracing::debug!(pos, accepted, "EOS in draft tokens");
868 break;
869 }
870 }
871 Err(_e) => {
872 tracing::debug!("speculative verify failed, using single-token decode");
874 match self.model.forward_greedy_gpu(next_token, pos - 1) {
875 Ok(token_id) => {
876 if token_id == EOS_TOKEN_ID {
877 tracing::debug!(pos, "EOS token generated (greedy GPU)");
878 break;
879 }
880 output_tokens.push(token_id);
881 context.push(token_id);
882 let window_start = context.len().saturating_sub(3);
883 ngram_cache.record(&context[window_start..]);
884 next_token = token_id;
885 pos += 1;
886 }
887 Err(e) => {
888 tracing::warn!(
889 error = %e, pos,
890 "greedy GPU path failed, falling back to normal forward"
891 );
892 let logits =
893 self.model.forward(next_token, pos - 1, &self.kernel)?;
894 let mut best_idx = 0u32;
895 let mut best_val = f32::NEG_INFINITY;
896 for (i, &v) in logits.iter().enumerate() {
897 if v > best_val {
898 best_val = v;
899 best_idx = i as u32;
900 }
901 }
902 if best_idx == EOS_TOKEN_ID {
903 tracing::debug!(pos, "EOS from CPU fallback");
904 break;
905 }
906 output_tokens.push(best_idx);
907 context.push(best_idx);
908 let window_start = context.len().saturating_sub(3);
909 ngram_cache.record(&context[window_start..]);
910 next_token = best_idx;
911 pos += 1;
912 }
913 }
914 }
915 }
916 } else {
917 match self.model.forward_greedy_gpu(next_token, pos - 1) {
919 Ok(token_id) => {
920 if token_id == EOS_TOKEN_ID {
921 tracing::debug!(pos, "EOS token generated (greedy GPU)");
922 break;
923 }
924 output_tokens.push(token_id);
925 context.push(token_id);
926 let window_start = context.len().saturating_sub(3);
927 ngram_cache.record(&context[window_start..]);
928 next_token = token_id;
929 pos += 1;
930 }
931 Err(e) => {
932 tracing::warn!(
933 error = %e, pos,
934 "greedy GPU path failed, falling back to normal forward"
935 );
936 let logits = self.model.forward(next_token, pos - 1, &self.kernel)?;
937 let mut best_idx = 0u32;
938 let mut best_val = f32::NEG_INFINITY;
939 for (i, &v) in logits.iter().enumerate() {
940 if v > best_val {
941 best_val = v;
942 best_idx = i as u32;
943 }
944 }
945 if best_idx == EOS_TOKEN_ID {
946 tracing::debug!(pos, "EOS from CPU fallback");
947 break;
948 }
949 output_tokens.push(best_idx);
950 context.push(best_idx);
951 let window_start = context.len().saturating_sub(3);
952 ngram_cache.record(&context[window_start..]);
953 next_token = best_idx;
954 pos += 1;
955 }
956 }
957 }
958
959 if let Some(m) = &self.metrics {
960 m.decode_token_duration_seconds
961 .observe(step_start.elapsed().as_secs_f64());
962 }
963
964 if output_tokens.last() == Some(&EOS_TOKEN_ID) {
966 output_tokens.pop(); break;
968 }
969 }
970
971 if spec_attempts > 0 {
973 let avg_accepted = spec_accepted_total as f64 / spec_attempts as f64;
974 let accuracy =
975 spec_accepted_total as f64 / (spec_attempts as f64 * speculation_k as f64).max(1.0);
976 tracing::info!(
977 spec_attempts,
978 spec_accepted_total,
979 avg_accepted = format!("{:.2}", avg_accepted),
980 accuracy = format!("{:.1}%", accuracy * 100.0),
981 "speculative decode stats"
982 );
983 }
984
985 if let Some(m) = &self.metrics {
987 let decode_elapsed = decode_start.elapsed().as_secs_f64();
988 if decode_elapsed > 0.0 && !output_tokens.is_empty() {
989 let tok_per_sec = output_tokens.len() as f64 / decode_elapsed;
990 m.tokens_per_second.observe(tok_per_sec);
991 }
992 m.tokens_generated_total.inc_by(output_tokens.len() as u64);
993 m.update_memory_from_rss();
994 }
995
996 self.stats.record_request(output_tokens.len());
997
998 tracing::info!(
999 prompt_len = prompt_tokens.len(),
1000 generated = output_tokens.len(),
1001 "greedy GPU generation complete"
1002 );
1003
1004 Ok(output_tokens)
1005 }
1006}
1007
1008impl InferenceEngine<'static> {
1009 pub fn from_gguf_path(
1024 path: impl AsRef<std::path::Path>,
1025 sampling_params: SamplingParams,
1026 seed: u64,
1027 max_seq_len: usize,
1028 ) -> RuntimeResult<Self> {
1029 let path_ref = path.as_ref();
1030 if !path_ref.exists() {
1031 return Err(RuntimeError::FileNotFound {
1032 path: path_ref.display().to_string(),
1033 });
1034 }
1035
1036 let mmap = oxibonsai_core::gguf::reader::mmap_gguf_file(path_ref)?;
1039 let mmap: &'static memmap2::Mmap = Box::leak(Box::new(mmap));
1040 let gguf = oxibonsai_core::gguf::reader::GgufFile::parse(mmap)?;
1041 let gguf: &'static oxibonsai_core::gguf::reader::GgufFile<'static> =
1042 Box::leak(Box::new(gguf));
1043
1044 Self::from_gguf(gguf, sampling_params, seed, max_seq_len)
1045 }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use super::*;
1051
1052 #[test]
1053 fn engine_creation() {
1054 let config = Qwen3Config::bonsai_8b();
1055 let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1056 assert_eq!(engine.model().config().num_layers, 36);
1057 }
1058
1059 #[test]
1060 fn engine_stats_initial() {
1061 let config = Qwen3Config::bonsai_8b();
1062 let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1063 let stats = engine.stats();
1064 assert_eq!(stats.tokens_generated(), 0);
1065 assert_eq!(stats.requests_completed(), 0);
1066 assert_eq!(stats.active_session_count(), 0);
1067 assert!(stats.uptime_seconds() >= 0.0);
1068 assert!((stats.avg_tokens_per_request() - 0.0).abs() < f64::EPSILON);
1069 }
1070
1071 #[test]
1072 fn engine_stats_record() {
1073 let stats = EngineStats::new();
1074 stats.record_request(10);
1075 stats.record_request(20);
1076 assert_eq!(stats.tokens_generated(), 30);
1077 assert_eq!(stats.requests_completed(), 2);
1078 assert!((stats.avg_tokens_per_request() - 15.0).abs() < f64::EPSILON);
1079 }
1080
1081 #[test]
1082 fn engine_session_tracking() {
1083 let config = Qwen3Config::bonsai_8b();
1084 let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1085 assert_eq!(engine.active_sessions(), 0);
1086 assert_eq!(engine.session_count(), 0);
1087 }
1088
1089 #[test]
1090 fn engine_batch_generate_empty() {
1091 let config = Qwen3Config::bonsai_8b();
1092 let mut engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1093 let results = engine.batch_generate(&[], 10);
1094 assert!(results.is_empty());
1095 assert_eq!(engine.session_count(), 0);
1096 }
1097
1098 #[test]
1099 fn engine_batch_generate_empty_prompts() {
1100 let config = Qwen3Config::bonsai_8b();
1101 let mut engine = InferenceEngine::new(config, SamplingParams::default(), 42);
1102 let prompts = vec![vec![], vec![]];
1103 let results = engine.batch_generate(&prompts, 5);
1104 assert_eq!(results.len(), 2);
1105 for r in &results {
1106 assert!(r.is_ok());
1107 }
1108 assert_eq!(engine.stats().requests_completed(), 2);
1110 }
1111
1112 #[test]
1113 fn engine_stats_default() {
1114 let stats = EngineStats::default();
1115 assert_eq!(stats.tokens_generated(), 0);
1116 assert_eq!(stats.requests_completed(), 0);
1117 }
1118}