use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use traitclaw_core::{Error, Result};
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type BoundAgent = Arc<dyn Fn(String) -> BoxFuture<'static, Result<String>> + Send + Sync>;
pub struct TeamRunner {
agents: HashMap<String, BoundAgent>,
sequence: Vec<String>,
max_iterations: usize,
}
impl TeamRunner {
#[must_use]
pub fn new(max_iterations: usize) -> Self {
assert!(max_iterations > 0, "max_iterations must be > 0");
Self {
agents: HashMap::new(),
sequence: Vec::new(),
max_iterations,
}
}
pub fn bind<F, Fut>(&mut self, role: impl Into<String>, agent: F)
where
F: Fn(String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<String>> + Send + 'static,
{
let name = role.into();
self.agents.insert(
name,
Arc::new(move |input: String| Box::pin(agent(input)) as BoxFuture<_>),
);
}
pub fn set_sequence(&mut self, roles: &[&str]) {
self.sequence = roles.iter().map(|&r| r.to_string()).collect();
}
#[must_use]
pub fn is_bound(&self, role: &str) -> bool {
self.agents.contains_key(role)
}
pub async fn run(&self, input: &str) -> Result<String> {
if self.sequence.is_empty() {
return Err(Error::Runtime("TeamRunner has no sequence defined".into()));
}
let mut current_input = input.to_string();
let mut iterations = 0;
for role in &self.sequence {
if iterations >= self.max_iterations {
return Err(Error::Runtime(format!(
"TeamRunner exceeded max_iterations ({}) at role '{}'",
self.max_iterations, role
)));
}
let agent = self.agents.get(role).ok_or_else(|| {
Error::Runtime(format!("Role '{}' not bound in TeamRunner", role))
})?;
current_input = agent(current_input).await?;
iterations += 1;
}
Ok(current_input)
}
}
pub async fn run_verification_chain<G, GFut, V, VFut>(
initial_input: &str,
max_retries: usize,
generator: G,
verifier: V,
) -> Result<String>
where
G: Fn(String) -> GFut,
GFut: Future<Output = Result<String>>,
V: Fn(String) -> VFut,
VFut: Future<Output = std::result::Result<String, String>>,
{
let mut prompt = initial_input.to_string();
let mut last_output = String::new();
for attempt in 0..=max_retries {
let output = generator(prompt.clone()).await?;
last_output = output.clone();
match verifier(output.clone()).await {
Ok(accepted) => return Ok(accepted),
Err(feedback) => {
if attempt == max_retries {
return Err(Error::Runtime(format!(
"VerificationChain exhausted {max_retries} retries. Last output: {last_output}. Last feedback: {feedback}"
)));
}
prompt = format!("{initial_input}\n\nPrevious attempt:\n{output}\n\nFeedback: {feedback}\n\nPlease improve.");
}
}
}
Ok(last_output)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_sequential_team_two_agents() {
let mut runner = TeamRunner::new(10);
runner.bind("researcher", |input: String| async move {
Ok(format!("Research: {input}"))
});
runner.bind("writer", |input: String| async move {
Ok(format!("Written: {input}"))
});
runner.set_sequence(&["researcher", "writer"]);
let result = runner.run("AI history").await.unwrap();
assert!(
result.starts_with("Written: Research: AI history"),
"got: {result}"
);
}
#[tokio::test]
async fn test_max_iterations_exceeded() {
let mut runner = TeamRunner::new(1); runner.bind("a", |i: String| async move { Ok(i) });
runner.bind("b", |i: String| async move { Ok(i) });
runner.set_sequence(&["a", "b"]);
let result = runner.run("test").await;
assert!(
result.is_err(),
"expected error for max_iterations exceeded"
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("max_iterations"),
"error should mention max_iterations: {msg}"
);
}
#[tokio::test]
async fn test_unbound_role_returns_error() {
let mut runner = TeamRunner::new(10);
runner.set_sequence(&["missing_role"]);
let result = runner.run("test").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_empty_sequence_returns_error() {
let runner = TeamRunner::new(10);
let result = runner.run("test").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("no sequence"));
}
#[tokio::test]
async fn test_verification_chain_accepts_on_second_try() {
let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let attempt_clone = attempt_count.clone();
let verify_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let verify_clone = verify_count.clone();
let result = run_verification_chain(
"Write something",
3,
move |_input: String| {
let c = attempt_clone.clone();
async move {
let n = c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(format!("generated-attempt-{n}"))
}
},
move |output: String| {
let v = verify_clone.clone();
async move {
let n = v.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n == 0 {
Err("Too brief".to_string())
} else {
Ok(output) }
}
},
)
.await;
assert!(
result.is_ok(),
"expected success on retry, got: {:?}",
result
);
let output = result.unwrap();
assert!(
output.contains("attempt-1"),
"expected 2nd attempt output, got: {output}"
);
}
#[tokio::test]
async fn test_verification_chain_all_retries_exhausted() {
let result = run_verification_chain(
"Write something",
2, |input: String| async move { Ok(format!("draft: {input}")) },
|output: String| async move { Err(format!("not good enough: {output}")) },
)
.await;
assert!(result.is_err(), "expected error when all retries exhausted");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("exhausted"),
"error should mention exhausted: {msg}"
);
}
#[tokio::test]
async fn test_verification_chain_accepts_immediately() {
let result = run_verification_chain(
"input",
3,
|_| async { Ok("perfect output".to_string()) },
|output: String| async move { Ok(output) },
)
.await;
assert_eq!(result.unwrap(), "perfect output");
}
#[tokio::test]
async fn test_verification_chain_feedback_included_in_retry() {
let got_feedback = std::sync::Arc::new(std::sync::Mutex::new(false));
let got_feedback_clone = got_feedback.clone();
let _ = run_verification_chain(
"Write",
1,
move |input: String| {
let f = got_feedback_clone.clone();
async move {
if input.contains("Feedback:") {
*f.lock().unwrap() = true;
}
Ok(format!("output: {input}"))
}
},
|_| async move { Err("needs work".to_string()) },
)
.await;
assert!(
*got_feedback.lock().unwrap(),
"retry prompt should contain 'Feedback:'"
);
}
}