use crate::error::{Error, Result};
use crate::recursive::cli::CliCapture;
use crate::recursive::llm::Llm;
use crate::recursive::validate::Validate;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::fmt;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum StopReason {
TargetReached,
#[default]
MaxIterations,
BudgetExhausted,
TimeoutReached,
Plateau,
HumanAccepted,
HumanRejected,
}
impl fmt::Display for StopReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TargetReached => write!(f, "target reached"),
Self::MaxIterations => write!(f, "max iterations"),
Self::BudgetExhausted => write!(f, "budget exhausted"),
Self::TimeoutReached => write!(f, "timeout reached"),
Self::Plateau => write!(f, "plateau"),
Self::HumanAccepted => write!(f, "human accepted"),
Self::HumanRejected => write!(f, "human rejected"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
pub struct ContextId(pub u64);
impl ContextId {
pub fn new() -> Self {
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let counter = COUNTER.fetch_add(1, Ordering::SeqCst);
Self(nanos.wrapping_add(counter))
}
}
impl fmt::Display for ContextId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:016x}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Iteration {
pub number: u32,
pub output: String,
pub score: f64,
pub feedback: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Correction {
pub error: String,
pub resolution: String,
pub iteration: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Example {
pub input: String,
pub output: String,
}
impl Example {
pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
Self {
input: input.into(),
output: output.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizedPrompt {
pub signature: String,
pub instructions: String,
pub examples: SmallVec<[Example; 4]>,
pub template: String,
}
impl OptimizedPrompt {
pub fn new(signature: impl Into<String>) -> Self {
Self {
signature: signature.into(),
instructions: String::new(),
examples: SmallVec::new(),
template: String::new(),
}
}
pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = instructions.into();
self
}
pub fn with_example(mut self, input: impl Into<String>, output: impl Into<String>) -> Self {
self.examples.push(Example::new(input, output));
self
}
pub fn with_template(mut self, template: impl Into<String>) -> Self {
self.template = template.into();
self
}
#[cfg(feature = "std")]
pub fn save(&self, path: &str) -> Result<()> {
let json = serde_json::to_string_pretty(self)
.map_err(|e| Error::module(format!("Failed to serialize OptimizedPrompt: {}", e)))?;
std::fs::write(path, json)?;
Ok(())
}
#[cfg(feature = "std")]
pub fn load(path: &str) -> Result<Self> {
let json = std::fs::read_to_string(path)?;
serde_json::from_str(&json)
.map_err(|e| Error::module(format!("Failed to deserialize OptimizedPrompt: {}", e)))
}
pub fn render(&self) -> String {
let mut result = format!("Signature: {}\n", self.signature);
if !self.instructions.is_empty() {
result.push_str("\nInstructions:\n");
result.push_str(&self.instructions);
result.push('\n');
}
if !self.examples.is_empty() {
result.push_str("\nExamples:\n");
for (i, ex) in self.examples.iter().enumerate() {
result.push_str(&format!(" {}. Input: {}\n", i + 1, ex.input));
result.push_str(&format!(" Output: {}\n", ex.output));
}
}
result
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefineResult {
pub output: String,
pub score: f64,
pub iterations: u32,
pub context_id: ContextId,
pub from_cache: bool,
pub prompt: Option<OptimizedPrompt>,
pub history: SmallVec<[Iteration; 8]>,
pub corrections: SmallVec<[Correction; 8]>,
pub cli_captures: SmallVec<[CliCapture; 4]>,
pub stop_reason: StopReason,
pub total_tokens: u32,
pub elapsed: Duration,
pub confidence: f64,
}
impl RefineResult {
pub fn new(output: impl Into<String>, score: f64, iterations: u32) -> Self {
Self {
output: output.into(),
score,
iterations,
context_id: ContextId::new(),
from_cache: false,
prompt: None,
history: SmallVec::new(),
corrections: SmallVec::new(),
cli_captures: SmallVec::new(),
stop_reason: StopReason::MaxIterations,
total_tokens: 0,
elapsed: Duration::ZERO,
confidence: 1.0,
}
}
pub fn cached(output: impl Into<String>, score: f64) -> Self {
Self {
output: output.into(),
score,
iterations: 0,
context_id: ContextId::new(),
from_cache: true,
prompt: None,
history: SmallVec::new(),
corrections: SmallVec::new(),
cli_captures: SmallVec::new(),
stop_reason: StopReason::TargetReached,
total_tokens: 0,
elapsed: Duration::ZERO,
confidence: 1.0,
}
}
pub fn add_iteration(&mut self, iteration: Iteration) {
self.history.push(iteration);
}
pub fn add_correction(&mut self, error: impl Into<String>, resolution: impl Into<String>) {
self.corrections.push(Correction {
error: error.into(),
resolution: resolution.into(),
iteration: self.iterations,
});
}
pub fn add_cli_capture(&mut self, capture: CliCapture) {
self.cli_captures.push(capture);
}
pub fn cli_output(&self, stage: &str) -> Option<&str> {
self.cli_captures
.iter()
.find(|c| c.stage == stage)
.map(|c| c.stdout.as_str())
}
pub fn cli_summary(&self) -> String {
if self.cli_captures.is_empty() {
return String::new();
}
let mut summary = String::from("## CLI Outputs\n\n");
for capture in &self.cli_captures {
summary.push_str(&format!("### {}\n", capture.stage));
if capture.success {
summary.push_str("**Status:** Success\n");
} else {
summary.push_str(&format!(
"**Status:** Failed (exit code: {})\n",
capture.exit_code.unwrap_or(-1)
));
}
if !capture.stdout.is_empty() {
summary.push_str("\n**stdout:**\n```\n");
summary.push_str(&capture.stdout);
summary.push_str("\n```\n");
}
if !capture.stderr.is_empty() {
summary.push_str("\n**stderr:**\n```\n");
summary.push_str(&capture.stderr);
summary.push_str("\n```\n");
}
summary.push('\n');
}
summary
}
pub fn corrections_table(&self) -> String {
if self.corrections.is_empty() {
return String::new();
}
let mut table = String::from("| Iteration | Error | Resolution |\n");
table.push_str("|-----------|-------|------------|\n");
for correction in &self.corrections {
table.push_str(&format!(
"| {} | {} | {} |\n",
correction.iteration,
correction.error.replace('|', "\\|"),
correction.resolution.replace('|', "\\|")
));
}
table
}
pub fn corrections_markdown(&self) -> String {
if self.corrections.is_empty() {
return String::from("No corrections were needed.\n");
}
let mut md = String::from("## Corrections Made\n\n");
for (i, correction) in self.corrections.iter().enumerate() {
md.push_str(&format!(
"### Correction {} (Iteration {})\n\n",
i + 1,
correction.iteration
));
md.push_str(&format!("**Error:** {}\n\n", correction.error));
md.push_str(&format!("**Resolution:** {}\n\n", correction.resolution));
}
md
}
pub fn is_success(&self) -> bool {
self.score >= 1.0 - f64::EPSILON
}
pub fn improvement(&self) -> f64 {
if self.history.len() < 2 {
return 0.0;
}
let first = self.history.first().map(|h| h.score).unwrap_or(0.0);
self.score - first
}
#[cfg(feature = "std")]
pub fn save(&self, path: &str) -> Result<()> {
let json = serde_json::to_string_pretty(self)
.map_err(|e| Error::module(format!("Failed to serialize RefineResult: {}", e)))?;
std::fs::write(path, json)?;
Ok(())
}
#[cfg(feature = "std")]
pub fn load(path: &str) -> Result<Self> {
let json = std::fs::read_to_string(path)?;
serde_json::from_str(&json)
.map_err(|e| Error::module(format!("Failed to deserialize RefineResult: {}", e)))
}
pub fn to_json(&self) -> Result<String> {
serde_json::to_string_pretty(self)
.map_err(|e| Error::module(format!("Failed to serialize RefineResult: {}", e)))
}
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| Error::module(format!("Failed to deserialize RefineResult: {}", e)))
}
}
impl fmt::Display for RefineResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "RefineResult {{")?;
writeln!(f, " score: {:.2}", self.score)?;
writeln!(f, " confidence: {:.2}", self.confidence)?;
writeln!(f, " iterations: {}", self.iterations)?;
writeln!(f, " stop_reason: {}", self.stop_reason)?;
writeln!(f, " total_tokens: {}", self.total_tokens)?;
writeln!(f, " elapsed: {:?}", self.elapsed)?;
writeln!(f, " context_id: {}", self.context_id)?;
writeln!(f, " from_cache: {}", self.from_cache)?;
writeln!(f, " corrections: {}", self.corrections.len())?;
writeln!(
f,
" output: {}...",
&self.output[..self.output.len().min(50)]
)?;
write!(f, "}}")
}
}
pub struct Compiled<L: Llm, V: Validate> {
llm: L,
validator: V,
prompt: OptimizedPrompt,
target_score: f64,
}
impl<L: Llm, V: Validate> Compiled<L, V> {
pub fn new(llm: L, validator: V, prompt: OptimizedPrompt) -> Self {
Self {
llm,
validator,
prompt,
target_score: 1.0,
}
}
pub fn with_target(mut self, score: f64) -> Self {
self.target_score = score;
self
}
pub fn prompt(&self) -> &OptimizedPrompt {
&self.prompt
}
pub fn validator(&self) -> &V {
&self.validator
}
pub async fn predict(&self, input: &str) -> String {
let mut context = String::new();
for ex in &self.prompt.examples {
context.push_str(&format!("Input: {}\nOutput: {}\n\n", ex.input, ex.output));
}
match self.llm.generate(input, &context, None).await {
Ok(output) => output.text.to_string(),
Err(_) => String::new(),
}
}
pub async fn predict_scored(&self, input: &str) -> (String, f64) {
let output = self.predict(input).await;
let score = self.validator.validate(&output);
(output, score.value)
}
#[cfg(feature = "std")]
pub fn save(&self, path: &str) -> Result<()> {
use std::fs::File;
use std::io::Write;
let content = self.prompt.render();
let mut file = File::create(path)?;
file.write_all(content.as_bytes())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_id() {
let id1 = ContextId::new();
let id2 = ContextId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_refine_result_new() {
let result = RefineResult::new("output", 0.8, 3);
assert_eq!(result.output, "output");
assert!((result.score - 0.8).abs() < f64::EPSILON);
assert_eq!(result.iterations, 3);
assert!(!result.from_cache);
}
#[test]
fn test_refine_result_cached() {
let result = RefineResult::cached("cached output", 1.0);
assert!(result.from_cache);
assert_eq!(result.iterations, 0);
}
#[test]
fn test_add_iteration() {
let mut result = RefineResult::new("output", 0.8, 1);
result.add_iteration(Iteration {
number: 0,
output: "first".to_string(),
score: 0.5,
feedback: Some("improve".to_string()),
});
assert_eq!(result.history.len(), 1);
assert_eq!(result.history[0].score, 0.5);
}
#[test]
fn test_add_correction() {
let mut result = RefineResult::new("output", 0.8, 1);
result.add_correction("missing return type", "added -> i32");
assert_eq!(result.corrections.len(), 1);
assert_eq!(result.corrections[0].error, "missing return type");
}
#[test]
fn test_corrections_table() {
let mut result = RefineResult::new("output", 1.0, 2);
result.add_correction("error 1", "fix 1");
result.add_correction("error 2", "fix 2");
let table = result.corrections_table();
assert!(table.contains("Iteration"));
assert!(table.contains("error 1"));
assert!(table.contains("fix 2"));
}
#[test]
fn test_cli_summary() {
let mut result = RefineResult::new("output", 1.0, 1);
result.add_cli_capture(CliCapture {
stage: "compile".to_string(),
command: "rustc".to_string(),
stdout: "compiled".to_string(),
stderr: String::new(),
success: true,
exit_code: Some(0),
duration_ms: 100,
});
let summary = result.cli_summary();
assert!(summary.contains("compile"));
assert!(summary.contains("Success"));
}
#[test]
fn test_optimized_prompt() {
let prompt = OptimizedPrompt::new("question -> answer")
.with_instructions("Be concise")
.with_example("What is 2+2?", "4")
.with_example("What is the capital of France?", "Paris");
assert_eq!(prompt.signature, "question -> answer");
assert_eq!(prompt.examples.len(), 2);
let rendered = prompt.render();
assert!(rendered.contains("question -> answer"));
assert!(rendered.contains("Be concise"));
assert!(rendered.contains("2+2"));
}
#[test]
fn test_is_success() {
let success = RefineResult::new("output", 1.0, 1);
assert!(success.is_success());
let partial = RefineResult::new("output", 0.8, 1);
assert!(!partial.is_success());
}
#[test]
fn test_improvement() {
let mut result = RefineResult::new("output", 0.9, 2);
result.add_iteration(Iteration {
number: 0,
output: "first".to_string(),
score: 0.3,
feedback: None,
});
result.add_iteration(Iteration {
number: 1,
output: "second".to_string(),
score: 0.9,
feedback: None,
});
assert!((result.improvement() - 0.6).abs() < 0.01);
}
#[test]
fn test_refine_result_json_roundtrip() {
let mut result = RefineResult::new("test output", 0.95, 2);
result.stop_reason = StopReason::TargetReached;
result.total_tokens = 500;
result.confidence = 0.9;
result.add_iteration(Iteration {
number: 0,
output: "first attempt".to_string(),
score: 0.5,
feedback: Some("needs improvement".to_string()),
});
result.add_iteration(Iteration {
number: 1,
output: "test output".to_string(),
score: 0.95,
feedback: None,
});
result.add_correction("missing semicolon", "added semicolon at line 5");
result.prompt = Some(
OptimizedPrompt::new("question -> answer")
.with_instructions("Be concise")
.with_example("Q1", "A1"),
);
let json = result.to_json().unwrap();
let loaded = RefineResult::from_json(&json).unwrap();
assert_eq!(loaded.output, "test output");
assert!((loaded.score - 0.95).abs() < f64::EPSILON);
assert_eq!(loaded.iterations, 2);
assert_eq!(loaded.stop_reason, StopReason::TargetReached);
assert_eq!(loaded.total_tokens, 500);
assert_eq!(loaded.history.len(), 2);
assert_eq!(loaded.corrections.len(), 1);
assert!(loaded.prompt.is_some());
assert_eq!(loaded.prompt.unwrap().signature, "question -> answer");
}
#[test]
fn test_refine_result_save_load() {
let result = RefineResult::new("saved output", 1.0, 1);
let path = std::env::temp_dir().join("kkachi_test_result.json");
let path_str = path.to_str().unwrap();
result.save(path_str).unwrap();
let loaded = RefineResult::load(path_str).unwrap();
assert_eq!(loaded.output, "saved output");
assert!((loaded.score - 1.0).abs() < f64::EPSILON);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_optimized_prompt_json_roundtrip() {
let prompt = OptimizedPrompt::new("input -> output")
.with_instructions("Follow the format")
.with_example("hello", "world")
.with_example("foo", "bar")
.with_template("Template: {{ input }}");
let json = serde_json::to_string(&prompt).unwrap();
let loaded: OptimizedPrompt = serde_json::from_str(&json).unwrap();
assert_eq!(loaded.signature, "input -> output");
assert_eq!(loaded.instructions, "Follow the format");
assert_eq!(loaded.examples.len(), 2);
assert_eq!(loaded.template, "Template: {{ input }}");
}
#[test]
fn test_refine_event_serialize() {
let event = RefineEvent::IterationComplete {
iteration: 2,
score: 0.8,
output: "some output".to_string(),
feedback: Some("try harder".to_string()),
};
let json = serde_json::to_string(&event).unwrap();
let loaded: RefineEvent = serde_json::from_str(&json).unwrap();
if let RefineEvent::IterationComplete {
iteration, score, ..
} = loaded
{
assert_eq!(iteration, 2);
assert!((score - 0.8).abs() < f64::EPSILON);
} else {
panic!("Wrong variant");
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RefineEvent {
IterationStart {
iteration: u32,
},
IterationComplete {
iteration: u32,
score: f64,
output: String,
feedback: Option<String>,
},
Complete(Box<RefineResult>),
Error(String),
}