kkachi 0.1.8

High-performance, zero-copy library for optimizing language model prompts and programs
Documentation
// Copyright © 2025 lituus-io <spicyzhug@gmail.com>
// All Rights Reserved.
// Licensed under PolyForm Noncommercial 1.0.0

//! Shared candidate-generation engine for `best_of` and `ensemble`.
//!
//! Extracts the ~70% duplicated code between `best_of.rs` and `ensemble.rs`
//! into a single, reusable generation pipeline.

use crate::recursive::defaults::Defaults;
use crate::recursive::llm::Llm;
use crate::recursive::shared::{self, STYLE_HINTS};
use smallvec::SmallVec;

/// Configuration shared between BestOf and Ensemble.
#[derive(Clone, Default)]
pub struct GenerationConfig {
    /// Whether to use Chain of Thought for each candidate.
    pub with_reasoning: bool,
    /// Whether to inject diversity hints for each candidate.
    pub diverse: bool,
    /// Language to extract from code fences before validation (e.g., "rust").
    pub extract_lang: Option<String>,
    /// Whether to generate candidates in parallel.
    pub parallel: bool,
    /// Runtime defaults applied via regex substitution.
    pub defaults: Option<Defaults>,
    /// Pre-rendered skill instructions.
    pub skill_text: Option<String>,
}

/// A raw candidate generated by the engine.
pub struct RawCandidate {
    /// Index in generation order.
    pub index: usize,
    /// The raw LLM output text.
    pub raw_text: String,
    /// The text after extract + defaults transformation.
    pub transformed_text: String,
    /// Tokens used for this candidate.
    pub tokens: u32,
}

/// Shared generation engine that produces N candidates from an LLM.
pub struct GenerationEngine<'a, L: Llm> {
    llm: &'a L,
    prompt: &'a str,
    n: usize,
    config: &'a GenerationConfig,
}

impl<'a, L: Llm> GenerationEngine<'a, L> {
    /// Create a new generation engine.
    pub fn new(llm: &'a L, prompt: &'a str, n: usize, config: &'a GenerationConfig) -> Self {
        Self {
            llm,
            prompt,
            n,
            config,
        }
    }

    /// Generate N candidates, returning raw candidates and an optional error.
    pub async fn generate_candidates(&self) -> (Vec<RawCandidate>, Option<String>) {
        let prompt = shared::assemble_prompt(
            self.config.skill_text.as_deref(),
            self.prompt,
            self.config.with_reasoning,
        );

        let mut candidates = Vec::with_capacity(self.n);
        let mut error: Option<String> = None;

        if self.config.parallel {
            use futures::stream::{FuturesUnordered, StreamExt};

            let contexts: Vec<String> = (0..self.n)
                .map(|i| self.diversity_context(i, None))
                .collect();

            let mut futs = FuturesUnordered::new();
            for (i, ctx) in contexts.iter().enumerate() {
                let fut = self.llm.generate(&prompt, ctx, None);
                futs.push(async move { (i, fut.await) });
            }

            let mut outputs: Vec<(usize, crate::error::Result<crate::recursive::llm::LmOutput>)> =
                Vec::with_capacity(self.n);
            while let Some(result) = futs.next().await {
                outputs.push(result);
            }

            for (i, result) in outputs {
                match result {
                    Ok(output) => {
                        let tokens = output.prompt_tokens + output.completion_tokens;
                        let transformed = shared::transform_output(
                            &output.text,
                            self.config.extract_lang.as_deref(),
                            self.config.defaults.as_ref(),
                        );
                        candidates.push(RawCandidate {
                            index: i,
                            raw_text: output.text.to_string(),
                            transformed_text: transformed,
                            tokens,
                        });
                    }
                    Err(e) => {
                        error = Some(e.to_string());
                    }
                }
            }
        } else {
            // Sequential path with richer diversity context
            let mut prev_snippets: SmallVec<[String; 4]> = SmallVec::new();

            for i in 0..self.n {
                let context = self.diversity_context(i, Some(&prev_snippets));

                let output = match self.llm.generate(&prompt, &context, None).await {
                    Ok(out) => out,
                    Err(e) => {
                        error = Some(e.to_string());
                        continue;
                    }
                };

                let tokens = output.prompt_tokens + output.completion_tokens;
                let transformed = shared::transform_output(
                    &output.text,
                    self.config.extract_lang.as_deref(),
                    self.config.defaults.as_ref(),
                );

                // Track previous snippets for diversity
                if self.config.diverse && prev_snippets.len() < 3 {
                    let snippet = if output.text.len() > 80 {
                        let mut end = 80;
                        while end > 0 && !output.text.is_char_boundary(end) {
                            end -= 1;
                        }
                        output.text[..end].to_string()
                    } else {
                        output.text.to_string()
                    };
                    prev_snippets.push(snippet);
                }

                candidates.push(RawCandidate {
                    index: i,
                    raw_text: output.text.to_string(),
                    transformed_text: transformed,
                    tokens,
                });
            }
        }

        (candidates, error)
    }

