use async_trait::async_trait;
use tokio::sync::mpsc;
use crate::error::{GraphError, TerminalError};
use crate::event::{GraphEvent, NodeEvent, SpanId};
use crate::node::{GraphNode, NextStep, StreamNodeResult};
use crate::state::State;
pub struct AgentNode {
pub name: String,
pub agent: lellm_agent::ToolUseLoop,
pub prefix: String,
pub write_messages: bool,
pub write_stats: bool,
}
impl AgentNode {
pub fn new(name: impl Into<String>, agent: lellm_agent::ToolUseLoop) -> Self {
Self {
name: name.into(),
agent,
prefix: "agent".into(),
write_messages: true,
write_stats: true,
}
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = prefix.into();
self
}
pub fn with_write_messages(mut self, enabled: bool) -> Self {
self.write_messages = enabled;
self
}
pub fn with_write_stats(mut self, enabled: bool) -> Self {
self.write_stats = enabled;
self
}
}
fn stop_reason_str(reason: &lellm_agent::StopReason) -> &'static str {
match reason {
lellm_agent::StopReason::Complete => "Complete",
lellm_agent::StopReason::MaxIterationsReached => "MaxIterations",
lellm_agent::StopReason::Cancelled => "Cancelled",
lellm_agent::StopReason::OutputBudgetExceeded => "OutputBudget",
lellm_agent::StopReason::ReasoningBudgetExceeded => "ReasoningBudget",
}
}
fn write_agent_result(node: &AgentNode, result: &lellm_agent::ToolUseResult, state: &mut State) {
let text: String = result
.response
.content
.iter()
.filter_map(|b| match b {
lellm_core::ContentBlock::Text(t) => Some(t.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
if !text.is_empty() {
state.insert(
format!("{}.output", node.prefix),
serde_json::Value::String(text),
);
}
if node.write_messages {
state.insert(
format!("{}.messages", node.prefix),
serde_json::to_value(&result.messages).unwrap_or(serde_json::Value::Null),
);
}
if node.write_stats {
state.insert(
format!("{}.iterations", node.prefix),
serde_json::json!(result.iterations),
);
state.insert(
format!("{}.tool_calls", node.prefix),
serde_json::json!(result.tool_calls_executed),
);
state.insert(
format!("{}.stop_reason", node.prefix),
serde_json::json!(stop_reason_str(&result.stop_reason)),
);
}
}
fn read_messages(state: &State, prefix: &str) -> Vec<lellm_core::Message> {
let input_key = format!("{}.messages", prefix);
let messages = state
.get(&input_key)
.and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
.unwrap_or_default();
if messages.is_empty() {
state
.get("messages")
.and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
.unwrap_or_default()
} else {
messages
}
}
#[async_trait]
impl GraphNode for AgentNode {
async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
let messages = read_messages(state, &self.prefix);
let result =
self.agent
.execute(messages)
.await
.map_err(|e| GraphError::Terminal(TerminalError::NodeExecutionFailed {
node: self.name.clone(),
source: Box::new(e),
}))?;
write_agent_result(self, &result, state);
Ok(NextStep::GoToNext)
}
async fn execute_stream(
&self,
state: &mut State,
sink: &mpsc::Sender<GraphEvent>,
span_id: SpanId,
) -> Result<StreamNodeResult, GraphError> {
let messages = read_messages(state, &self.prefix);
let node_name = self.name.clone();
let mut stream = self.agent.execute_stream(messages);
struct ExtractedResult {
write_result: Option<lellm_agent::ToolUseResult>,
error_msg: Option<String>,
}
while let Some(event) = stream.recv().await {
let extracted = match &event {
lellm_agent::AgentEvent::LoopEnd { result } => ExtractedResult {
write_result: Some(result.clone()),
error_msg: None,
},
lellm_agent::AgentEvent::LoopError { error, .. } => ExtractedResult {
write_result: None,
error_msg: Some(error.to_string()),
},
_ => ExtractedResult {
write_result: None,
error_msg: None,
},
};
let _ = sink
.send(GraphEvent::Node {
span_id,
node_name: node_name.clone(),
event: NodeEvent::Agent(event),
})
.await;
if let Some(result) = extracted.write_result {
write_agent_result(self, &result, state);
return Ok(StreamNodeResult::Done {
next: NextStep::GoToNext,
span_id,
});
}
if let Some(err_msg) = extracted.error_msg {
return Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
node: self.name.clone(),
source: err_msg.into(),
}));
}
}
Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
node: self.name.clone(),
source: "agent stream closed without terminal event".into(),
}))
}
}
pub struct LLMNode {
pub name: String,
model: lellm_agent::ResolvedModel,
system_prompt: Option<String>,
messages_key: String,
}
impl LLMNode {
pub fn new(name: impl Into<String>, model: lellm_agent::ResolvedModel) -> Self {
Self {
name: name.into(),
model,
system_prompt: None,
messages_key: "messages".into(),
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_messages_key(mut self, key: impl Into<String>) -> Self {
self.messages_key = key.into();
self
}
}
#[async_trait]
impl GraphNode for LLMNode {
async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
let mut messages = state
.get(&self.messages_key)
.and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
.unwrap_or_default();
if let Some(ref sys) = self.system_prompt {
messages.retain(|m| !matches!(m, lellm_core::Message::System { .. }));
messages.insert(
0,
lellm_core::Message::System {
content: lellm_core::text_block(sys.clone()),
},
);
}
let request = lellm_core::ChatRequest {
model: self.model.model.clone(),
messages: messages.clone(),
..Default::default()
};
let response = self.model.provider.call(&request).await.map_err(|e| {
GraphError::Terminal(TerminalError::NodeExecutionFailed {
node: self.name.clone(),
source: Box::new(e),
})
})?;
let assistant_msg = lellm_core::Message::Assistant {
content: response.content,
};
messages.push(assistant_msg);
state.insert(
self.messages_key.clone(),
serde_json::to_value(&messages).map_err(|e| {
GraphError::Terminal(TerminalError::StateError(format!("failed to serialize messages: {e}")))
})?,
);
Ok(NextStep::GoToNext)
}
}