1use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use async_trait::async_trait;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16#[derive(Clone)]
18pub struct ExecutionConfig {
19 pub thread_id: String,
21 pub resume_from: Option<String>,
23 pub recursion_limit: usize,
25 pub metadata: HashMap<String, Value>,
27}
28
29impl ExecutionConfig {
30 pub fn new(thread_id: &str) -> Self {
32 Self {
33 thread_id: thread_id.to_string(),
34 resume_from: None,
35 recursion_limit: 50,
36 metadata: HashMap::new(),
37 }
38 }
39
40 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
42 self.recursion_limit = limit;
43 self
44 }
45
46 pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
48 self.resume_from = Some(checkpoint_id.to_string());
49 self
50 }
51
52 pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
54 self.metadata.insert(key.to_string(), value);
55 self
56 }
57}
58
59impl Default for ExecutionConfig {
60 fn default() -> Self {
61 Self::new(&uuid::Uuid::new_v4().to_string())
62 }
63}
64
65pub struct NodeContext {
67 pub state: State,
69 pub config: ExecutionConfig,
71 pub step: usize,
73}
74
75impl NodeContext {
76 pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
78 Self { state, config, step }
79 }
80
81 pub fn get(&self, key: &str) -> Option<&Value> {
83 self.state.get(key)
84 }
85
86 pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
88 self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
89 }
90}
91
92#[derive(Default)]
94pub struct NodeOutput {
95 pub updates: HashMap<String, Value>,
97 pub interrupt: Option<Interrupt>,
99 pub events: Vec<StreamEvent>,
101}
102
103impl NodeOutput {
104 pub fn new() -> Self {
106 Self::default()
107 }
108
109 pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
111 self.updates.insert(key.to_string(), value.into());
112 self
113 }
114
115 pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
117 self.updates.extend(updates);
118 self
119 }
120
121 pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
123 self.interrupt = Some(interrupt);
124 self
125 }
126
127 pub fn with_event(mut self, event: StreamEvent) -> Self {
129 self.events.push(event);
130 self
131 }
132
133 pub fn interrupt(message: &str) -> Self {
135 Self::new().with_interrupt(crate::interrupt::interrupt(message))
136 }
137
138 pub fn interrupt_with_data(message: &str, data: Value) -> Self {
140 Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
141 }
142}
143
144#[async_trait]
146pub trait Node: Send + Sync {
147 fn name(&self) -> &str;
149
150 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
152
153 fn execute_stream<'a>(
155 &'a self,
156 ctx: &'a NodeContext,
157 ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
158 let _name = self.name().to_string();
159 Box::pin(async_stream::stream! {
160 match self.execute(ctx).await {
161 Ok(output) => {
162 for event in output.events {
163 yield Ok(event);
164 }
165 }
166 Err(e) => yield Err(e),
167 }
168 })
169 }
170}
171
172pub type BoxedNode = Box<dyn Node>;
174
175pub type AsyncNodeFn = Box<
177 dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
178>;
179
180pub struct FunctionNode {
182 name: String,
183 func: AsyncNodeFn,
184}
185
186impl FunctionNode {
187 pub fn new<F, Fut>(name: &str, func: F) -> Self
189 where
190 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
191 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
192 {
193 Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
194 }
195}
196
197#[async_trait]
198impl Node for FunctionNode {
199 fn name(&self) -> &str {
200 &self.name
201 }
202
203 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
204 let ctx_owned =
205 NodeContext { state: ctx.state.clone(), config: ctx.config.clone(), step: ctx.step };
206 (self.func)(ctx_owned).await
207 }
208}
209
210pub struct PassthroughNode {
212 name: String,
213}
214
215impl PassthroughNode {
216 pub fn new(name: &str) -> Self {
218 Self { name: name.to_string() }
219 }
220}
221
222#[async_trait]
223impl Node for PassthroughNode {
224 fn name(&self) -> &str {
225 &self.name
226 }
227
228 async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
229 Ok(NodeOutput::new())
230 }
231}
232
233pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
235
236pub type AgentOutputMapper =
238 Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
239
240pub struct AgentNode {
242 name: String,
243 #[allow(dead_code)]
244 agent: Arc<dyn adk_core::Agent>,
245 input_mapper: AgentInputMapper,
247 output_mapper: AgentOutputMapper,
249}
250
251impl AgentNode {
252 pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
254 let name = agent.name().to_string();
255 Self {
256 name,
257 agent,
258 input_mapper: Box::new(default_input_mapper),
259 output_mapper: Box::new(default_output_mapper),
260 }
261 }
262
263 pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
265 where
266 F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
267 {
268 self.input_mapper = Box::new(mapper);
269 self
270 }
271
272 pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
274 where
275 F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
276 {
277 self.output_mapper = Box::new(mapper);
278 self
279 }
280}
281
282fn default_input_mapper(state: &State) -> adk_core::Content {
284 if let Some(messages) = state.get("messages") {
286 if let Some(arr) = messages.as_array() {
287 if let Some(last) = arr.last() {
288 if let Some(content) = last.get("content").and_then(|c| c.as_str()) {
289 return adk_core::Content::new("user").with_text(content);
290 }
291 }
292 }
293 }
294
295 if let Some(input) = state.get("input") {
297 if let Some(text) = input.as_str() {
298 return adk_core::Content::new("user").with_text(text);
299 }
300 }
301
302 adk_core::Content::new("user")
303}
304
305fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
307 let mut updates = HashMap::new();
308
309 let mut messages = Vec::new();
311 for event in events {
312 if let Some(content) = event.content() {
313 let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
314
315 if !text.is_empty() {
316 messages.push(serde_json::json!({
317 "role": "assistant",
318 "content": text
319 }));
320 }
321 }
322 }
323
324 if !messages.is_empty() {
325 updates.insert("messages".to_string(), serde_json::json!(messages));
326 }
327
328 updates
329}
330
331#[async_trait]
332impl Node for AgentNode {
333 fn name(&self) -> &str {
334 &self.name
335 }
336
337 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
338 use futures::StreamExt;
339
340 let content = (self.input_mapper)(&ctx.state);
342
343 let invocation_ctx = Arc::new(GraphInvocationContext::new(
345 ctx.config.thread_id.clone(),
346 content,
347 self.agent.clone(),
348 ));
349
350 let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
352 crate::error::GraphError::NodeExecutionFailed {
353 node: self.name.clone(),
354 message: e.to_string(),
355 }
356 })?;
357
358 let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
359
360 let updates = (self.output_mapper)(&events);
362
363 let mut output = NodeOutput::new().with_updates(updates);
365 for event in &events {
366 if let Ok(json) = serde_json::to_value(event) {
367 output = output.with_event(StreamEvent::custom(&self.name, "agent_event", json));
368 }
369 }
370
371 Ok(output)
372 }
373
374 fn execute_stream<'a>(
375 &'a self,
376 ctx: &'a NodeContext,
377 ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
378 use futures::StreamExt;
379 let name = self.name.clone();
380 let agent = self.agent.clone();
381 let input_mapper = &self.input_mapper;
382 let thread_id = ctx.config.thread_id.clone();
383 let content = (input_mapper)(&ctx.state);
384
385 Box::pin(async_stream::stream! {
386 eprintln!("DEBUG: AgentNode::execute_stream called for {}", name);
387 let invocation_ctx = Arc::new(GraphInvocationContext::new(
388 thread_id,
389 content,
390 agent.clone(),
391 ));
392
393 let stream = match agent.run(invocation_ctx).await {
394 Ok(s) => s,
395 Err(e) => {
396 yield Err(crate::error::GraphError::NodeExecutionFailed {
397 node: name.clone(),
398 message: e.to_string(),
399 });
400 return;
401 }
402 };
403
404 tokio::pin!(stream);
405 let mut all_events = Vec::new();
406
407 while let Some(result) = stream.next().await {
408 match result {
409 Ok(event) => {
410 if let Some(content) = event.content() {
412 let text: String = content.parts.iter().filter_map(|p| p.text()).collect();
413 if !text.is_empty() {
414 yield Ok(StreamEvent::Message {
415 node: name.clone(),
416 content: text,
417 is_final: false,
418 });
419 }
420 }
421 all_events.push(event);
422 }
423 Err(e) => {
424 yield Err(crate::error::GraphError::NodeExecutionFailed {
425 node: name.clone(),
426 message: e.to_string(),
427 });
428 return;
429 }
430 }
431 }
432
433 for event in &all_events {
435 if let Ok(json) = serde_json::to_value(event) {
436 yield Ok(StreamEvent::custom(&name, "agent_event", json));
437 }
438 }
439 })
440 }
441}
442
443struct GraphInvocationContext {
445 invocation_id: String,
446 user_content: adk_core::Content,
447 agent: Arc<dyn adk_core::Agent>,
448 session: Arc<GraphSession>,
449 run_config: adk_core::RunConfig,
450 ended: std::sync::atomic::AtomicBool,
451}
452
453impl GraphInvocationContext {
454 fn new(
455 session_id: String,
456 user_content: adk_core::Content,
457 agent: Arc<dyn adk_core::Agent>,
458 ) -> Self {
459 let invocation_id = uuid::Uuid::new_v4().to_string();
460 let session = Arc::new(GraphSession::new(session_id));
461 session.append_content(user_content.clone());
463 Self {
464 invocation_id,
465 user_content,
466 agent,
467 session,
468 run_config: adk_core::RunConfig::default(),
469 ended: std::sync::atomic::AtomicBool::new(false),
470 }
471 }
472}
473
474impl adk_core::ReadonlyContext for GraphInvocationContext {
476 fn invocation_id(&self) -> &str {
477 &self.invocation_id
478 }
479
480 fn agent_name(&self) -> &str {
481 self.agent.name()
482 }
483
484 fn user_id(&self) -> &str {
485 "graph_user"
486 }
487
488 fn app_name(&self) -> &str {
489 "graph_app"
490 }
491
492 fn session_id(&self) -> &str {
493 &self.session.id
494 }
495
496 fn branch(&self) -> &str {
497 "main"
498 }
499
500 fn user_content(&self) -> &adk_core::Content {
501 &self.user_content
502 }
503}
504
505#[async_trait]
507impl adk_core::CallbackContext for GraphInvocationContext {
508 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
509 None
510 }
511}
512
513#[async_trait]
515impl adk_core::InvocationContext for GraphInvocationContext {
516 fn agent(&self) -> Arc<dyn adk_core::Agent> {
517 self.agent.clone()
518 }
519
520 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
521 None
522 }
523
524 fn session(&self) -> &dyn adk_core::Session {
525 self.session.as_ref()
526 }
527
528 fn run_config(&self) -> &adk_core::RunConfig {
529 &self.run_config
530 }
531
532 fn end_invocation(&self) {
533 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
534 }
535
536 fn ended(&self) -> bool {
537 self.ended.load(std::sync::atomic::Ordering::SeqCst)
538 }
539}
540
541struct GraphSession {
543 id: String,
544 state: GraphState,
545 history: std::sync::RwLock<Vec<adk_core::Content>>,
546}
547
548impl GraphSession {
549 fn new(id: String) -> Self {
550 Self { id, state: GraphState::new(), history: std::sync::RwLock::new(Vec::new()) }
551 }
552
553 fn append_content(&self, content: adk_core::Content) {
554 if let Ok(mut h) = self.history.write() {
555 h.push(content);
556 }
557 }
558}
559
560impl adk_core::Session for GraphSession {
561 fn id(&self) -> &str {
562 &self.id
563 }
564
565 fn app_name(&self) -> &str {
566 "graph_app"
567 }
568
569 fn user_id(&self) -> &str {
570 "graph_user"
571 }
572
573 fn state(&self) -> &dyn adk_core::State {
574 &self.state
575 }
576
577 fn conversation_history(&self) -> Vec<adk_core::Content> {
578 self.history.read().ok().map(|h| h.clone()).unwrap_or_default()
579 }
580
581 fn append_to_history(&self, content: adk_core::Content) {
582 self.append_content(content);
583 }
584}
585
586struct GraphState {
588 data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
589}
590
591impl GraphState {
592 fn new() -> Self {
593 Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
594 }
595}
596
597impl adk_core::State for GraphState {
598 fn get(&self, key: &str) -> Option<serde_json::Value> {
599 self.data.read().ok()?.get(key).cloned()
600 }
601
602 fn set(&mut self, key: String, value: serde_json::Value) {
603 if let Ok(mut data) = self.data.write() {
604 data.insert(key, value);
605 }
606 }
607
608 fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
609 self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[tokio::test]
618 async fn test_function_node() {
619 let node = FunctionNode::new("test", |_ctx| async {
620 Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
621 });
622
623 assert_eq!(node.name(), "test");
624
625 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
626 let output = node.execute(&ctx).await.unwrap();
627
628 assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
629 }
630
631 #[tokio::test]
632 async fn test_passthrough_node() {
633 let node = PassthroughNode::new("pass");
634 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
635 let output = node.execute(&ctx).await.unwrap();
636
637 assert!(output.updates.is_empty());
638 assert!(output.interrupt.is_none());
639 }
640
641 #[test]
642 fn test_node_output_builder() {
643 let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
644
645 assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
646 assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
647 }
648}