    /// Build a diversity context string for the given candidate index.
    fn diversity_context(&self, i: usize, prev_snippets: Option<&SmallVec<[String; 4]>>) -> String {
        if !self.config.diverse || self.n <= 1 {
            return String::new();
        }

        let hint = STYLE_HINTS[i % STYLE_HINTS.len()];
        let mut ctx = format!(
            "Generate candidate {} of {}. Style: {}.",
            i + 1,
            self.n,
            hint
        );

        if let Some(snippets) = prev_snippets {
            if i > 0 && !snippets.is_empty() {
                ctx.push_str("\n\nAvoid repeating these previous approaches:\n");
                for snippet in snippets.iter() {
                    ctx.push_str(&format!("- {}\n", snippet.replace('\n', " ")));
                }
            }
        }

        ctx
    }
}

/// Shared iteration config for `reason` and `refine` overlap.
#[derive(Clone, Default)]
#[allow(dead_code)]
pub struct IterationConfig {
    /// Language to extract from code fences.
    pub extract_lang: Option<String>,
    /// Runtime defaults.
    pub defaults: Option<Defaults>,
    /// Pre-rendered skill text.
    pub skill_text: Option<String>,
}

#[allow(dead_code)]
impl IterationConfig {
    /// Apply extract + defaults transform to raw LLM output.
    pub fn transform(&self, text: &str) -> String {
        shared::transform_output(text, self.extract_lang.as_deref(), self.defaults.as_ref())
    }
}

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
mod tests {
    use super::*;
    use crate::recursive::llm::MockLlm;
    use std::sync::atomic::{AtomicUsize, Ordering};

    #[tokio::test]
    async fn test_engine_sequential() {
        let counter = AtomicUsize::new(0);
        let llm = MockLlm::new(move |_, _| {
            let n = counter.fetch_add(1, Ordering::SeqCst);
            format!("Response {}", n)
        });

        let config = GenerationConfig {
            diverse: true,
            ..Default::default()
        };
        let engine = GenerationEngine::new(&llm, "Generate", 3, &config);
        let (candidates, error) = engine.generate_candidates().await;

        assert!(error.is_none());
        assert_eq!(candidates.len(), 3);
        assert_eq!(candidates[0].raw_text, "Response 0");
        assert_eq!(candidates[2].raw_text, "Response 2");
    }

    #[tokio::test]
    async fn test_engine_parallel() {
        let counter = AtomicUsize::new(0);
        let llm = MockLlm::new(move |_, _| {
            let n = counter.fetch_add(1, Ordering::SeqCst);
            format!("Parallel {}", n)
        });

        let config = GenerationConfig {
            parallel: true,
            diverse: true,
            ..Default::default()
        };
        let engine = GenerationEngine::new(&llm, "Generate", 4, &config);
        let (candidates, error) = engine.generate_candidates().await;

        assert!(error.is_none());
        assert_eq!(candidates.len(), 4);
    }

    #[test]
    fn test_iteration_config_transform() {
        let config = IterationConfig {
            extract_lang: None,
            defaults: None,
            skill_text: None,
        };
        assert_eq!(config.transform("hello"), "hello");
    }
}