use crate::agent::Payload;
use std::fmt;
pub trait Expandable {
fn expand(&self) -> Payload;
}
pub trait Selectable: Expandable {
fn selection_id(&self) -> &str;
fn description(&self) -> &str;
}
pub struct SelectionRegistry<T: Selectable> {
items: Vec<T>,
}
impl<T: Selectable> SelectionRegistry<T> {
pub fn new() -> Self {
Self { items: Vec::new() }
}
pub fn register(&mut self, item: T) -> &mut Self {
let id = item.selection_id();
if self.items.iter().any(|i| i.selection_id() == id) {
panic!("Item with id '{}' is already registered", id);
}
self.items.push(item);
self
}
pub fn try_register(&mut self, item: T) -> Result<&mut Self, RegistryError> {
let id = item.selection_id().to_string();
if self.items.iter().any(|i| i.selection_id() == id) {
return Err(RegistryError::DuplicateId { id });
}
self.items.push(item);
Ok(self)
}
pub fn get(&self, id: &str) -> Option<&T> {
self.items.iter().find(|item| item.selection_id() == id)
}
pub fn items(&self) -> &[T] {
&self.items
}
pub fn to_prompt_section(&self) -> String {
self.to_prompt_section_with_title("Available Actions")
}
pub fn to_prompt_section_with_title(&self, title: &str) -> String {
let mut output = format!("## {}\n\n", title);
for item in &self.items {
output.push_str(&format!(
"- `{}`: {}\n",
item.selection_id(),
item.description()
));
}
output
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn len(&self) -> usize {
self.items.len()
}
}
impl<T: Selectable> Default for SelectionRegistry<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Selectable> fmt::Debug for SelectionRegistry<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SelectionRegistry")
.field("items", &self.items)
.finish()
}
}
#[derive(Debug, thiserror::Error)]
pub enum RegistryError {
#[error("Item with id '{id}' is already registered")]
DuplicateId { id: String },
#[error("Item with id '{id}' not found in registry")]
NotFound { id: String },
}
#[derive(Debug, thiserror::Error)]
pub enum ReActError {
#[error("Agent error: {0}")]
Agent(#[from] crate::agent::AgentError),
#[error("Selection not found: {0}")]
SelectionNotFound(String),
#[error("Max iterations ({0}) reached without completion")]
MaxIterationsReached(usize),
#[error("Failed to extract selection from response: {0}")]
ExtractionFailed(String),
}
#[derive(Debug, Clone, PartialEq)]
pub enum ReActResult {
Complete(String),
Continue { context: String },
}
#[derive(Debug, Clone)]
pub struct ReActConfig {
pub max_iterations: usize,
pub include_selection_prompt: bool,
pub completion_marker: String,
pub accumulate_results: bool,
}
impl Default for ReActConfig {
fn default() -> Self {
Self {
max_iterations: 10,
include_selection_prompt: true,
completion_marker: "DONE".to_string(),
accumulate_results: true,
}
}
}
impl ReActConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub fn with_include_selection_prompt(mut self, include: bool) -> Self {
self.include_selection_prompt = include;
self
}
pub fn with_completion_marker(mut self, marker: impl Into<String>) -> Self {
self.completion_marker = marker.into();
self
}
pub fn with_accumulate_results(mut self, accumulate: bool) -> Self {
self.accumulate_results = accumulate;
self
}
}
pub async fn react_loop<T, A, F>(
agent: &A,
registry: &SelectionRegistry<T>,
initial_task: impl Into<Payload>,
selector: F,
config: ReActConfig,
) -> Result<String, ReActError>
where
T: Selectable + Clone,
A: crate::agent::Agent<Output = String>,
F: Fn(&str) -> Result<Option<String>, ReActError>,
{
let mut context = initial_task.into().to_text();
for _iteration in 0..config.max_iterations {
let mut prompt = String::new();
if config.include_selection_prompt {
prompt.push_str(®istry.to_prompt_section());
prompt.push_str("\n\n");
}
prompt.push_str(&format!(
"Current context:\n{}\n\nSelect an action or respond with '{}' if the task is complete.",
context, config.completion_marker
));
let response = agent.execute(Payload::from(prompt)).await?;
match selector(&response)? {
None => {
return Ok(response);
}
Some(action_id) => {
let item = registry
.get(&action_id)
.ok_or_else(|| ReActError::SelectionNotFound(action_id.clone()))?;
let expanded = item.expand();
let result = agent.execute(expanded).await?;
if config.accumulate_results {
context = format!("{}\n\n[Action: {}]\nResult: {}", context, action_id, result);
} else {
context = result;
}
}
}
}
Err(ReActError::MaxIterationsReached(config.max_iterations))
}
pub fn simple_tag_selector(
tag: &'static str,
completion_marker: &'static str,
) -> impl Fn(&str) -> Result<Option<String>, ReActError> {
move |response: &str| {
if response.contains(completion_marker) {
return Ok(None);
}
use crate::extract::FlexibleExtractor;
use crate::extract::core::ContentExtractor;
let extractor = FlexibleExtractor::new();
if let Some(action_id) = extractor.extract_tagged(response, tag) {
Ok(Some(action_id))
} else {
Err(ReActError::ExtractionFailed(format!(
"No <{}> tag or '{}' found in response",
tag, completion_marker
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq)]
enum TestAction {
Greet { name: String },
Calculate { expr: String },
}
impl Expandable for TestAction {
fn expand(&self) -> Payload {
match self {
TestAction::Greet { name } => Payload::from(format!("Say hello to {}", name)),
TestAction::Calculate { expr } => Payload::from(format!("Calculate: {}", expr)),
}
}
}
impl Selectable for TestAction {
fn selection_id(&self) -> &str {
match self {
TestAction::Greet { .. } => "greet",
TestAction::Calculate { .. } => "calculate",
}
}
fn description(&self) -> &str {
match self {
TestAction::Greet { .. } => "Greet a person by name",
TestAction::Calculate { .. } => "Perform a calculation",
}
}
}
#[test]
fn test_expandable() {
let action = TestAction::Greet {
name: "Alice".to_string(),
};
let payload = action.expand();
assert_eq!(payload.to_text(), "Say hello to Alice");
}
#[test]
fn test_selectable() {
let action = TestAction::Greet {
name: "Bob".to_string(),
};
assert_eq!(action.selection_id(), "greet");
assert_eq!(action.description(), "Greet a person by name");
}
#[test]
fn test_registry_basic() {
let mut registry = SelectionRegistry::new();
registry.register(TestAction::Greet {
name: "Charlie".to_string(),
});
registry.register(TestAction::Calculate {
expr: "2+2".to_string(),
});
assert_eq!(registry.len(), 2);
assert!(!registry.is_empty());
let greet = registry.get("greet").unwrap();
assert_eq!(greet.selection_id(), "greet");
}
#[test]
fn test_registry_to_prompt_section() {
let mut registry = SelectionRegistry::new();
registry.register(TestAction::Greet {
name: "Dave".to_string(),
});
registry.register(TestAction::Calculate {
expr: "5*5".to_string(),
});
let section = registry.to_prompt_section();
assert!(section.contains("## Available Actions"));
assert!(section.contains("- `greet`: Greet a person by name"));
assert!(section.contains("- `calculate`: Perform a calculation"));
}
#[test]
#[should_panic(expected = "already registered")]
fn test_registry_duplicate_panic() {
let mut registry = SelectionRegistry::new();
registry.register(TestAction::Greet {
name: "Eve".to_string(),
});
registry.register(TestAction::Greet {
name: "Frank".to_string(),
});
}
#[test]
fn test_registry_try_register_duplicate() {
let mut registry = SelectionRegistry::new();
registry
.try_register(TestAction::Greet {
name: "Grace".to_string(),
})
.unwrap();
let result = registry.try_register(TestAction::Greet {
name: "Heidi".to_string(),
});
assert!(matches!(result, Err(RegistryError::DuplicateId { .. })));
}
}