use std::sync::Arc;
use std::time::Instant;
use swink_agent::{AgentMessage, AgentResult, ContentBlock, LlmMessage, Usage};
use tokio_util::sync::CancellationToken;
use super::events::PipelineEvent;
use super::executor::AgentFactory;
use super::output::{PipelineError, PipelineOutput, StepResult};
use super::types::{ExitCondition, PipelineId};
#[allow(clippy::too_many_arguments)]
pub(crate) async fn run_loop(
factory: &Arc<dyn AgentFactory>,
event_handler: &Option<Arc<dyn Fn(PipelineEvent) + Send + Sync>>,
id: PipelineId,
_name: String,
body: String,
exit_condition: ExitCondition,
max_iterations: usize,
input: String,
cancellation_token: CancellationToken,
) -> Result<PipelineOutput, PipelineError> {
let pipeline_start = Instant::now();
let mut steps: Vec<StepResult> = Vec::new();
let mut total_usage = Usage::default();
let mut accumulated_responses: Vec<String> = Vec::new();
for iteration in 0..max_iterations {
if cancellation_token.is_cancelled() {
return Err(PipelineError::Cancelled);
}
if let Some(handler) = event_handler {
handler(PipelineEvent::StepStarted {
pipeline_id: id.clone(),
step_index: iteration,
agent_name: body.clone(),
});
}
let step_start = Instant::now();
let mut agent = factory.create(&body)?;
let mut messages = Vec::new();
if accumulated_responses.is_empty() {
messages.push(make_user_message(&input));
} else {
let context = format!(
"{}\n\nPrevious iterations:\n{}",
input,
accumulated_responses
.iter()
.enumerate()
.map(|(i, r)| format!("Iteration {}: {}", i + 1, r))
.collect::<Vec<_>>()
.join("\n")
);
messages.push(make_user_message(&context));
}
let result = agent
.prompt_async(messages)
.await
.map_err(|e| PipelineError::StepFailed {
step_index: iteration,
agent_name: body.clone(),
source: Box::new(e),
})?;
let step_duration = step_start.elapsed();
let response_text = extract_text(&result);
total_usage.merge(&result.usage);
let step = StepResult {
agent_name: body.clone(),
response: response_text.clone(),
duration: step_duration,
usage: result.usage.clone(),
};
steps.push(step);
if let Some(handler) = event_handler {
handler(PipelineEvent::StepCompleted {
pipeline_id: id.clone(),
step_index: iteration,
agent_name: body.clone(),
duration: step_duration,
usage: result.usage.clone(),
});
}
accumulated_responses.push(response_text.clone());
let should_exit = match &exit_condition {
ExitCondition::ToolCalled { tool_name } => check_tool_called(&result, tool_name),
ExitCondition::OutputContains { compiled, .. } => compiled.is_match(&response_text),
ExitCondition::MaxIterations => false, };
if should_exit {
let total_duration = pipeline_start.elapsed();
if let Some(handler) = event_handler {
handler(PipelineEvent::Completed {
pipeline_id: id.clone(),
total_duration,
total_usage: total_usage.clone(),
});
}
return Ok(PipelineOutput {
pipeline_id: id,
final_response: response_text,
steps,
total_duration,
total_usage,
});
}
}
match exit_condition {
ExitCondition::MaxIterations => {
let total_duration = pipeline_start.elapsed();
let final_response = accumulated_responses.last().cloned().unwrap_or_default();
if let Some(handler) = event_handler {
handler(PipelineEvent::Completed {
pipeline_id: id.clone(),
total_duration,
total_usage: total_usage.clone(),
});
}
Ok(PipelineOutput {
pipeline_id: id,
final_response,
steps,
total_duration,
total_usage,
})
}
_ => Err(PipelineError::MaxIterationsReached {
iterations: max_iterations,
}),
}
}
fn extract_text(result: &AgentResult) -> String {
result
.messages
.iter()
.rev()
.find_map(|m| match m {
AgentMessage::Llm(LlmMessage::Assistant(msg)) => Some(msg),
_ => None,
})
.map(|msg| {
msg.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
})
.unwrap_or_default()
}
fn check_tool_called(result: &AgentResult, tool_name: &str) -> bool {
result.messages.iter().any(|m| match m {
AgentMessage::Llm(LlmMessage::Assistant(msg)) => msg
.content
.iter()
.any(|b| matches!(b, ContentBlock::ToolCall { name, .. } if name == tool_name)),
_ => false,
})
}
fn make_user_message(text: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::User(swink_agent::UserMessage {
content: vec![ContentBlock::Text {
text: text.to_string(),
}],
timestamp: 0,
cache_hint: None,
}))
}
#[cfg(all(test, feature = "testkit"))]
mod tests {
use super::*;
use std::sync::Arc;
use swink_agent::AgentOptions;
use swink_agent::testing::{
MockStreamFn, default_convert, default_model, text_events, tool_call_events,
};
use crate::pipeline::executor::SimpleAgentFactory;
fn factory_with_responses(
name: &str,
responses: Vec<Vec<swink_agent::AssistantMessageEvent>>,
) -> Arc<SimpleAgentFactory> {
let name = name.to_string();
let responses = Arc::new(std::sync::Mutex::new(responses));
let mut factory = SimpleAgentFactory::new();
factory.register(name, move || {
let next = {
let mut guard = responses.lock().unwrap();
if guard.is_empty() {
vec![]
} else {
vec![guard.remove(0)]
}
};
let options = AgentOptions::new(
"loop-body",
default_model(),
Arc::new(MockStreamFn::new(next)),
default_convert,
);
Agent::new(options)
});
Arc::new(factory)
}
use swink_agent::Agent;
#[tokio::test]
async fn loop_exits_on_tool_called() {
let factory = factory_with_responses(
"body",
vec![
text_events("iteration 1 output"),
tool_call_events("tc-1", "done", "{}"),
],
);
let result = run_loop(
&(factory as Arc<dyn AgentFactory>),
&None,
PipelineId::new("test-loop"),
"test".to_string(),
"body".to_string(),
ExitCondition::ToolCalled {
tool_name: "done".to_string(),
},
10,
"do something".to_string(),
CancellationToken::new(),
)
.await;
let output = result.expect("should succeed");
assert_eq!(output.steps.len(), 2);
assert!(!output.steps[0].response.is_empty());
}
#[tokio::test]
async fn loop_exits_on_output_contains() {
let factory = factory_with_responses(
"body",
vec![
text_events("still working..."),
text_events("all finished DONE"),
],
);
let exit_cond = ExitCondition::output_contains(r"DONE").unwrap();
let result = run_loop(
&(factory as Arc<dyn AgentFactory>),
&None,
PipelineId::new("test-loop"),
"test".to_string(),
"body".to_string(),
exit_cond,
10,
"process data".to_string(),
CancellationToken::new(),
)
.await;
let output = result.expect("should succeed");
assert_eq!(output.steps.len(), 2);
assert!(output.steps[1].response.contains("DONE"));
}
#[tokio::test]
async fn loop_errors_when_max_iterations_reached() {
let factory = factory_with_responses(
"body",
vec![
text_events("iter 1"),
text_events("iter 2"),
text_events("iter 3"),
],
);
let exit_cond = ExitCondition::output_contains(r"NEVER_MATCHES").unwrap();
let result = run_loop(
&(factory as Arc<dyn AgentFactory>),
&None,
PipelineId::new("test-loop"),
"test".to_string(),
"body".to_string(),
exit_cond,
3,
"input".to_string(),
CancellationToken::new(),
)
.await;
match result {
Err(PipelineError::MaxIterationsReached { iterations }) => {
assert_eq!(iterations, 3);
}
other => panic!("expected MaxIterationsReached, got: {other:?}"),
}
}
#[tokio::test]
async fn loop_halts_on_agent_error() {
let factory: Arc<dyn AgentFactory> = Arc::new(SimpleAgentFactory::new());
let result = run_loop(
&factory,
&None,
PipelineId::new("test-loop"),
"test".to_string(),
"body".to_string(),
ExitCondition::MaxIterations,
5,
"input".to_string(),
CancellationToken::new(),
)
.await;
assert!(
matches!(result, Err(PipelineError::AgentNotFound { .. })),
"expected AgentNotFound, got: {result:?}"
);
}
#[tokio::test]
async fn loop_accumulates_context() {
let factory = factory_with_responses(
"body",
vec![
text_events("response A"),
text_events("response B"),
text_events("response C DONE"),
],
);
let exit_cond = ExitCondition::output_contains(r"DONE").unwrap();
let result = run_loop(
&(factory as Arc<dyn AgentFactory>),
&None,
PipelineId::new("test-loop"),
"test".to_string(),
"body".to_string(),
exit_cond,
10,
"original input".to_string(),
CancellationToken::new(),
)
.await;
let output = result.expect("should succeed");
assert_eq!(output.steps.len(), 3);
assert_eq!(output.steps[0].response, "response A");
assert_eq!(output.steps[1].response, "response B");
assert!(output.steps[2].response.contains("DONE"));
}
#[tokio::test]
async fn loop_max_iterations_exit_condition_succeeds() {
let factory = factory_with_responses(
"body",
vec![
text_events("iter 1"),
text_events("iter 2"),
text_events("iter 3"),
],
);
let result = run_loop(
&(factory as Arc<dyn AgentFactory>),
&None,
PipelineId::new("test-loop"),
"test".to_string(),
"body".to_string(),
ExitCondition::MaxIterations,
3,
"input".to_string(),
CancellationToken::new(),
)
.await;
let output = result.expect("MaxIterations should succeed after running all iterations");
assert_eq!(output.steps.len(), 3);
assert_eq!(output.final_response, "iter 3");
}
}