1use std::time::Instant;
30
31use crate::beam_search::{BeamSearchConfig, BeamSearchEngine};
32use crate::constrained_decoding::TokenConstraint;
33use crate::context_manager::{ContextWindow, TruncationStrategy};
34use crate::engine::InferenceEngine;
35use crate::sampling_advanced::{LcgRng, SamplerChain, SamplerStep};
36use crate::token_healing::{TokenHealer, TokenHealingConfig};
37
38pub enum GenerationStrategy {
44 Sampling(SamplerChain),
46 BeamSearch(BeamSearchConfig),
48 Greedy,
50}
51
52#[derive(Debug, Clone, PartialEq)]
58pub enum StopReason {
59 MaxTokens,
61 StopSequence(String),
63 EndOfSequence,
65 ConstraintComplete,
67}
68
69#[derive(Debug)]
75pub struct PipelineOutput {
76 pub text: String,
81 pub token_ids: Vec<u32>,
83 pub prompt_tokens: usize,
85 pub completion_tokens: usize,
87 pub stop_reason: StopReason,
89 pub healing_applied: bool,
91 pub elapsed_ms: u64,
93}
94
95struct PipelineConfig {
100 max_tokens: usize,
101 strategy: GenerationStrategy,
102 healing_config: Option<TokenHealingConfig>,
103 constraint: Option<Box<dyn TokenConstraint>>,
104 context_max_tokens: usize,
105 truncation: TruncationStrategy,
106 stop_sequences: Vec<String>,
107 #[allow(dead_code)]
110 seed: u64,
111}
112
113pub struct PipelineBuilder {
119 max_tokens: usize,
120 strategy: Option<GenerationStrategy>,
121 healing_config: Option<TokenHealingConfig>,
122 constraint: Option<Box<dyn TokenConstraint>>,
123 context_max_tokens: usize,
124 truncation: TruncationStrategy,
125 stop_sequences: Vec<String>,
126 seed: u64,
127}
128
129impl Default for PipelineBuilder {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl PipelineBuilder {
136 pub fn new() -> Self {
146 Self {
147 max_tokens: 256,
148 strategy: None,
149 healing_config: None,
150 constraint: None,
151 context_max_tokens: 2048,
152 truncation: TruncationStrategy::TruncateLeft,
153 stop_sequences: Vec::new(),
154 seed: 0,
155 }
156 }
157
158 pub fn max_tokens(mut self, n: usize) -> Self {
160 self.max_tokens = n;
161 self
162 }
163
164 pub fn greedy(mut self) -> Self {
166 self.strategy = Some(GenerationStrategy::Greedy);
167 self
168 }
169
170 pub fn with_sampling(mut self, chain: SamplerChain) -> Self {
172 self.strategy = Some(GenerationStrategy::Sampling(chain));
173 self
174 }
175
176 pub fn with_beam_search(mut self, config: BeamSearchConfig) -> Self {
178 self.strategy = Some(GenerationStrategy::BeamSearch(config));
179 self
180 }
181
182 pub fn with_token_healing(mut self, config: TokenHealingConfig) -> Self {
184 self.healing_config = Some(config);
185 self
186 }
187
188 pub fn with_constraint(mut self, c: Box<dyn TokenConstraint>) -> Self {
190 self.constraint = Some(c);
191 self
192 }
193
194 pub fn stop_on(mut self, sequences: Vec<String>) -> Self {
196 self.stop_sequences = sequences;
197 self
198 }
199
200 pub fn context_window(mut self, max_tokens: usize, strategy: TruncationStrategy) -> Self {
202 self.context_max_tokens = max_tokens;
203 self.truncation = strategy;
204 self
205 }
206
207 pub fn seed(mut self, s: u64) -> Self {
209 self.seed = s;
210 self
211 }
212
213 pub fn build(self) -> InferencePipeline {
215 let strategy = self.strategy.unwrap_or(GenerationStrategy::Greedy);
216 InferencePipeline {
217 config: PipelineConfig {
218 max_tokens: self.max_tokens,
219 strategy,
220 healing_config: self.healing_config,
221 constraint: self.constraint,
222 context_max_tokens: self.context_max_tokens,
223 truncation: self.truncation,
224 stop_sequences: self.stop_sequences,
225 seed: self.seed,
226 },
227 }
228 }
229}
230
231pub struct InferencePipeline {
240 config: PipelineConfig,
241}
242
243impl InferencePipeline {
244 pub fn run(
258 &mut self,
259 prompt_token_ids: Vec<u32>,
260 engine: &mut InferenceEngine,
261 ) -> PipelineOutput {
262 let wall_start = Instant::now();
263
264 let (healed_prompt, healing_applied) =
266 if let Some(ref healing_cfg) = self.config.healing_config {
267 let healer = TokenHealer::new(healing_cfg.clone());
268 let result = healer.heal(&prompt_token_ids, 0, |_prefix| Vec::new());
276 let changed = result.changed;
277 (result.healed_tokens, changed)
278 } else {
279 (prompt_token_ids, false)
280 };
281
282 let mut window = ContextWindow::new(self.config.context_max_tokens, self.config.truncation);
284 window.append(&healed_prompt);
285 let context_tokens = window.tokens();
286 let prompt_tokens = context_tokens.len();
287
288 let (generated, stop_reason) = match &self.config.strategy {
290 GenerationStrategy::Greedy | GenerationStrategy::Sampling(_) => {
291 self.run_autoregressive(&context_tokens, engine)
292 }
293 GenerationStrategy::BeamSearch(beam_cfg) => {
294 self.run_beam_search(&context_tokens, beam_cfg.clone(), engine)
295 }
296 };
297
298 let text: String = generated
300 .iter()
301 .map(|id| id.to_string())
302 .collect::<Vec<_>>()
303 .join(" ");
304
305 let elapsed_ms = wall_start.elapsed().as_millis() as u64;
306
307 PipelineOutput {
308 text,
309 completion_tokens: generated.len(),
310 token_ids: generated,
311 prompt_tokens,
312 stop_reason,
313 healing_applied,
314 elapsed_ms,
315 }
316 }
317
318 fn run_autoregressive(
320 &mut self,
321 context_tokens: &[u32],
322 engine: &mut InferenceEngine,
323 ) -> (Vec<u32>, StopReason) {
324 let max = self.config.max_tokens;
326
327 let raw = engine
330 .generate(context_tokens, max)
331 .expect("generation must not fail in pipeline");
332
333 self.check_stop_sequences(raw)
335 }
336
337 fn run_beam_search(
339 &mut self,
340 context_tokens: &[u32],
341 beam_cfg: BeamSearchConfig,
342 _engine: &mut InferenceEngine,
343 ) -> (Vec<u32>, StopReason) {
344 let beam_engine = BeamSearchEngine::new(beam_cfg.clone());
345 let result = beam_engine.search(
346 context_tokens.to_vec(),
347 0, |_tokens, _step| {
349 Vec::new()
354 },
355 );
356
357 let best = result.best().to_vec();
358 let generated = if best.len() > context_tokens.len() {
360 best[context_tokens.len()..].to_vec()
361 } else {
362 Vec::new()
363 };
364
365 let (trimmed, stop_reason) = self.check_stop_sequences(generated);
366 (trimmed, stop_reason)
367 }
368
369 fn check_stop_sequences(&self, tokens: Vec<u32>) -> (Vec<u32>, StopReason) {
373 if self.config.stop_sequences.is_empty() {
374 let stop = if tokens.len() >= self.config.max_tokens {
375 StopReason::MaxTokens
376 } else {
377 StopReason::EndOfSequence
378 };
379 return (tokens, stop);
380 }
381
382 let mut text_so_far = String::new();
384 for (i, &tok) in tokens.iter().enumerate() {
385 text_so_far.push_str(&tok.to_string());
386 text_so_far.push(' ');
387
388 for seq in &self.config.stop_sequences {
389 if text_so_far.contains(seq.as_str()) {
390 return (tokens[..i].to_vec(), StopReason::StopSequence(seq.clone()));
391 }
392 }
393 }
394
395 let stop = if tokens.len() >= self.config.max_tokens {
396 StopReason::MaxTokens
397 } else {
398 StopReason::EndOfSequence
399 };
400 (tokens, stop)
401 }
402
403 pub fn max_tokens(&self) -> usize {
405 self.config.max_tokens
406 }
407
408 pub fn has_healing(&self) -> bool {
410 self.config.healing_config.is_some()
411 }
412
413 pub fn has_constraint(&self) -> bool {
415 self.config.constraint.is_some()
416 }
417
418 pub fn stop_sequences(&self) -> &[String] {
420 &self.config.stop_sequences
421 }
422}
423
424pub fn chat_pipeline(seed: u64, max_tokens: usize) -> InferencePipeline {
435 let chain = SamplerChain::new(seed)
436 .add(SamplerStep::Temperature(0.7))
437 .add(SamplerStep::TopP(0.9))
438 .add(SamplerStep::MinP(0.05));
439
440 PipelineBuilder::new()
441 .max_tokens(max_tokens)
442 .with_sampling(chain)
443 .context_window(4096, TruncationStrategy::TruncateLeft)
444 .seed(seed)
445 .build()
446}
447
448pub fn code_pipeline(seed: u64, max_tokens: usize) -> InferencePipeline {
455 let chain = SamplerChain::new(seed)
456 .add(SamplerStep::Temperature(0.2))
457 .add(SamplerStep::TopK(40));
458
459 PipelineBuilder::new()
460 .max_tokens(max_tokens)
461 .with_sampling(chain)
462 .with_token_healing(TokenHealingConfig::default())
463 .stop_on(vec!["\n\n".to_string()])
464 .seed(seed)
465 .build()
466}
467
468pub fn greedy_pipeline(max_tokens: usize) -> InferencePipeline {
470 PipelineBuilder::new()
471 .max_tokens(max_tokens)
472 .greedy()
473 .build()
474}
475
476#[allow(dead_code)]
482fn argmax_logits(logits: &[f32]) -> u32 {
483 logits
484 .iter()
485 .enumerate()
486 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
487 .map(|(i, _)| i as u32)
488 .unwrap_or(0)
489}
490
491#[allow(dead_code)]
493fn greedy_chain(seed: u64) -> SamplerChain {
494 SamplerChain::new(seed).add(SamplerStep::Greedy)
495}
496
497#[allow(dead_code)]
499fn sample_from_logits(logits: &[f32], temperature: f32, rng: &mut LcgRng) -> u32 {
500 if logits.is_empty() {
501 return 0;
502 }
503 if temperature < 1e-6 {
504 return argmax_logits(logits);
505 }
506 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
507 let exps: Vec<f32> = logits
508 .iter()
509 .map(|&v| ((v - max) / temperature).exp())
510 .collect();
511 let sum: f32 = exps.iter().sum();
512 if sum == 0.0 {
513 return 0;
514 }
515 let target = rng.next_f32() * sum;
516 let mut cum = 0.0f32;
517 for (i, &e) in exps.iter().enumerate() {
518 cum += e;
519 if cum >= target {
520 return i as u32;
521 }
522 }
523 (exps.len() - 1) as u32
524}
525
526#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::sampling::SamplingParams;
534
535 #[test]
538 fn test_pipeline_builder_default() {
539 let pipeline = PipelineBuilder::new().build();
540 assert_eq!(pipeline.max_tokens(), 256);
541 assert!(!pipeline.has_healing());
542 assert!(!pipeline.has_constraint());
543 assert!(pipeline.stop_sequences().is_empty());
544 }
545
546 #[test]
547 fn test_pipeline_builder_max_tokens() {
548 let pipeline = PipelineBuilder::new().max_tokens(512).build();
549 assert_eq!(pipeline.max_tokens(), 512);
550 }
551
552 #[test]
553 fn test_pipeline_builder_greedy() {
554 let pipeline = PipelineBuilder::new().greedy().build();
555 assert!(matches!(
556 pipeline.config.strategy,
557 GenerationStrategy::Greedy
558 ));
559 }
560
561 #[test]
562 fn test_pipeline_builder_stop_sequences() {
563 let stops = vec!["<|end|>".to_string(), "STOP".to_string()];
564 let pipeline = PipelineBuilder::new().stop_on(stops.clone()).build();
565 assert_eq!(pipeline.stop_sequences(), stops.as_slice());
566 }
567
568 #[test]
569 fn test_pipeline_builder_with_healing() {
570 let cfg = TokenHealingConfig {
571 lookback: 2,
572 min_prob: 0.1,
573 enabled: true,
574 };
575 let pipeline = PipelineBuilder::new().with_token_healing(cfg).build();
576 assert!(pipeline.has_healing());
577 }
578
579 #[test]
582 fn test_pipeline_output_stop_reason() {
583 let output = PipelineOutput {
584 text: "hello".to_string(),
585 token_ids: vec![1, 2, 3],
586 prompt_tokens: 5,
587 completion_tokens: 3,
588 stop_reason: StopReason::StopSequence("STOP".to_string()),
589 healing_applied: false,
590 elapsed_ms: 10,
591 };
592 assert_eq!(
593 output.stop_reason,
594 StopReason::StopSequence("STOP".to_string())
595 );
596 assert_eq!(output.completion_tokens, 3);
597 assert_eq!(output.prompt_tokens, 5);
598 }
599
600 #[test]
603 fn test_chat_pipeline_preset() {
604 let pipeline = chat_pipeline(42, 256);
605 assert_eq!(pipeline.max_tokens(), 256);
606 assert!(!pipeline.has_healing());
607 assert!(pipeline.stop_sequences().is_empty());
608 assert_eq!(pipeline.config.context_max_tokens, 4096);
610 }
611
612 #[test]
613 fn test_code_pipeline_preset() {
614 let pipeline = code_pipeline(0, 128);
615 assert_eq!(pipeline.max_tokens(), 128);
616 assert!(pipeline.has_healing());
617 assert_eq!(pipeline.stop_sequences(), &["\n\n"]);
618 }
619
620 #[test]
621 fn test_greedy_pipeline_preset() {
622 let pipeline = greedy_pipeline(64);
623 assert_eq!(pipeline.max_tokens(), 64);
624 assert!(!pipeline.has_healing());
625 assert!(!pipeline.has_constraint());
626 assert!(matches!(
627 pipeline.config.strategy,
628 GenerationStrategy::Greedy
629 ));
630 }
631
632 #[test]
635 fn test_pipeline_run_basic() {
636 use oxibonsai_core::config::Qwen3Config;
637
638 let config = Qwen3Config::tiny_test();
639 let mut engine = InferenceEngine::new(
640 config,
641 SamplingParams {
642 temperature: 0.0,
643 ..SamplingParams::default()
644 },
645 42,
646 );
647
648 let mut pipeline = PipelineBuilder::new().max_tokens(5).greedy().build();
649
650 let output = pipeline.run(vec![151644u32, 872], &mut engine);
651 assert_eq!(output.prompt_tokens, 2);
653 assert!(output.elapsed_ms < 60_000, "should finish in under 60s");
654 }
655}