use crate::recursive::defaults::Defaults;
use crate::recursive::executor::{CodeExecutor, DynCodeExecutor};
use crate::recursive::llm::Llm;
use crate::recursive::shared;
use crate::recursive::skill::Skill;
use crate::recursive::validate::{NoValidation, Validate};
pub fn program<'a, L: Llm>(llm: &'a L, problem: &'a str) -> Program<'a, L, NoValidation> {
Program::new(llm, problem)
}
#[derive(Clone)]
pub struct ProgramConfig {
pub max_iter: u32,
pub include_code: bool,
pub language: String,
pub extract_lang: Option<String>,
pub defaults: Option<Defaults>,
pub skill_text: Option<String>,
}
impl Default for ProgramConfig {
fn default() -> Self {
Self {
max_iter: 3,
include_code: true,
language: "python".to_string(),
extract_lang: None,
defaults: None,
skill_text: None,
}
}
}
pub struct Program<'a, L: Llm, V: Validate> {
llm: &'a L,
problem: &'a str,
executor: Option<Box<dyn DynCodeExecutor + 'static>>,
validator: V,
pub config: ProgramConfig,
}
impl<'a, L: Llm> Program<'a, L, NoValidation> {
pub fn new(llm: &'a L, problem: &'a str) -> Self {
Self {
llm,
problem,
executor: None,
validator: NoValidation,
config: ProgramConfig::default(),
}
}
}
impl<'a, L: Llm, V: Validate> Program<'a, L, V> {
pub fn validate<V2: Validate>(self, validator: V2) -> Program<'a, L, V2> {
Program {
llm: self.llm,
problem: self.problem,
executor: self.executor,
validator,
config: self.config,
}
}
pub fn executor<E: CodeExecutor + 'static>(mut self, executor: E) -> Self {
self.config.language = CodeExecutor::language(&executor).to_string();
self.executor = Some(Box::new(executor));
self
}
pub fn executor_dyn(mut self, executor: Box<dyn DynCodeExecutor + 'static>) -> Self {
self.config.language = executor.language().to_string();
self.executor = Some(executor);
self
}
pub fn max_iter(mut self, n: u32) -> Self {
self.config.max_iter = n.max(1);
self
}
pub fn language(mut self, lang: &str) -> Self {
self.config.language = lang.to_string();
self
}
pub fn extract(mut self, lang: impl Into<String>) -> Self {
self.config.extract_lang = Some(lang.into());
self
}
pub fn skill(mut self, skill: &Skill<'_>) -> Self {
let rendered = skill.render();
if rendered.is_empty() {
self.config.skill_text = None;
} else {
self.config.skill_text = Some(rendered);
}
self
}
pub fn defaults(mut self, defaults: Defaults) -> Self {
self.config.defaults = Some(defaults);
self
}
pub fn no_code(mut self) -> Self {
self.config.include_code = false;
self
}
pub fn go(self) -> ProgramResult {
shared::block_on(self.run())
}
pub async fn run(self) -> ProgramResult {
let executor = match self.executor {
Some(ref e) => e,
None => {
return ProgramResult {
output: String::new(),
code: String::new(),
attempts: 0,
tokens: 0,
success: false,
error: Some("No executor configured. Use .executor() to set one.".to_string()),
};
}
};
let max_iter = self.config.max_iter;
let include_code = self.config.include_code;
let mut last_error: Option<String> = None;
let mut last_code = String::new();
let mut total_tokens = 0u32;
for attempt in 0u32..max_iter {
let prompt = self.build_prompt(last_error.as_deref());
let output = match self.llm.generate(&prompt, "", None).await {
Ok(out) => out,
Err(e) => {
return ProgramResult {
output: String::new(),
code: last_code,
attempts: attempt + 1,
tokens: total_tokens,
success: false,
error: Some(e.to_string()),
};
}
};
total_tokens += output.prompt_tokens + output.completion_tokens;
let code = self.extract_code(&output.text);
last_code = code.to_string();
let result = executor.execute_dyn(code).await;
if result.success {
let output_text = if let Some(ref lang) = self.config.extract_lang {
use crate::recursive::rewrite::extract_code;
extract_code(result.output(), lang)
.map(|s| s.to_string())
.unwrap_or_else(|| result.stdout.trim().to_string())
} else {
result.stdout.trim().to_string()
};
let output_text = match self.config.defaults {
Some(ref d) => d.apply(&output_text),
None => output_text,
};
let score = self.validator.validate(&output_text);
if score.value >= 1.0 || attempt == max_iter - 1 {
return ProgramResult {
output: output_text,
code: if include_code {
last_code
} else {
String::new()
},
attempts: attempt + 1,
tokens: total_tokens,
success: true,
error: None,
};
}
last_error = score.feedback_str().map(|s| s.to_string());
} else {
last_error = Some(result.stderr.clone());
}
}
ProgramResult {
output: String::new(),
code: if include_code {
last_code
} else {
String::new()
},
attempts: max_iter,
tokens: total_tokens,
success: false,
error: last_error,
}
}
fn build_prompt(&self, previous_error: Option<&str>) -> String {
let mut prompt = String::new();
if let Some(ref skill_text) = self.config.skill_text {
prompt.push_str(skill_text);
prompt.push('\n');
}
prompt.push_str(&format!(
"Write {} code to solve the following problem:\n\n{}\n\n",
self.config.language, self.problem
));
if let Some(error) = previous_error {
prompt.push_str(&format!(
"Previous attempt failed with error:\n```\n{}\n```\n\n\
Please fix the code and try again.\n\n",
error
));
}
prompt.push_str(&format!(
"Provide your solution in a code block:\n```{}\n",
self.config.language
));
prompt
}
fn extract_code<'b>(&self, response: &'b str) -> &'b str {
let lang_marker = format!("```{}", self.config.language);
if let Some(start) = response.find(&lang_marker) {
let code_start = start + lang_marker.len();
let code_start = response[code_start..]
.find('\n')
.map(|i| code_start + i + 1)
.unwrap_or(code_start);
if let Some(end) = response[code_start..].find("```") {
return &response[code_start..code_start + end];
}
}
if let Some(start) = response.find("```") {
let code_start = start + 3;
let code_start = response[code_start..]
.find('\n')
.map(|i| code_start + i + 1)
.unwrap_or(code_start);
if let Some(end) = response[code_start..].find("```") {
return &response[code_start..code_start + end];
}
}
response.trim()
}
}
#[derive(Debug, Clone)]
pub struct ProgramResult {
pub output: String,
pub code: String,
pub attempts: u32,
pub tokens: u32,
pub success: bool,
pub error: Option<String>,
}
impl ProgramResult {
pub fn code(&self) -> &str {
&self.code
}
pub fn is_success(&self) -> bool {
self.success
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(unix)]
use crate::recursive::executor::bash_executor;
use crate::recursive::llm::MockLlm;
#[test]
#[cfg(unix)]
fn test_program_basic() {
let llm = MockLlm::new(|_, _| "Here's the code:\n```bash\necho 42\n```".to_string());
let result = program(&llm, "Print 42").executor(bash_executor()).go();
assert!(result.success);
assert!(result.output.contains("42"));
assert!(result.code.contains("echo"));
}
#[test]
fn test_program_no_executor() {
let llm = MockLlm::new(|_, _| "print(42)".to_string());
let result = program(&llm, "Print 42").go();
assert!(!result.success);
assert!(result.error.is_some());
assert!(result.error.unwrap().contains("No executor"));
}
#[test]
fn test_program_extract_code() {
let llm = MockLlm::new(|_, _| String::new());
let builder = program(&llm, "test");
let code =
builder.extract_code("Here's the solution:\n```python\nprint('hello')\n```\nDone!");
assert_eq!(code, "print('hello')\n");
let code = builder.extract_code("```\necho test\n```");
assert_eq!(code, "echo test\n");
let code = builder.extract_code("just plain text");
assert_eq!(code, "just plain text");
}
#[test]
#[cfg(unix)]
fn test_program_with_error_retry() {
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = AtomicUsize::new(0);
let llm = MockLlm::new(move |_prompt, _| {
let n = counter.fetch_add(1, Ordering::SeqCst);
match n {
0 => "```bash\nexit 1\n```".to_string(), _ => "```bash\necho success\n```".to_string(), }
});
let result = program(&llm, "Succeed")
.executor(bash_executor())
.max_iter(3)
.go();
assert!(result.success);
assert_eq!(result.attempts, 2);
}
#[test]
#[cfg(unix)]
fn test_program_no_code() {
let llm = MockLlm::new(|_, _| "```bash\necho test\n```".to_string());
let result = program(&llm, "Test")
.executor(bash_executor())
.no_code()
.go();
assert!(result.success);
assert!(result.code.is_empty());
}
#[test]
fn test_program_config() {
let llm = MockLlm::new(|_, _| String::new());
let builder = program(&llm, "test").max_iter(5).language("python");
assert_eq!(builder.config.max_iter, 5);
assert_eq!(builder.config.language, "python");
}
#[test]
fn test_program_result_methods() {
let result = ProgramResult {
output: "42".to_string(),
code: "print(42)".to_string(),
attempts: 1,
tokens: 100,
success: true,
error: None,
};
assert!(result.is_success());
assert_eq!(result.code(), "print(42)");
}
#[test]
#[cfg(unix)]
fn test_program_with_skill() {
use crate::recursive::skill::Skill;
let llm = MockLlm::new(|prompt, _| {
if prompt.contains("deletionProtection") {
"```bash\necho skill_applied\n```".to_string()
} else {
"```bash\necho no_skill\n```".to_string()
}
});
let skill = Skill::new().instruct(
"deletionProtection",
"Always set deletionProtection: false.",
);
let result = program(&llm, "Generate config")
.executor(bash_executor())
.skill(&skill)
.go();
assert!(result.success);
assert!(result.output.contains("skill_applied"));
}
#[test]
#[cfg(unix)]
fn test_program_with_defaults() {
use crate::recursive::defaults::Defaults;
let llm = MockLlm::new(|_, _| "```bash\necho admin@example.com\n```".to_string());
let defaults = Defaults::new().set("email", r"admin@example\.com", "real@company.com");
let result = program(&llm, "Generate IAM")
.executor(bash_executor())
.defaults(defaults)
.go();
assert!(result.success);
assert!(result.output.contains("real@company.com"));
assert!(!result.output.contains("admin@example.com"));
}
}