use crate::callable::{Callable, CallableInvoker, DynCallable};
use crate::graph::CheckpointStore;
use crate::kernel::ids::{StepId, StepSourceType};
use crate::policy::LongRunningExecutionPolicy;
use crate::runner::Runner;
use crate::streaming::StreamEvent;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::time::{Duration, Instant};
type InvokerWorkItem = (String, String, u32, Option<StepId>, Option<String>);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveredStep {
#[serde(default)]
pub name: Option<String>,
pub input: String,
#[serde(default)]
pub reason: Option<String>,
#[serde(default)]
pub priority: u8,
}
impl DiscoveredStep {
pub fn new(input: impl Into<String>) -> Self {
Self {
name: None,
input: input.into(),
reason: None,
priority: 50,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
self.reason = Some(reason.into());
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
self.priority = priority;
self
}
}
#[derive(Debug, Clone)]
pub struct DiscoveryOutput {
pub result: String,
pub discovered_steps: Vec<DiscoveredStep>,
pub is_complete: bool,
}
impl DiscoveryOutput {
pub fn parse(output: &str) -> Self {
if let Some(parsed) = Self::try_parse_json(output) {
return parsed;
}
let is_complete = output.ends_with("[DONE]")
|| output.ends_with("[COMPLETE]")
|| output.contains("\"status\": \"complete\"")
|| output.contains("\"status\":\"complete\"");
let result = output
.trim_end_matches("[DONE]")
.trim_end_matches("[COMPLETE]")
.trim()
.to_string();
Self {
result,
discovered_steps: Vec::new(),
is_complete,
}
}
fn try_parse_json(output: &str) -> Option<Self> {
let trimmed = output.trim();
let json_str = if trimmed.starts_with('{') {
trimmed
} else if let Some(start) = trimmed.find("```json") {
let content_start = trimmed[start..].find('\n').map(|i| start + i + 1)?;
let content_end = trimmed[content_start..].find("```")?;
&trimmed[content_start..content_start + content_end]
} else if let Some(start) = trimmed.find('{') {
let end = trimmed.rfind('}')?;
if end > start {
&trimmed[start..=end]
} else {
return None;
}
} else {
return None;
};
#[derive(Deserialize)]
struct DiscoveryOutputJson {
#[serde(default)]
result: Option<String>,
#[serde(default)]
output: Option<String>,
#[serde(default)]
discovered_steps: Vec<DiscoveredStep>,
#[serde(default)]
status: Option<String>,
}
let parsed: DiscoveryOutputJson = serde_json::from_str(json_str).ok()?;
let result = parsed
.result
.or(parsed.output)
.unwrap_or_else(|| json_str.to_string());
let is_complete = parsed
.status
.map(|s| s == "complete" || s == "done")
.unwrap_or(false);
Some(Self {
result,
discovered_steps: parsed.discovered_steps,
is_complete,
})
}
pub fn has_pending_work(&self) -> bool {
!self.discovered_steps.is_empty() && !self.is_complete
}
}
#[derive(Debug, Clone)]
pub struct AgenticLoopResult {
pub output: String,
pub steps_executed: u32,
pub discovered_steps_processed: u32,
pub max_depth_reached: u32,
pub completed: bool,
pub stop_reason: Option<String>,
pub history: Vec<String>,
}
pub struct AgenticLoop;
impl AgenticLoop {
pub async fn run<S: CheckpointStore>(
runner: &mut Runner<S>,
callable: DynCallable,
input: String,
policy: LongRunningExecutionPolicy,
) -> anyhow::Result<String> {
let result = Self::run_with_details(runner, callable, input, policy).await?;
Ok(result.output)
}
pub async fn run_with_details<S: CheckpointStore>(
runner: &mut Runner<S>,
callable: DynCallable,
input: String,
policy: LongRunningExecutionPolicy,
) -> anyhow::Result<AgenticLoopResult> {
let start_time = Instant::now();
let mut steps_executed: u32 = 0;
let mut discovered_steps_processed: u32 = 0;
let mut max_depth_reached: u32 = 0;
let mut history: Vec<String> = Vec::new();
let mut work_queue: VecDeque<(String, u32, Option<StepId>, Option<String>)> =
VecDeque::new();
work_queue.push_back((input.clone(), 0, None, None));
let mut last_output = String::new();
history.push(format!("User: {}", input));
while let Some((current_input, depth, triggered_by, reason)) = work_queue.pop_front() {
if depth > max_depth_reached {
max_depth_reached = depth;
}
if let Some(max_steps) = policy.max_discovered_steps {
if steps_executed >= max_steps {
runner.emitter().emit(StreamEvent::execution_failed(
runner.execution_id(),
crate::kernel::ExecutionError::quota_exceeded(format!(
"Max discovered steps exceeded: {} >= {}",
steps_executed, max_steps
)),
));
return Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: false,
stop_reason: Some("max_discovered_steps".to_string()),
history,
});
}
}
if let Some(max_depth) = policy.max_discovery_depth {
if depth > max_depth {
runner.emitter().emit(StreamEvent::execution_failed(
runner.execution_id(),
crate::kernel::ExecutionError::quota_exceeded(format!(
"Max discovery depth exceeded: {} > {}",
depth, max_depth
)),
));
return Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: false,
stop_reason: Some("max_discovery_depth".to_string()),
history,
});
}
}
if let Some(timeout) = policy.idle_timeout_seconds {
if start_time.elapsed() > Duration::from_secs(timeout) {
runner.emitter().emit(StreamEvent::execution_failed(
runner.execution_id(),
crate::kernel::ExecutionError::timeout(format!(
"Idle timeout after {}s",
timeout
)),
));
return Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: false,
stop_reason: Some("idle_timeout".to_string()),
history,
});
}
}
let step_id = StepId::new();
if triggered_by.is_some() {
runner.emitter().emit(StreamEvent::step_discovered(
runner.execution_id(),
&step_id,
triggered_by.as_ref(),
StepSourceType::Discovered,
reason.as_deref().unwrap_or("Discovered by previous step"),
depth,
));
discovered_steps_processed += 1;
}
let result = runner
.run_callable(callable.as_ref() as &dyn Callable, ¤t_input)
.await;
match result {
Ok(output) => {
steps_executed += 1;
history.push(format!("Assistant [depth={}]: {}", depth, &output));
let should_checkpoint = policy.checkpointing.on_discovery
|| policy
.checkpointing
.interval_steps
.is_some_and(|i| steps_executed.is_multiple_of(i));
if should_checkpoint {
let state = crate::graph::NodeState::from_string(&output);
if let Err(e) = runner
.save_checkpoint(state, Some(callable.name()), Some(callable.name()))
.await
{
tracing::warn!("Failed to save checkpoint: {}", e);
}
}
let discovery = DiscoveryOutput::parse(&output);
last_output = discovery.result.clone();
if discovery.is_complete {
tracing::debug!(steps = steps_executed, "Callable signaled completion");
return Ok(AgenticLoopResult {
output: discovery.result,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: true,
stop_reason: None,
history,
});
}
if !discovery.discovered_steps.is_empty() {
tracing::debug!(
count = discovery.discovered_steps.len(),
"Discovered new steps"
);
let mut sorted_steps = discovery.discovered_steps;
sorted_steps.sort_by(|a, b| b.priority.cmp(&a.priority));
for discovered in sorted_steps {
work_queue.push_back((
discovered.input,
depth + 1,
Some(step_id.clone()),
discovered.reason,
));
}
}
}
Err(e) => {
runner.emitter().emit(StreamEvent::execution_failed(
runner.execution_id(),
crate::kernel::ExecutionError::kernel_internal(e.to_string()),
));
return Err(e);
}
}
}
Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: true,
stop_reason: None,
history,
})
}
pub async fn run_with_invoker<S: CheckpointStore>(
runner: &mut Runner<S>,
invoker: &CallableInvoker,
initial_callable_name: &str,
input: String,
policy: LongRunningExecutionPolicy,
) -> anyhow::Result<AgenticLoopResult> {
let start_time = Instant::now();
let mut steps_executed: u32 = 0;
let mut discovered_steps_processed: u32 = 0;
let mut max_depth_reached: u32 = 0;
let mut history: Vec<String> = Vec::new();
let mut work_queue: VecDeque<InvokerWorkItem> = VecDeque::new();
work_queue.push_back((
initial_callable_name.to_string(),
input.clone(),
0,
None,
None,
));
let mut last_output = String::new();
history.push(format!("User: {}", input));
while let Some((callable_name, current_input, depth, triggered_by, reason)) =
work_queue.pop_front()
{
if depth > max_depth_reached {
max_depth_reached = depth;
}
if let Some(max_steps) = policy.max_discovered_steps {
if steps_executed >= max_steps {
return Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: false,
stop_reason: Some("max_discovered_steps".to_string()),
history,
});
}
}
if let Some(max_depth) = policy.max_discovery_depth {
if depth > max_depth {
return Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: false,
stop_reason: Some("max_discovery_depth".to_string()),
history,
});
}
}
if let Some(timeout) = policy.idle_timeout_seconds {
if start_time.elapsed() > Duration::from_secs(timeout) {
return Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: false,
stop_reason: Some("idle_timeout".to_string()),
history,
});
}
}
let step_id = StepId::new();
if triggered_by.is_some() {
runner.emitter().emit(StreamEvent::step_discovered(
runner.execution_id(),
&step_id,
triggered_by.as_ref(),
StepSourceType::Discovered,
reason.as_deref().unwrap_or("Discovered by previous step"),
depth,
));
discovered_steps_processed += 1;
}
let callable = invoker.get(&callable_name).ok_or_else(|| {
anyhow::anyhow!("Callable '{}' not found in registry", callable_name)
})?;
let result = runner
.run_callable(callable.as_ref() as &dyn Callable, ¤t_input)
.await;
match result {
Ok(output) => {
steps_executed += 1;
history.push(format!("{}[depth={}]: {}", callable_name, depth, &output));
let should_checkpoint = policy.checkpointing.on_discovery
|| policy
.checkpointing
.interval_steps
.is_some_and(|i| steps_executed.is_multiple_of(i));
if should_checkpoint {
let state = crate::graph::NodeState::from_string(&output);
if let Err(e) = runner
.save_checkpoint(state, Some(&callable_name), Some(&callable_name))
.await
{
tracing::warn!("Failed to save checkpoint: {}", e);
}
}
let discovery = DiscoveryOutput::parse(&output);
last_output = discovery.result.clone();
if discovery.is_complete {
return Ok(AgenticLoopResult {
output: discovery.result,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: true,
stop_reason: None,
history,
});
}
let mut sorted_steps = discovery.discovered_steps;
sorted_steps.sort_by(|a, b| b.priority.cmp(&a.priority));
for discovered in sorted_steps {
let target_callable =
discovered.name.unwrap_or_else(|| callable_name.clone());
work_queue.push_back((
target_callable,
discovered.input,
depth + 1,
Some(step_id.clone()),
discovered.reason,
));
}
}
Err(e) => {
runner.emitter().emit(StreamEvent::execution_failed(
runner.execution_id(),
crate::kernel::ExecutionError::kernel_internal(e.to_string()),
));
return Err(e);
}
}
}
Ok(AgenticLoopResult {
output: last_output,
steps_executed,
discovered_steps_processed,
max_depth_reached,
completed: true,
stop_reason: None,
history,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::callable::Callable;
use crate::graph::InMemoryCheckpointStore;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
struct DiscoveryCallable {
name: String,
discover_count: AtomicU32,
}
impl DiscoveryCallable {
fn new(name: &str, discover_count: u32) -> Self {
Self {
name: name.to_string(),
discover_count: AtomicU32::new(discover_count),
}
}
}
#[async_trait]
impl Callable for DiscoveryCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
let remaining = self.discover_count.fetch_sub(1, Ordering::SeqCst);
if remaining > 0 {
Ok(format!(
r#"{{
"result": "Processed: {}",
"discovered_steps": [
{{"input": "Follow-up task {}", "reason": "Discovered work"}}
]
}}"#,
input, remaining
))
} else {
Ok(format!("Final result for: {} [DONE]", input))
}
}
}
struct SingleShotCallable {
name: String,
}
impl SingleShotCallable {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
#[async_trait]
impl Callable for SingleShotCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
Ok(format!("Processed: {}", input))
}
}
struct FailingCallable {
name: String,
fail_after: AtomicU32,
}
impl FailingCallable {
fn new(name: &str, fail_after: u32) -> Self {
Self {
name: name.to_string(),
fail_after: AtomicU32::new(fail_after),
}
}
}
#[async_trait]
impl Callable for FailingCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
let remaining = self.fail_after.fetch_sub(1, Ordering::SeqCst);
if remaining > 0 {
Ok(format!(
r#"{{
"result": "Step {}",
"discovered_steps": [{{"input": "Next step"}}]
}}"#,
input
))
} else {
anyhow::bail!("Simulated failure")
}
}
}
#[test]
fn test_parse_plain_output() {
let output = "Just a simple response";
let discovery = DiscoveryOutput::parse(output);
assert_eq!(discovery.result, "Just a simple response");
assert!(discovery.discovered_steps.is_empty());
assert!(!discovery.is_complete);
}
#[test]
fn test_parse_done_marker() {
let output = "Task completed successfully [DONE]";
let discovery = DiscoveryOutput::parse(output);
assert_eq!(discovery.result, "Task completed successfully");
assert!(discovery.discovered_steps.is_empty());
assert!(discovery.is_complete);
}
#[test]
fn test_parse_complete_marker() {
let output = "All work finished [COMPLETE]";
let discovery = DiscoveryOutput::parse(output);
assert_eq!(discovery.result, "All work finished");
assert!(discovery.is_complete);
}
#[test]
fn test_parse_json_with_discovered_steps() {
let output = r#"{
"result": "Analyzed the system",
"discovered_steps": [
{"input": "Refactor module A", "reason": "Code smell detected"},
{"input": "Add tests for B"}
]
}"#;
let discovery = DiscoveryOutput::parse(output);
assert_eq!(discovery.result, "Analyzed the system");
assert_eq!(discovery.discovered_steps.len(), 2);
assert_eq!(discovery.discovered_steps[0].input, "Refactor module A");
assert_eq!(
discovery.discovered_steps[0].reason,
Some("Code smell detected".to_string())
);
assert_eq!(discovery.discovered_steps[1].input, "Add tests for B");
assert!(!discovery.is_complete);
}
#[test]
fn test_parse_json_with_status_complete() {
let output = r#"{"result": "Done", "status": "complete"}"#;
let discovery = DiscoveryOutput::parse(output);
assert_eq!(discovery.result, "Done");
assert!(discovery.is_complete);
}
#[test]
fn test_parse_json_embedded_in_text() {
let output = r#"Here is the analysis:
{
"result": "Found issues",
"discovered_steps": [{"input": "Fix issue 1"}]
}
End of response."#;
let discovery = DiscoveryOutput::parse(output);
assert!(!discovery.discovered_steps.is_empty());
assert_eq!(discovery.discovered_steps[0].input, "Fix issue 1");
}
#[test]
fn test_has_pending_work() {
let no_steps = DiscoveryOutput {
result: "test".to_string(),
discovered_steps: vec![],
is_complete: false,
};
assert!(!no_steps.has_pending_work());
let complete = DiscoveryOutput {
result: "test".to_string(),
discovered_steps: vec![DiscoveredStep::new("task")],
is_complete: true,
};
assert!(!complete.has_pending_work());
let pending = DiscoveryOutput {
result: "test".to_string(),
discovered_steps: vec![DiscoveredStep::new("task")],
is_complete: false,
};
assert!(pending.has_pending_work());
}
#[test]
fn test_discovered_step_builder() {
let step = DiscoveredStep::new("Process data")
.with_name("data-processor")
.with_reason("Data needs processing")
.with_priority(80);
assert_eq!(step.input, "Process data");
assert_eq!(step.name, Some("data-processor".to_string()));
assert_eq!(step.reason, Some("Data needs processing".to_string()));
assert_eq!(step.priority, 80);
}
#[tokio::test]
async fn test_single_shot_execution() {
let store = Arc::new(InMemoryCheckpointStore::new());
let mut runner = Runner::new(store);
let callable: DynCallable = Arc::new(SingleShotCallable::new("single"));
let policy = LongRunningExecutionPolicy::standard();
let result = AgenticLoop::run(&mut runner, callable, "test input".to_string(), policy)
.await
.unwrap();
assert!(result.contains("Processed: test input"));
}
#[tokio::test]
async fn test_discovery_loop_multiple_steps() {
let store = Arc::new(InMemoryCheckpointStore::new());
let mut runner = Runner::new(store);
let callable: DynCallable = Arc::new(DiscoveryCallable::new("discover", 3));
let policy = LongRunningExecutionPolicy::standard();
let result = AgenticLoop::run_with_details(
&mut runner,
callable,
"initial task".to_string(),
policy,
)
.await
.unwrap();
assert!(result.steps_executed >= 3);
assert!(result.completed);
assert!(result.stop_reason.is_none());
}
#[tokio::test]
async fn test_max_steps_limit() {
let store = Arc::new(InMemoryCheckpointStore::new());
let mut runner = Runner::new(store);
let callable: DynCallable = Arc::new(DiscoveryCallable::new("discover", 100));
let policy = LongRunningExecutionPolicy {
max_discovered_steps: Some(5),
..LongRunningExecutionPolicy::standard()
};
let result =
AgenticLoop::run_with_details(&mut runner, callable, "task".to_string(), policy)
.await
.unwrap();
assert!(result.steps_executed <= 5);
assert!(!result.completed);
assert_eq!(result.stop_reason, Some("max_discovered_steps".to_string()));
}
#[tokio::test]
async fn test_max_depth_limit() {
let store = Arc::new(InMemoryCheckpointStore::new());
let mut runner = Runner::new(store);
let callable: DynCallable = Arc::new(DiscoveryCallable::new("discover", 20));
let policy = LongRunningExecutionPolicy {
max_discovery_depth: Some(3),
max_discovered_steps: Some(100), ..LongRunningExecutionPolicy::standard()
};
let result =
AgenticLoop::run_with_details(&mut runner, callable, "task".to_string(), policy)
.await
.unwrap();
assert!(result.max_depth_reached <= 4); assert!(!result.completed);
assert_eq!(result.stop_reason, Some("max_discovery_depth".to_string()));
}
#[tokio::test]
async fn test_error_propagation() {
let store = Arc::new(InMemoryCheckpointStore::new());
let mut runner = Runner::new(store);
let callable: DynCallable = Arc::new(FailingCallable::new("failing", 2));
let policy = LongRunningExecutionPolicy::standard();
let result = AgenticLoop::run(&mut runner, callable, "task".to_string(), policy).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Simulated failure"));
}
#[tokio::test]
async fn test_history_tracking() {
let store = Arc::new(InMemoryCheckpointStore::new());
let mut runner = Runner::new(store);
let callable: DynCallable = Arc::new(DiscoveryCallable::new("discover", 2));
let policy = LongRunningExecutionPolicy::standard();
let result =
AgenticLoop::run_with_details(&mut runner, callable, "start".to_string(), policy)
.await
.unwrap();
assert!(!result.history.is_empty());
assert!(result.history[0].contains("User: start"));
}
#[tokio::test]
async fn test_priority_ordering() {
let output = r#"{
"result": "test",
"discovered_steps": [
{"input": "low priority", "priority": 10},
{"input": "high priority", "priority": 90},
{"input": "medium priority", "priority": 50}
]
}"#;
let discovery = DiscoveryOutput::parse(output);
let mut sorted = discovery.discovered_steps;
sorted.sort_by(|a, b| b.priority.cmp(&a.priority));
assert_eq!(sorted[0].input, "high priority");
assert_eq!(sorted[1].input, "medium priority");
assert_eq!(sorted[2].input, "low priority");
}
#[tokio::test]
async fn test_checkpointing_on_discovery() {
let store = Arc::new(InMemoryCheckpointStore::new());
let mut runner = Runner::new(store.clone());
let callable: DynCallable = Arc::new(DiscoveryCallable::new("discover", 2));
let policy = LongRunningExecutionPolicy {
checkpointing: crate::policy::CheckpointPolicy {
on_discovery: true,
..Default::default()
},
..LongRunningExecutionPolicy::standard()
};
let _ = AgenticLoop::run(&mut runner, callable, "task".to_string(), policy).await;
let checkpoint = store
.load_latest(runner.execution_id().as_str())
.await
.unwrap();
assert!(checkpoint.is_some());
}
}