use std::sync::Arc;
use entelix_core::ir::{ContentPart, Message, Role, SystemPrompt, ToolResultContent};
use entelix_core::{Error, ExecutionContext, LlmRenderable, Result, RunOverrides, ToolRegistry};
use entelix_graph::{CompiledGraph, StateGraph};
use entelix_runnable::{Configured, Runnable, RunnableLambda};
use crate::agent::{Agent, AgentBuilder};
use crate::state::ReActState;
pub fn build_react_graph<M>(model: M, tools: ToolRegistry) -> Result<CompiledGraph<ReActState>>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
build_react_graph_inner(model, tools, None)
}
pub(crate) fn build_react_graph_with_recursion_limit<M>(
model: M,
tools: ToolRegistry,
recursion_limit: usize,
) -> Result<CompiledGraph<ReActState>>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
build_react_graph_inner(model, tools, Some(recursion_limit))
}
fn build_react_graph_inner<M>(
model: M,
tools: ToolRegistry,
recursion_limit: Option<usize>,
) -> Result<CompiledGraph<ReActState>>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
let model = Arc::new(model);
let tools = Arc::new(tools);
let planner_node = RunnableLambda::new(move |mut state: ReActState, ctx: ExecutionContext| {
let model = model.clone();
async move {
let reply = model.invoke(state.messages.clone(), &ctx).await?;
state.messages.push(reply);
state.steps = state.steps.saturating_add(1);
Ok::<_, _>(state)
}
});
let tool_node = RunnableLambda::new(move |mut state: ReActState, ctx: ExecutionContext| {
let tools = tools.clone();
async move {
let last = state.messages.last().cloned().ok_or_else(|| {
Error::invalid_request("ReActAgent: tool dispatch with empty conversation")
})?;
let mut results: Vec<ContentPart> = Vec::new();
for part in &last.content {
if let ContentPart::ToolUse {
id, name, input, ..
} = part
{
let (content, is_error) =
match tools.dispatch(id, name, input.clone(), &ctx).await {
Ok(value) => (ToolResultContent::Json(value), false),
Err(Error::Interrupted { kind, payload }) => {
return Err(Error::Interrupted { kind, payload });
}
Err(e) => (ToolResultContent::Text(e.render_for_llm()), true),
};
results.push(ContentPart::ToolResult {
tool_use_id: id.clone(),
name: name.clone(),
content,
is_error,
cache_control: None,
provider_echoes: Vec::new(),
});
}
}
if results.is_empty() {
return Err(Error::invalid_request(
"ReActAgent: tool node reached without any ToolUse parts",
));
}
state.messages.push(Message::new(Role::Tool, results));
Ok::<_, _>(state)
}
});
let finish_node =
RunnableLambda::new(|state: ReActState, _ctx| async move { Ok::<_, _>(state) });
let mut builder = StateGraph::<ReActState>::new()
.add_node("planner", planner_node)
.add_node("tools", tool_node)
.add_node("finish", finish_node)
.set_entry_point("planner")
.add_finish_point("finish")
.add_edge("tools", "planner")
.add_conditional_edges(
"planner",
|state: &ReActState| {
if last_message_has_tool_use(state) {
"tools".to_owned()
} else {
"finish".to_owned()
}
},
[("tools", "tools"), ("finish", "finish")],
);
if let Some(limit) = recursion_limit {
builder = builder.with_recursion_limit(limit);
}
builder.compile()
}
pub fn create_react_agent<M>(model: M, tools: ToolRegistry) -> Result<Agent<ReActState>>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
let defaults = build_run_overrides(None, tools.tool_specs());
react_agent_builder(model, tools, defaults)?.build()
}
fn build_run_overrides(
system: Option<SystemPrompt>,
tool_specs: Arc<[entelix_core::ir::ToolSpec]>,
) -> Option<RunOverrides> {
if system.is_none() && tool_specs.is_empty() {
return None;
}
let mut overrides = RunOverrides::new();
if let Some(system) = system {
overrides = overrides.with_system_prompt(system);
}
if !tool_specs.is_empty() {
overrides = overrides.with_tool_specs(tool_specs);
}
Some(overrides)
}
fn wrap_with_run_overrides(
graph: CompiledGraph<ReActState>,
defaults: RunOverrides,
) -> Configured<
CompiledGraph<ReActState>,
impl Fn(&mut ExecutionContext) + Send + Sync + 'static,
ReActState,
ReActState,
> {
Configured::new(graph, move |ctx: &mut ExecutionContext| {
if ctx.extension::<RunOverrides>().is_none() {
let scoped = std::mem::take(ctx).add_extension(defaults.clone());
*ctx = scoped;
}
})
}
pub(crate) fn react_agent_builder<M>(
model: M,
tools: ToolRegistry,
defaults: Option<RunOverrides>,
) -> Result<AgentBuilder<ReActState>>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
let graph = build_react_graph(model, tools)?;
let builder = Agent::<ReActState>::builder().with_name("react");
let builder = match defaults {
Some(defaults) => builder.with_runnable(wrap_with_run_overrides(graph, defaults)),
None => builder.with_runnable(graph),
};
Ok(builder)
}
pub(crate) fn react_agent_builder_with_recursion_limit<M>(
model: M,
tools: ToolRegistry,
recursion_limit: usize,
defaults: Option<RunOverrides>,
) -> Result<AgentBuilder<ReActState>>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
let graph = build_react_graph_with_recursion_limit(model, tools, recursion_limit)?;
let builder = Agent::<ReActState>::builder().with_name("react");
let builder = match defaults {
Some(defaults) => builder.with_runnable(wrap_with_run_overrides(graph, defaults)),
None => builder.with_runnable(graph),
};
Ok(builder)
}
pub struct ReActAgentBuilder<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
model: M,
tools: ToolRegistry,
name: Option<String>,
system: Option<SystemPrompt>,
sinks: Vec<Arc<dyn crate::agent::AgentEventSink<ReActState>>>,
approver: Option<Arc<dyn crate::agent::Approver>>,
execution_mode: Option<crate::agent::ExecutionMode>,
observers: Vec<crate::agent::DynObserver<ReActState>>,
recursion_limit: Option<usize>,
}
impl<M> ReActAgentBuilder<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
pub fn new(model: M, tools: ToolRegistry) -> Self {
Self {
model,
tools,
name: None,
system: None,
sinks: Vec::new(),
approver: None,
execution_mode: None,
observers: Vec::new(),
recursion_limit: None,
}
}
#[must_use]
pub fn with_system(mut self, system: SystemPrompt) -> Self {
self.system = Some(system);
self
}
#[must_use]
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
#[must_use]
pub fn add_sink(mut self, sink: Arc<dyn crate::agent::AgentEventSink<ReActState>>) -> Self {
self.sinks.push(sink);
self
}
#[must_use]
pub fn with_approver(mut self, approver: Arc<dyn crate::agent::Approver>) -> Self {
self.approver = Some(approver);
self
}
#[must_use]
pub const fn with_execution_mode(mut self, mode: crate::agent::ExecutionMode) -> Self {
self.execution_mode = Some(mode);
self
}
#[must_use]
pub fn with_observer(mut self, observer: crate::agent::DynObserver<ReActState>) -> Self {
self.observers.push(observer);
self
}
#[must_use]
pub const fn with_recursion_limit(mut self, n: usize) -> Self {
self.recursion_limit = Some(n);
self
}
pub fn build(self) -> Result<Agent<ReActState>> {
let tools = match &self.approver {
Some(approver) => self
.tools
.layer(crate::agent::ApprovalLayer::new(Arc::clone(approver))),
None => self.tools,
};
let tool_specs = tools.tool_specs();
let defaults = build_run_overrides(self.system, tool_specs);
let mut builder = match self.recursion_limit {
Some(limit) => {
react_agent_builder_with_recursion_limit(self.model, tools, limit, defaults)?
}
None => react_agent_builder(self.model, tools, defaults)?,
};
if let Some(name) = self.name {
builder = builder.with_name(name);
}
for sink in self.sinks {
builder = builder.add_sink_arc(sink);
}
let mode = self.execution_mode.unwrap_or_else(|| {
if self.approver.is_some() {
crate::agent::ExecutionMode::Supervised
} else {
crate::agent::ExecutionMode::default()
}
});
builder = builder.with_execution_mode(mode);
if let Some(approver) = self.approver {
builder = builder.with_approver_arc(approver);
}
for observer in self.observers {
builder = builder.with_observer_arc(observer);
}
builder.build()
}
}
fn last_message_has_tool_use(state: &ReActState) -> bool {
state.messages.last().is_some_and(|m| {
m.content
.iter()
.any(|p| matches!(p, ContentPart::ToolUse { .. }))
})
}