use crate::recursive::defaults::Defaults;
use crate::recursive::llm::Llm;
use crate::recursive::shared::{self, STYLE_HINTS};
use smallvec::SmallVec;
#[derive(Clone, Default)]
pub struct GenerationConfig {
pub with_reasoning: bool,
pub diverse: bool,
pub extract_lang: Option<String>,
pub parallel: bool,
pub defaults: Option<Defaults>,
pub skill_text: Option<String>,
}
pub struct RawCandidate {
pub index: usize,
pub raw_text: String,
pub transformed_text: String,
pub tokens: u32,
}
pub struct GenerationEngine<'a, L: Llm> {
llm: &'a L,
prompt: &'a str,
n: usize,
config: &'a GenerationConfig,
}
impl<'a, L: Llm> GenerationEngine<'a, L> {
pub fn new(llm: &'a L, prompt: &'a str, n: usize, config: &'a GenerationConfig) -> Self {
Self {
llm,
prompt,
n,
config,
}
}
pub async fn generate_candidates(&self) -> (Vec<RawCandidate>, Option<String>) {
let prompt = shared::assemble_prompt(
self.config.skill_text.as_deref(),
self.prompt,
self.config.with_reasoning,
);
let mut candidates = Vec::with_capacity(self.n);
let mut error: Option<String> = None;
if self.config.parallel {
use futures::stream::{FuturesUnordered, StreamExt};
let contexts: Vec<String> = (0..self.n)
.map(|i| self.diversity_context(i, None))
.collect();
let mut futs = FuturesUnordered::new();
for (i, ctx) in contexts.iter().enumerate() {
let fut = self.llm.generate(&prompt, ctx, None);
futs.push(async move { (i, fut.await) });
}
let mut outputs: Vec<(usize, crate::error::Result<crate::recursive::llm::LmOutput>)> =
Vec::with_capacity(self.n);
while let Some(result) = futs.next().await {
outputs.push(result);
}
for (i, result) in outputs {
match result {
Ok(output) => {
let tokens = output.prompt_tokens + output.completion_tokens;
let transformed = shared::transform_output(
&output.text,
self.config.extract_lang.as_deref(),
self.config.defaults.as_ref(),
);
candidates.push(RawCandidate {
index: i,
raw_text: output.text.to_string(),
transformed_text: transformed,
tokens,
});
}
Err(e) => {
error = Some(e.to_string());
}
}
}
} else {
let mut prev_snippets: SmallVec<[String; 4]> = SmallVec::new();
for i in 0..self.n {
let context = self.diversity_context(i, Some(&prev_snippets));
let output = match self.llm.generate(&prompt, &context, None).await {
Ok(out) => out,
Err(e) => {
error = Some(e.to_string());
continue;
}
};
let tokens = output.prompt_tokens + output.completion_tokens;
let transformed = shared::transform_output(
&output.text,
self.config.extract_lang.as_deref(),
self.config.defaults.as_ref(),
);
if self.config.diverse && prev_snippets.len() < 3 {
let snippet = if output.text.len() > 80 {
let mut end = 80;
while end > 0 && !output.text.is_char_boundary(end) {
end -= 1;
}
output.text[..end].to_string()
} else {
output.text.to_string()
};
prev_snippets.push(snippet);
}
candidates.push(RawCandidate {
index: i,
raw_text: output.text.to_string(),
transformed_text: transformed,
tokens,
});
}
}
(candidates, error)
}
fn diversity_context(&self, i: usize, prev_snippets: Option<&SmallVec<[String; 4]>>) -> String {
if !self.config.diverse || self.n <= 1 {
return String::new();
}
let hint = STYLE_HINTS[i % STYLE_HINTS.len()];
let mut ctx = format!(
"Generate candidate {} of {}. Style: {}.",
i + 1,
self.n,
hint
);
if let Some(snippets) = prev_snippets {
if i > 0 && !snippets.is_empty() {
ctx.push_str("\n\nAvoid repeating these previous approaches:\n");
for snippet in snippets.iter() {
ctx.push_str(&format!("- {}\n", snippet.replace('\n', " ")));
}
}
}
ctx
}
}
#[derive(Clone, Default)]
#[allow(dead_code)]
pub struct IterationConfig {
pub extract_lang: Option<String>,
pub defaults: Option<Defaults>,
pub skill_text: Option<String>,
}
#[allow(dead_code)]
impl IterationConfig {
pub fn transform(&self, text: &str) -> String {
shared::transform_output(text, self.extract_lang.as_deref(), self.defaults.as_ref())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::recursive::llm::MockLlm;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn test_engine_sequential() {
let counter = AtomicUsize::new(0);
let llm = MockLlm::new(move |_, _| {
let n = counter.fetch_add(1, Ordering::SeqCst);
format!("Response {}", n)
});
let config = GenerationConfig {
diverse: true,
..Default::default()
};
let engine = GenerationEngine::new(&llm, "Generate", 3, &config);
let (candidates, error) = engine.generate_candidates().await;
assert!(error.is_none());
assert_eq!(candidates.len(), 3);
assert_eq!(candidates[0].raw_text, "Response 0");
assert_eq!(candidates[2].raw_text, "Response 2");
}
#[tokio::test]
async fn test_engine_parallel() {
let counter = AtomicUsize::new(0);
let llm = MockLlm::new(move |_, _| {
let n = counter.fetch_add(1, Ordering::SeqCst);
format!("Parallel {}", n)
});
let config = GenerationConfig {
parallel: true,
diverse: true,
..Default::default()
};
let engine = GenerationEngine::new(&llm, "Generate", 4, &config);
let (candidates, error) = engine.generate_candidates().await;
assert!(error.is_none());
assert_eq!(candidates.len(), 4);
}
#[test]
fn test_iteration_config_transform() {
let config = IterationConfig {
extract_lang: None,
defaults: None,
skill_text: None,
};
assert_eq!(config.transform("hello"), "hello");
}
}