1use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use crate::timeout::ProgressHandle;
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16
17#[derive(Clone)]
19pub struct ExecutionConfig {
20 pub thread_id: String,
22 pub resume_from: Option<String>,
24 pub recursion_limit: usize,
26 pub metadata: HashMap<String, Value>,
28}
29
30impl ExecutionConfig {
31 pub fn new(thread_id: &str) -> Self {
33 Self {
34 thread_id: thread_id.to_string(),
35 resume_from: None,
36 recursion_limit: 50,
37 metadata: HashMap::new(),
38 }
39 }
40
41 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
43 self.recursion_limit = limit;
44 self
45 }
46
47 pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
49 self.resume_from = Some(checkpoint_id.to_string());
50 self
51 }
52
53 pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
55 self.metadata.insert(key.to_string(), value);
56 self
57 }
58}
59
60impl Default for ExecutionConfig {
61 fn default() -> Self {
62 Self::new(&uuid::Uuid::new_v4().to_string())
63 }
64}
65
66pub struct NodeContext {
68 pub state: State,
70 pub config: ExecutionConfig,
72 pub step: usize,
74 progress_handle: Option<ProgressHandle>,
77}
78
79impl NodeContext {
80 pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
82 Self { state, config, step, progress_handle: None }
83 }
84
85 pub fn get(&self, key: &str) -> Option<&Value> {
87 self.state.get(key)
88 }
89
90 pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
92 self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
93 }
94
95 pub fn report_progress(&self) {
113 if let Some(handle) = &self.progress_handle {
114 handle.report_progress();
115 }
116 }
117
118 pub fn set_progress_handle(&mut self, handle: ProgressHandle) {
123 self.progress_handle = Some(handle);
124 }
125
126 pub fn progress_handle(&self) -> Option<&ProgressHandle> {
128 self.progress_handle.as_ref()
129 }
130}
131
132#[derive(Default)]
134pub struct NodeOutput {
135 pub updates: HashMap<String, Value>,
137 pub interrupt: Option<Interrupt>,
139 pub events: Vec<StreamEvent>,
141}
142
143impl NodeOutput {
144 pub fn new() -> Self {
146 Self::default()
147 }
148
149 pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
151 self.updates.insert(key.to_string(), value.into());
152 self
153 }
154
155 pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
157 self.updates.extend(updates);
158 self
159 }
160
161 pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
163 self.interrupt = Some(interrupt);
164 self
165 }
166
167 pub fn with_event(mut self, event: StreamEvent) -> Self {
169 self.events.push(event);
170 self
171 }
172
173 pub fn interrupt(message: &str) -> Self {
175 Self::new().with_interrupt(crate::interrupt::interrupt(message))
176 }
177
178 pub fn interrupt_with_data(message: &str, data: Value) -> Self {
180 Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
181 }
182}
183
184#[async_trait]
186pub trait Node: Send + Sync {
187 fn name(&self) -> &str;
189
190 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
192
193 fn execute_stream<'a>(
195 &'a self,
196 ctx: &'a NodeContext,
197 ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
198 let _name = self.name().to_string();
199 Box::pin(async_stream::stream! {
200 match self.execute(ctx).await {
201 Ok(output) => {
202 for event in output.events {
203 yield Ok(event);
204 }
205 }
206 Err(e) => yield Err(e),
207 }
208 })
209 }
210}
211
212pub type BoxedNode = Box<dyn Node>;
214
215pub type AsyncNodeFn = Box<
217 dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
218>;
219
220pub struct FunctionNode {
222 name: String,
223 func: AsyncNodeFn,
224}
225
226impl FunctionNode {
227 pub fn new<F, Fut>(name: &str, func: F) -> Self
229 where
230 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
231 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
232 {
233 Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
234 }
235}
236
237#[async_trait]
238impl Node for FunctionNode {
239 fn name(&self) -> &str {
240 &self.name
241 }
242
243 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
244 let mut ctx_owned = NodeContext::new(ctx.state.clone(), ctx.config.clone(), ctx.step);
245 if let Some(handle) = ctx.progress_handle() {
246 ctx_owned.set_progress_handle(handle.clone());
247 }
248 (self.func)(ctx_owned).await
249 }
250}
251
252pub struct PassthroughNode {
254 name: String,
255}
256
257impl PassthroughNode {
258 pub fn new(name: &str) -> Self {
260 Self { name: name.to_string() }
261 }
262}
263
264#[async_trait]
265impl Node for PassthroughNode {
266 fn name(&self) -> &str {
267 &self.name
268 }
269
270 async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
271 Ok(NodeOutput::new())
272 }
273}
274
275pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
277
278pub type AgentOutputMapper =
280 Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
281
282pub struct AgentNode {
284 name: String,
285 #[allow(dead_code)]
286 agent: Arc<dyn adk_core::Agent>,
287 input_mapper: AgentInputMapper,
289 output_mapper: AgentOutputMapper,
291}
292
293impl AgentNode {
294 pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
296 let name = agent.name().to_string();
297 Self {
298 name,
299 agent,
300 input_mapper: Box::new(default_input_mapper),
301 output_mapper: Box::new(default_output_mapper),
302 }
303 }
304
305 pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
307 where
308 F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
309 {
310 self.input_mapper = Box::new(mapper);
311 self
312 }
313
314 pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
316 where
317 F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
318 {
319 self.output_mapper = Box::new(mapper);
320 self
321 }
322}
323
324fn default_input_mapper(state: &State) -> adk_core::Content {
326 if let Some(messages) = state.get("messages")
328 && let Some(arr) = messages.as_array()
329 && let Some(last) = arr.last()
330 && let Some(content) = last.get("content").and_then(|c| c.as_str())
331 {
332 return adk_core::Content::new("user").with_text(content);
333 }
334
335 if let Some(input) = state.get("input")
337 && let Some(text) = input.as_str()
338 {
339 return adk_core::Content::new("user").with_text(text);
340 }
341
342 adk_core::Content::new("user")
343}
344
345fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
347 let mut updates = HashMap::new();
348
349 let mut messages = Vec::new();
351 for event in events {
352 if let Some(content) = event.content() {
353 let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
354
355 if !text.is_empty() {
356 messages.push(serde_json::json!({
357 "role": "assistant",
358 "content": text
359 }));
360 }
361 }
362 }
363
364 if !messages.is_empty() {
365 updates.insert("messages".to_string(), serde_json::json!(messages));
366 }
367
368 updates
369}
370
371#[async_trait]
372impl Node for AgentNode {
373 fn name(&self) -> &str {
374 &self.name
375 }
376
377 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
378 use futures::StreamExt;
379
380 let content = (self.input_mapper)(&ctx.state);
382
383 let invocation_ctx = Arc::new(GraphInvocationContext::new(
385 ctx.config.thread_id.clone(),
386 content,
387 self.agent.clone(),
388 ));
389
390 let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
392 crate::error::GraphError::NodeExecutionFailed {
393 node: self.name.clone(),
394 message: e.to_string(),
395 }
396 })?;
397
398 let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
399
400 let updates = (self.output_mapper)(&events);
402
403 let mut output = NodeOutput::new().with_updates(updates);
405 for event in &events {
406 if let Ok(json) = serde_json::to_value(event) {
407 output = output.with_event(StreamEvent::custom(&self.name, "agent_event", json));
408 }
409 }
410
411 Ok(output)
412 }
413
414 fn execute_stream<'a>(
415 &'a self,
416 ctx: &'a NodeContext,
417 ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
418 use futures::StreamExt;
419 let name = self.name.clone();
420 let agent = self.agent.clone();
421 let input_mapper = &self.input_mapper;
422 let thread_id = ctx.config.thread_id.clone();
423 let content = (input_mapper)(&ctx.state);
424
425 Box::pin(async_stream::stream! {
426 tracing::debug!("AgentNode::execute_stream called for {}", name);
427 let invocation_ctx = Arc::new(GraphInvocationContext::new(
428 thread_id,
429 content,
430 agent.clone(),
431 ));
432
433 let stream = match agent.run(invocation_ctx).await {
434 Ok(s) => s,
435 Err(e) => {
436 yield Err(crate::error::GraphError::NodeExecutionFailed {
437 node: name.clone(),
438 message: e.to_string(),
439 });
440 return;
441 }
442 };
443
444 tokio::pin!(stream);
445 let mut all_events = Vec::new();
446
447 while let Some(result) = stream.next().await {
448 match result {
449 Ok(event) => {
450 if let Some(content) = event.content() {
452 let text: String = content.parts.iter().filter_map(|p| p.text()).collect();
453 if !text.is_empty() {
454 yield Ok(StreamEvent::Message {
455 node: name.clone(),
456 content: text,
457 is_final: false,
458 });
459 }
460 }
461 all_events.push(event);
462 }
463 Err(e) => {
464 yield Err(crate::error::GraphError::NodeExecutionFailed {
465 node: name.clone(),
466 message: e.to_string(),
467 });
468 return;
469 }
470 }
471 }
472
473 for event in &all_events {
475 if let Ok(json) = serde_json::to_value(event) {
476 yield Ok(StreamEvent::custom(&name, "agent_event", json));
477 }
478 }
479 })
480 }
481}
482
483struct GraphInvocationContext {
485 invocation_id: String,
486 user_content: adk_core::Content,
487 agent: Arc<dyn adk_core::Agent>,
488 session: Arc<GraphSession>,
489 run_config: adk_core::RunConfig,
490 ended: std::sync::atomic::AtomicBool,
491}
492
493impl GraphInvocationContext {
494 fn new(
495 session_id: String,
496 user_content: adk_core::Content,
497 agent: Arc<dyn adk_core::Agent>,
498 ) -> Self {
499 let invocation_id = uuid::Uuid::new_v4().to_string();
500 let session = Arc::new(GraphSession::new(session_id));
501 session.append_content(user_content.clone());
503 Self {
504 invocation_id,
505 user_content,
506 agent,
507 session,
508 run_config: adk_core::RunConfig::default(),
509 ended: std::sync::atomic::AtomicBool::new(false),
510 }
511 }
512}
513
514impl adk_core::ReadonlyContext for GraphInvocationContext {
516 fn invocation_id(&self) -> &str {
517 &self.invocation_id
518 }
519
520 fn agent_name(&self) -> &str {
521 self.agent.name()
522 }
523
524 fn user_id(&self) -> &str {
525 "graph_user"
526 }
527
528 fn app_name(&self) -> &str {
529 "graph_app"
530 }
531
532 fn session_id(&self) -> &str {
533 &self.session.id
534 }
535
536 fn branch(&self) -> &str {
537 "main"
538 }
539
540 fn user_content(&self) -> &adk_core::Content {
541 &self.user_content
542 }
543}
544
545#[async_trait]
547impl adk_core::CallbackContext for GraphInvocationContext {
548 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
549 None
550 }
551}
552
553#[async_trait]
555impl adk_core::InvocationContext for GraphInvocationContext {
556 fn agent(&self) -> Arc<dyn adk_core::Agent> {
557 self.agent.clone()
558 }
559
560 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
561 None
562 }
563
564 fn session(&self) -> &dyn adk_core::Session {
565 self.session.as_ref()
566 }
567
568 fn run_config(&self) -> &adk_core::RunConfig {
569 &self.run_config
570 }
571
572 fn end_invocation(&self) {
573 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
574 }
575
576 fn ended(&self) -> bool {
577 self.ended.load(std::sync::atomic::Ordering::SeqCst)
578 }
579}
580
581struct GraphSession {
583 id: String,
584 state: GraphState,
585 history: std::sync::RwLock<Vec<adk_core::Content>>,
586}
587
588impl GraphSession {
589 fn new(id: String) -> Self {
590 Self { id, state: GraphState::new(), history: std::sync::RwLock::new(Vec::new()) }
591 }
592
593 fn append_content(&self, content: adk_core::Content) {
594 if let Ok(mut h) = self.history.write() {
595 h.push(content);
596 }
597 }
598}
599
600impl adk_core::Session for GraphSession {
601 fn id(&self) -> &str {
602 &self.id
603 }
604
605 fn app_name(&self) -> &str {
606 "graph_app"
607 }
608
609 fn user_id(&self) -> &str {
610 "graph_user"
611 }
612
613 fn state(&self) -> &dyn adk_core::State {
614 &self.state
615 }
616
617 fn conversation_history(&self) -> Vec<adk_core::Content> {
618 self.history.read().ok().map(|h| h.clone()).unwrap_or_default()
619 }
620
621 fn append_to_history(&self, content: adk_core::Content) {
622 self.append_content(content);
623 }
624}
625
626struct GraphState {
628 data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
629}
630
631impl GraphState {
632 fn new() -> Self {
633 Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
634 }
635}
636
637impl adk_core::State for GraphState {
638 fn get(&self, key: &str) -> Option<serde_json::Value> {
639 self.data.read().ok()?.get(key).cloned()
640 }
641
642 fn set(&mut self, key: String, value: serde_json::Value) {
643 if let Err(msg) = adk_core::validate_state_key(&key) {
644 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
645 return;
646 }
647 if let Ok(mut data) = self.data.write() {
648 data.insert(key, value);
649 }
650 }
651
652 fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
653 self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660
661 #[tokio::test]
662 async fn test_function_node() {
663 let node = FunctionNode::new("test", |_ctx| async {
664 Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
665 });
666
667 assert_eq!(node.name(), "test");
668
669 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
670 let output = node.execute(&ctx).await.unwrap();
671
672 assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
673 }
674
675 #[tokio::test]
676 async fn test_passthrough_node() {
677 let node = PassthroughNode::new("pass");
678 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
679 let output = node.execute(&ctx).await.unwrap();
680
681 assert!(output.updates.is_empty());
682 assert!(output.interrupt.is_none());
683 }
684
685 #[test]
686 fn test_node_output_builder() {
687 let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
688
689 assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
690 assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
691 }
692}