1use async_trait::async_trait;
4use tokio::sync::mpsc;
5
6use crate::error::{GraphError, TerminalError};
7use crate::event::{GraphEvent, NodeEvent, SpanId};
8use crate::node::{GraphNode, NextStep, StreamNodeResult};
9use crate::state::State;
10
11pub struct AgentNode {
22 pub name: String,
23 pub agent: lellm_agent::ToolUseLoop,
24 pub prefix: String,
26 pub write_messages: bool,
28 pub write_stats: bool,
30}
31
32impl AgentNode {
33 pub fn new(name: impl Into<String>, agent: lellm_agent::ToolUseLoop) -> Self {
34 Self {
35 name: name.into(),
36 agent,
37 prefix: "agent".into(),
38 write_messages: true,
39 write_stats: true,
40 }
41 }
42
43 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
47 self.prefix = prefix.into();
48 self
49 }
50
51 pub fn with_write_messages(mut self, enabled: bool) -> Self {
53 self.write_messages = enabled;
54 self
55 }
56
57 pub fn with_write_stats(mut self, enabled: bool) -> Self {
59 self.write_stats = enabled;
60 self
61 }
62}
63
64fn stop_reason_str(reason: &lellm_agent::StopReason) -> &'static str {
66 match reason {
67 lellm_agent::StopReason::Complete => "Complete",
68 lellm_agent::StopReason::MaxIterationsReached => "MaxIterations",
69 lellm_agent::StopReason::Cancelled => "Cancelled",
70 lellm_agent::StopReason::OutputBudgetExceeded => "OutputBudget",
71 lellm_agent::StopReason::ReasoningBudgetExceeded => "ReasoningBudget",
72 }
73}
74
75fn write_agent_result(node: &AgentNode, result: &lellm_agent::ToolUseResult, state: &mut State) {
77 let text: String = result
79 .response
80 .content
81 .iter()
82 .filter_map(|b| match b {
83 lellm_core::ContentBlock::Text(t) => Some(t.text.as_str()),
84 _ => None,
85 })
86 .collect::<Vec<_>>()
87 .join("");
88
89 if !text.is_empty() {
90 state.insert(
91 format!("{}.output", node.prefix),
92 serde_json::Value::String(text),
93 );
94 }
95
96 if node.write_messages {
98 state.insert(
99 format!("{}.messages", node.prefix),
100 serde_json::to_value(&result.messages).unwrap_or(serde_json::Value::Null),
101 );
102 }
103
104 if node.write_stats {
106 state.insert(
107 format!("{}.iterations", node.prefix),
108 serde_json::json!(result.iterations),
109 );
110 state.insert(
111 format!("{}.tool_calls", node.prefix),
112 serde_json::json!(result.tool_calls_executed),
113 );
114 state.insert(
115 format!("{}.stop_reason", node.prefix),
116 serde_json::json!(stop_reason_str(&result.stop_reason)),
117 );
118 }
119}
120
121fn read_messages(state: &State, prefix: &str) -> Vec<lellm_core::Message> {
123 let input_key = format!("{}.messages", prefix);
124 let messages = state
125 .get(&input_key)
126 .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
127 .unwrap_or_default();
128
129 if messages.is_empty() {
131 state
132 .get("messages")
133 .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
134 .unwrap_or_default()
135 } else {
136 messages
137 }
138}
139
140#[async_trait]
141impl GraphNode for AgentNode {
142 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
143 let messages = read_messages(state, &self.prefix);
144
145 let result =
146 self.agent
147 .execute(messages)
148 .await
149 .map_err(|e| GraphError::Terminal(TerminalError::NodeExecutionFailed {
150 node: self.name.clone(),
151 source: Box::new(e),
152 }))?;
153
154 write_agent_result(self, &result, state);
155 Ok(NextStep::GoToNext)
156 }
157
158 async fn execute_stream(
160 &self,
161 state: &mut State,
162 sink: &mpsc::Sender<GraphEvent>,
163 span_id: SpanId,
164 ) -> Result<StreamNodeResult, GraphError> {
165 let messages = read_messages(state, &self.prefix);
166 let node_name = self.name.clone();
167
168 let mut stream = self.agent.execute_stream(messages);
170
171 struct ExtractedResult {
173 write_result: Option<lellm_agent::ToolUseResult>,
174 error_msg: Option<String>,
175 }
176
177 while let Some(event) = stream.recv().await {
179 let extracted = match &event {
180 lellm_agent::AgentEvent::LoopEnd { result } => ExtractedResult {
181 write_result: Some(result.clone()),
182 error_msg: None,
183 },
184 lellm_agent::AgentEvent::LoopError { error, .. } => ExtractedResult {
185 write_result: None,
186 error_msg: Some(error.to_string()),
187 },
188 _ => ExtractedResult {
189 write_result: None,
190 error_msg: None,
191 },
192 };
193
194 let _ = sink
196 .send(GraphEvent::Node {
197 span_id,
198 node_name: node_name.clone(),
199 event: NodeEvent::Agent(event),
200 })
201 .await;
202
203 if let Some(result) = extracted.write_result {
205 write_agent_result(self, &result, state);
206 return Ok(StreamNodeResult::Done {
207 next: NextStep::GoToNext,
208 span_id,
209 });
210 }
211 if let Some(err_msg) = extracted.error_msg {
212 return Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
213 node: self.name.clone(),
214 source: err_msg.into(),
215 }));
216 }
217 }
218
219 Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
220 node: self.name.clone(),
221 source: "agent stream closed without terminal event".into(),
222 }))
223 }
224}
225
226pub struct LLMNode {
265 pub name: String,
266 model: lellm_agent::ResolvedModel,
267 system_prompt: Option<String>,
268 messages_key: String,
269}
270
271impl LLMNode {
272 pub fn new(name: impl Into<String>, model: lellm_agent::ResolvedModel) -> Self {
273 Self {
274 name: name.into(),
275 model,
276 system_prompt: None,
277 messages_key: "messages".into(),
278 }
279 }
280
281 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
283 self.system_prompt = Some(prompt.into());
284 self
285 }
286
287 pub fn with_messages_key(mut self, key: impl Into<String>) -> Self {
289 self.messages_key = key.into();
290 self
291 }
292}
293
294#[async_trait]
295impl GraphNode for LLMNode {
296 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
297 let mut messages = state
299 .get(&self.messages_key)
300 .and_then(|v| serde_json::from_value::<Vec<lellm_core::Message>>(v.clone()).ok())
301 .unwrap_or_default();
302
303 if let Some(ref sys) = self.system_prompt {
305 messages.retain(|m| !matches!(m, lellm_core::Message::System { .. }));
307 messages.insert(
308 0,
309 lellm_core::Message::System {
310 content: lellm_core::text_block(sys.clone()),
311 },
312 );
313 }
314
315 let request = lellm_core::ChatRequest {
317 model: self.model.model.clone(),
318 messages: messages.clone(),
319 ..Default::default()
320 };
321
322 let response = self.model.provider.call(&request).await.map_err(|e| {
324 GraphError::Terminal(TerminalError::NodeExecutionFailed {
325 node: self.name.clone(),
326 source: Box::new(e),
327 })
328 })?;
329
330 let assistant_msg = lellm_core::Message::Assistant {
332 content: response.content,
333 };
334 messages.push(assistant_msg);
335 state.insert(
336 self.messages_key.clone(),
337 serde_json::to_value(&messages).map_err(|e| {
338 GraphError::Terminal(TerminalError::StateError(format!("failed to serialize messages: {e}")))
339 })?,
340 );
341
342 Ok(NextStep::GoToNext)
343 }
344}