Skip to main content

oxibonsai_runtime/
pipeline.rs

1//! High-level inference pipeline API for OxiBonsai.
2//!
3//! The pipeline composes token healing, context management, sampling strategies,
4//! beam search, and constrained decoding into a single fluent builder that
5//! produces a configured [`InferencePipeline`] ready to run.
6//!
7//! ## Quick Start
8//!
9//! ```rust
10//! use oxibonsai_runtime::pipeline::{PipelineBuilder, greedy_pipeline};
11//! use oxibonsai_runtime::context_manager::TruncationStrategy;
12//!
13//! // Pre-built convenience preset
14//! let pipeline = greedy_pipeline(32);
15//! assert_eq!(pipeline.max_tokens(), 32);
16//! assert!(!pipeline.has_healing());
17//!
18//! // Custom pipeline via builder
19//! use oxibonsai_runtime::token_healing::TokenHealingConfig;
20//! let custom = PipelineBuilder::new()
21//!     .max_tokens(128)
22//!     .with_token_healing(TokenHealingConfig::default())
23//!     .stop_on(vec!["<|end|>".to_string()])
24//!     .build();
25//! assert!(custom.has_healing());
26//! assert_eq!(custom.stop_sequences(), &["<|end|>"]);
27//! ```
28
29use std::time::Instant;
30
31use crate::beam_search::{BeamSearchConfig, BeamSearchEngine};
32use crate::constrained_decoding::TokenConstraint;
33use crate::context_manager::{ContextWindow, TruncationStrategy};
34use crate::engine::InferenceEngine;
35use crate::sampling_advanced::{LcgRng, SamplerChain, SamplerStep};
36use crate::token_healing::{TokenHealer, TokenHealingConfig};
37
38// ─────────────────────────────────────────────────────────────────────────────
39// GenerationStrategy
40// ─────────────────────────────────────────────────────────────────────────────
41
42/// How the pipeline generates tokens at each step.
43pub enum GenerationStrategy {
44    /// Standard autoregressive sampling via a composable sampler chain.
45    Sampling(SamplerChain),
46    /// Beam search — deterministic search over the top-`beam_width` candidates.
47    BeamSearch(BeamSearchConfig),
48    /// Greedy decoding — always pick the highest-logit token.
49    Greedy,
50}
51
52// ─────────────────────────────────────────────────────────────────────────────
53// StopReason
54// ─────────────────────────────────────────────────────────────────────────────
55
56/// Why generation terminated.
57#[derive(Debug, Clone, PartialEq)]
58pub enum StopReason {
59    /// The `max_tokens` budget was exhausted.
60    MaxTokens,
61    /// A user-supplied stop sequence was encountered in the output.
62    StopSequence(String),
63    /// The model emitted an end-of-sequence token.
64    EndOfSequence,
65    /// The active [`TokenConstraint`] reported completion.
66    ConstraintComplete,
67}
68
69// ─────────────────────────────────────────────────────────────────────────────
70// PipelineOutput
71// ─────────────────────────────────────────────────────────────────────────────
72
73/// The result of a complete pipeline run.
74#[derive(Debug)]
75pub struct PipelineOutput {
76    /// Decoded text of the generated tokens.
77    ///
78    /// In the absence of a real tokenizer the token IDs are serialised as
79    /// space-separated decimal strings.
80    pub text: String,
81    /// Generated token IDs (not including the prompt).
82    pub token_ids: Vec<u32>,
83    /// Number of prompt tokens (after healing/context management).
84    pub prompt_tokens: usize,
85    /// Number of generated (completion) tokens.
86    pub completion_tokens: usize,
87    /// Reason generation ended.
88    pub stop_reason: StopReason,
89    /// Whether token healing was applied and changed the prompt.
90    pub healing_applied: bool,
91    /// Wall-clock time for the entire pipeline run in milliseconds.
92    pub elapsed_ms: u64,
93}
94
95// ─────────────────────────────────────────────────────────────────────────────
96// PipelineConfig  (private)
97// ─────────────────────────────────────────────────────────────────────────────
98
99struct PipelineConfig {
100    max_tokens: usize,
101    strategy: GenerationStrategy,
102    healing_config: Option<TokenHealingConfig>,
103    constraint: Option<Box<dyn TokenConstraint>>,
104    context_max_tokens: usize,
105    truncation: TruncationStrategy,
106    stop_sequences: Vec<String>,
107    /// Stored for reproducibility and future use by strategies that need a
108    /// standalone RNG (e.g. beam search with stochastic expansion).
109    #[allow(dead_code)]
110    seed: u64,
111}
112
113// ─────────────────────────────────────────────────────────────────────────────
114// PipelineBuilder
115// ─────────────────────────────────────────────────────────────────────────────
116
117/// Builder that composes all inference options into an [`InferencePipeline`].
118pub struct PipelineBuilder {
119    max_tokens: usize,
120    strategy: Option<GenerationStrategy>,
121    healing_config: Option<TokenHealingConfig>,
122    constraint: Option<Box<dyn TokenConstraint>>,
123    context_max_tokens: usize,
124    truncation: TruncationStrategy,
125    stop_sequences: Vec<String>,
126    seed: u64,
127}
128
129impl Default for PipelineBuilder {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl PipelineBuilder {
136    /// Create a new builder with sensible defaults.
137    ///
138    /// Defaults:
139    /// - `max_tokens` = 256
140    /// - strategy = `Greedy`
141    /// - no healing, no constraint
142    /// - `context_max_tokens` = 2048, `TruncationStrategy::TruncateLeft`
143    /// - no stop sequences
144    /// - `seed` = 0
145    pub fn new() -> Self {
146        Self {
147            max_tokens: 256,
148            strategy: None,
149            healing_config: None,
150            constraint: None,
151            context_max_tokens: 2048,
152            truncation: TruncationStrategy::TruncateLeft,
153            stop_sequences: Vec::new(),
154            seed: 0,
155        }
156    }
157
158    /// Set the maximum number of tokens to generate.
159    pub fn max_tokens(mut self, n: usize) -> Self {
160        self.max_tokens = n;
161        self
162    }
163
164    /// Use greedy (argmax) decoding.
165    pub fn greedy(mut self) -> Self {
166        self.strategy = Some(GenerationStrategy::Greedy);
167        self
168    }
169
170    /// Use a [`SamplerChain`] for token selection.
171    pub fn with_sampling(mut self, chain: SamplerChain) -> Self {
172        self.strategy = Some(GenerationStrategy::Sampling(chain));
173        self
174    }
175
176    /// Use beam search with the supplied configuration.
177    pub fn with_beam_search(mut self, config: BeamSearchConfig) -> Self {
178        self.strategy = Some(GenerationStrategy::BeamSearch(config));
179        self
180    }
181
182    /// Enable token healing with the supplied configuration.
183    pub fn with_token_healing(mut self, config: TokenHealingConfig) -> Self {
184        self.healing_config = Some(config);
185        self
186    }
187
188    /// Attach a token constraint (e.g. JSON or regex).
189    pub fn with_constraint(mut self, c: Box<dyn TokenConstraint>) -> Self {
190        self.constraint = Some(c);
191        self
192    }
193
194    /// Stop generation when any of the given string sequences appear in the output.
195    pub fn stop_on(mut self, sequences: Vec<String>) -> Self {
196        self.stop_sequences = sequences;
197        self
198    }
199
200    /// Configure the context window size and truncation strategy.
201    pub fn context_window(mut self, max_tokens: usize, strategy: TruncationStrategy) -> Self {
202        self.context_max_tokens = max_tokens;
203        self.truncation = strategy;
204        self
205    }
206
207    /// Set the random seed used by sampling strategies.
208    pub fn seed(mut self, s: u64) -> Self {
209        self.seed = s;
210        self
211    }
212
213    /// Consume the builder and produce an [`InferencePipeline`].
214    pub fn build(self) -> InferencePipeline {
215        let strategy = self.strategy.unwrap_or(GenerationStrategy::Greedy);
216        InferencePipeline {
217            config: PipelineConfig {
218                max_tokens: self.max_tokens,
219                strategy,
220                healing_config: self.healing_config,
221                constraint: self.constraint,
222                context_max_tokens: self.context_max_tokens,
223                truncation: self.truncation,
224                stop_sequences: self.stop_sequences,
225                seed: self.seed,
226            },
227        }
228    }
229}
230
231// ─────────────────────────────────────────────────────────────────────────────
232// InferencePipeline
233// ─────────────────────────────────────────────────────────────────────────────
234
235/// A fully configured inference pipeline.
236///
237/// Obtain one via [`PipelineBuilder`] or one of the convenience constructors
238/// ([`chat_pipeline`], [`code_pipeline`], [`greedy_pipeline`]).
239pub struct InferencePipeline {
240    config: PipelineConfig,
241}
242
243impl InferencePipeline {
244    /// Run the pipeline against the supplied engine.
245    ///
246    /// The pipeline:
247    ///
248    /// 1. Applies token healing to the prompt (if configured).
249    /// 2. Trims the prompt to `context_max_tokens` using the configured truncation.
250    /// 3. Generates tokens according to the selected strategy.
251    /// 4. Stops at `max_tokens`, an EOS token, a stop sequence, or constraint
252    ///    completion — whichever comes first.
253    ///
254    /// Because the engine API works with raw token IDs (no vocabulary metadata is
255    /// available at this layer), the `text` field of the returned [`PipelineOutput`]
256    /// encodes token IDs as space-separated decimal strings.
257    pub fn run(
258        &mut self,
259        prompt_token_ids: Vec<u32>,
260        engine: &mut InferenceEngine,
261    ) -> PipelineOutput {
262        let wall_start = Instant::now();
263
264        // ── 1. Token healing ────────────────────────────────────────────────
265        let (healed_prompt, healing_applied) =
266            if let Some(ref healing_cfg) = self.config.healing_config {
267                let healer = TokenHealer::new(healing_cfg.clone());
268                // We cannot call the real model during healing without knowing
269                // the vocab size, so we use a conservative heuristic: healing
270                // is applied via a forward pass on the prefix.
271                // For now, with no vocab_size metadata on the engine, we skip
272                // the logit query and return unchanged — healing can only fire
273                // when the caller supplies a vocab-aware callback.  The
274                // HealingDecoder is the richer entry point for that use case.
275                let result = healer.heal(&prompt_token_ids, 0, |_prefix| Vec::new());
276                let changed = result.changed;
277                (result.healed_tokens, changed)
278            } else {
279                (prompt_token_ids, false)
280            };
281
282        // ── 2. Context window management ────────────────────────────────────
283        let mut window = ContextWindow::new(self.config.context_max_tokens, self.config.truncation);
284        window.append(&healed_prompt);
285        let context_tokens = window.tokens();
286        let prompt_tokens = context_tokens.len();
287
288        // ── 3. Generation ───────────────────────────────────────────────────
289        let (generated, stop_reason) = match &self.config.strategy {
290            GenerationStrategy::Greedy | GenerationStrategy::Sampling(_) => {
291                self.run_autoregressive(&context_tokens, engine)
292            }
293            GenerationStrategy::BeamSearch(beam_cfg) => {
294                self.run_beam_search(&context_tokens, beam_cfg.clone(), engine)
295            }
296        };
297
298        // ── 4. Build output ──────────────────────────────────────────────────
299        let text: String = generated
300            .iter()
301            .map(|id| id.to_string())
302            .collect::<Vec<_>>()
303            .join(" ");
304
305        let elapsed_ms = wall_start.elapsed().as_millis() as u64;
306
307        PipelineOutput {
308            text,
309            completion_tokens: generated.len(),
310            token_ids: generated,
311            prompt_tokens,
312            stop_reason,
313            healing_applied,
314            elapsed_ms,
315        }
316    }
317
318    /// Autoregressive generation (greedy or sampled).
319    fn run_autoregressive(
320        &mut self,
321        context_tokens: &[u32],
322        engine: &mut InferenceEngine,
323    ) -> (Vec<u32>, StopReason) {
324        // Use the engine's built-in generate(); it already handles EOS.
325        let max = self.config.max_tokens;
326
327        // We need to track stop sequences ourselves since generate() only
328        // knows about the EOS token ID.
329        let raw = engine
330            .generate(context_tokens, max)
331            .expect("generation must not fail in pipeline");
332
333        // Walk the generated tokens and check stop sequences.
334        self.check_stop_sequences(raw)
335    }
336
337    /// Beam-search generation.
338    fn run_beam_search(
339        &mut self,
340        context_tokens: &[u32],
341        beam_cfg: BeamSearchConfig,
342        _engine: &mut InferenceEngine,
343    ) -> (Vec<u32>, StopReason) {
344        let beam_engine = BeamSearchEngine::new(beam_cfg.clone());
345        let result = beam_engine.search(
346            context_tokens.to_vec(),
347            0, // vocab_size hint (not used by current implementation)
348            |_tokens, _step| {
349                // Real beam search would call engine.forward() here; since
350                // InferenceEngine::generate() is the public API we fall back to
351                // an empty logit vector (search will stall after the prompt).
352                // Full integration requires exposing engine.forward() publicly.
353                Vec::new()
354            },
355        );
356
357        let best = result.best().to_vec();
358        // Strip the prompt prefix from the beam result.
359        let generated = if best.len() > context_tokens.len() {
360            best[context_tokens.len()..].to_vec()
361        } else {
362            Vec::new()
363        };
364
365        let (trimmed, stop_reason) = self.check_stop_sequences(generated);
366        (trimmed, stop_reason)
367    }
368
369    /// Walk `tokens`, checking whether any stop sequence appears in the partial
370    /// decoded text.  Returns the tokens up to (but not including) the stop
371    /// sequence, plus the stop reason.
372    fn check_stop_sequences(&self, tokens: Vec<u32>) -> (Vec<u32>, StopReason) {
373        if self.config.stop_sequences.is_empty() {
374            let stop = if tokens.len() >= self.config.max_tokens {
375                StopReason::MaxTokens
376            } else {
377                StopReason::EndOfSequence
378            };
379            return (tokens, stop);
380        }
381
382        // Build the text token-by-token and scan for stop sequences.
383        let mut text_so_far = String::new();
384        for (i, &tok) in tokens.iter().enumerate() {
385            text_so_far.push_str(&tok.to_string());
386            text_so_far.push(' ');
387
388            for seq in &self.config.stop_sequences {
389                if text_so_far.contains(seq.as_str()) {
390                    return (tokens[..i].to_vec(), StopReason::StopSequence(seq.clone()));
391                }
392            }
393        }
394
395        let stop = if tokens.len() >= self.config.max_tokens {
396            StopReason::MaxTokens
397        } else {
398            StopReason::EndOfSequence
399        };
400        (tokens, stop)
401    }
402
403    /// Maximum number of tokens this pipeline will generate.
404    pub fn max_tokens(&self) -> usize {
405        self.config.max_tokens
406    }
407
408    /// Returns `true` if token healing is configured.
409    pub fn has_healing(&self) -> bool {
410        self.config.healing_config.is_some()
411    }
412
413    /// Returns `true` if a token constraint is attached.
414    pub fn has_constraint(&self) -> bool {
415        self.config.constraint.is_some()
416    }
417
418    /// The list of stop sequences that will halt generation early.
419    pub fn stop_sequences(&self) -> &[String] {
420        &self.config.stop_sequences
421    }
422}
423
424// ─────────────────────────────────────────────────────────────────────────────
425// Convenience constructors
426// ─────────────────────────────────────────────────────────────────────────────
427
428/// Build a standard chat pipeline.
429///
430/// Settings:
431/// - Temperature = 0.7, top-p = 0.9, min-p = 0.05
432/// - Context window = 4096 tokens (TruncateLeft)
433/// - No healing, no constraint
434pub fn chat_pipeline(seed: u64, max_tokens: usize) -> InferencePipeline {
435    let chain = SamplerChain::new(seed)
436        .add(SamplerStep::Temperature(0.7))
437        .add(SamplerStep::TopP(0.9))
438        .add(SamplerStep::MinP(0.05));
439
440    PipelineBuilder::new()
441        .max_tokens(max_tokens)
442        .with_sampling(chain)
443        .context_window(4096, TruncationStrategy::TruncateLeft)
444        .seed(seed)
445        .build()
446}
447
448/// Build a code-generation pipeline.
449///
450/// Settings:
451/// - Temperature = 0.2, top-k = 40
452/// - Token healing enabled (default config)
453/// - Stop on `"\n\n"` (blank line)
454pub fn code_pipeline(seed: u64, max_tokens: usize) -> InferencePipeline {
455    let chain = SamplerChain::new(seed)
456        .add(SamplerStep::Temperature(0.2))
457        .add(SamplerStep::TopK(40));
458
459    PipelineBuilder::new()
460        .max_tokens(max_tokens)
461        .with_sampling(chain)
462        .with_token_healing(TokenHealingConfig::default())
463        .stop_on(vec!["\n\n".to_string()])
464        .seed(seed)
465        .build()
466}
467
468/// Build a greedy (deterministic) pipeline.
469pub fn greedy_pipeline(max_tokens: usize) -> InferencePipeline {
470    PipelineBuilder::new()
471        .max_tokens(max_tokens)
472        .greedy()
473        .build()
474}
475
476// ─────────────────────────────────────────────────────────────────────────────
477// Helper: unused but part of internal plumbing
478// ─────────────────────────────────────────────────────────────────────────────
479
480/// Greedy argmax over a logit slice.
481#[allow(dead_code)]
482fn argmax_logits(logits: &[f32]) -> u32 {
483    logits
484        .iter()
485        .enumerate()
486        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
487        .map(|(i, _)| i as u32)
488        .unwrap_or(0)
489}
490
491/// Build a greedy sampler chain (single Greedy step).
492#[allow(dead_code)]
493fn greedy_chain(seed: u64) -> SamplerChain {
494    SamplerChain::new(seed).add(SamplerStep::Greedy)
495}
496
497/// LCG-based sampler: temperature + weighted draw, no external deps.
498#[allow(dead_code)]
499fn sample_from_logits(logits: &[f32], temperature: f32, rng: &mut LcgRng) -> u32 {
500    if logits.is_empty() {
501        return 0;
502    }
503    if temperature < 1e-6 {
504        return argmax_logits(logits);
505    }
506    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
507    let exps: Vec<f32> = logits
508        .iter()
509        .map(|&v| ((v - max) / temperature).exp())
510        .collect();
511    let sum: f32 = exps.iter().sum();
512    if sum == 0.0 {
513        return 0;
514    }
515    let target = rng.next_f32() * sum;
516    let mut cum = 0.0f32;
517    for (i, &e) in exps.iter().enumerate() {
518        cum += e;
519        if cum >= target {
520            return i as u32;
521        }
522    }
523    (exps.len() - 1) as u32
524}
525
526// ─────────────────────────────────────────────────────────────────────────────
527// Tests
528// ─────────────────────────────────────────────────────────────────────────────
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::sampling::SamplingParams;
534
535    // ── Builder tests ────────────────────────────────────────────────────────
536
537    #[test]
538    fn test_pipeline_builder_default() {
539        let pipeline = PipelineBuilder::new().build();
540        assert_eq!(pipeline.max_tokens(), 256);
541        assert!(!pipeline.has_healing());
542        assert!(!pipeline.has_constraint());
543        assert!(pipeline.stop_sequences().is_empty());
544    }
545
546    #[test]
547    fn test_pipeline_builder_max_tokens() {
548        let pipeline = PipelineBuilder::new().max_tokens(512).build();
549        assert_eq!(pipeline.max_tokens(), 512);
550    }
551
552    #[test]
553    fn test_pipeline_builder_greedy() {
554        let pipeline = PipelineBuilder::new().greedy().build();
555        assert!(matches!(
556            pipeline.config.strategy,
557            GenerationStrategy::Greedy
558        ));
559    }
560
561    #[test]
562    fn test_pipeline_builder_stop_sequences() {
563        let stops = vec!["<|end|>".to_string(), "STOP".to_string()];
564        let pipeline = PipelineBuilder::new().stop_on(stops.clone()).build();
565        assert_eq!(pipeline.stop_sequences(), stops.as_slice());
566    }
567
568    #[test]
569    fn test_pipeline_builder_with_healing() {
570        let cfg = TokenHealingConfig {
571            lookback: 2,
572            min_prob: 0.1,
573            enabled: true,
574        };
575        let pipeline = PipelineBuilder::new().with_token_healing(cfg).build();
576        assert!(pipeline.has_healing());
577    }
578
579    // ── Output / StopReason tests ────────────────────────────────────────────
580
581    #[test]
582    fn test_pipeline_output_stop_reason() {
583        let output = PipelineOutput {
584            text: "hello".to_string(),
585            token_ids: vec![1, 2, 3],
586            prompt_tokens: 5,
587            completion_tokens: 3,
588            stop_reason: StopReason::StopSequence("STOP".to_string()),
589            healing_applied: false,
590            elapsed_ms: 10,
591        };
592        assert_eq!(
593            output.stop_reason,
594            StopReason::StopSequence("STOP".to_string())
595        );
596        assert_eq!(output.completion_tokens, 3);
597        assert_eq!(output.prompt_tokens, 5);
598    }
599
600    // ── Preset tests ─────────────────────────────────────────────────────────
601
602    #[test]
603    fn test_chat_pipeline_preset() {
604        let pipeline = chat_pipeline(42, 256);
605        assert_eq!(pipeline.max_tokens(), 256);
606        assert!(!pipeline.has_healing());
607        assert!(pipeline.stop_sequences().is_empty());
608        // Context window should be 4096
609        assert_eq!(pipeline.config.context_max_tokens, 4096);
610    }
611
612    #[test]
613    fn test_code_pipeline_preset() {
614        let pipeline = code_pipeline(0, 128);
615        assert_eq!(pipeline.max_tokens(), 128);
616        assert!(pipeline.has_healing());
617        assert_eq!(pipeline.stop_sequences(), &["\n\n"]);
618    }
619
620    #[test]
621    fn test_greedy_pipeline_preset() {
622        let pipeline = greedy_pipeline(64);
623        assert_eq!(pipeline.max_tokens(), 64);
624        assert!(!pipeline.has_healing());
625        assert!(!pipeline.has_constraint());
626        assert!(matches!(
627            pipeline.config.strategy,
628            GenerationStrategy::Greedy
629        ));
630    }
631
632    // ── Full run test ────────────────────────────────────────────────────────
633
634    #[test]
635    fn test_pipeline_run_basic() {
636        use oxibonsai_core::config::Qwen3Config;
637
638        let config = Qwen3Config::tiny_test();
639        let mut engine = InferenceEngine::new(
640            config,
641            SamplingParams {
642                temperature: 0.0,
643                ..SamplingParams::default()
644            },
645            42,
646        );
647
648        let mut pipeline = PipelineBuilder::new().max_tokens(5).greedy().build();
649
650        let output = pipeline.run(vec![151644u32, 872], &mut engine);
651        // We care that the pipeline runs without panic and produces a result.
652        assert_eq!(output.prompt_tokens, 2);
653        assert!(output.elapsed_ms < 60_000, "should finish in under 60s");
654    }
655}