1use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use crate::timeout::ProgressHandle;
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16
17#[derive(Clone)]
19pub struct ExecutionConfig {
20 pub thread_id: String,
22 pub resume_from: Option<String>,
24 pub recursion_limit: usize,
26 pub metadata: HashMap<String, Value>,
28}
29
30impl ExecutionConfig {
31 pub fn new(thread_id: &str) -> Self {
33 Self {
34 thread_id: thread_id.to_string(),
35 resume_from: None,
36 recursion_limit: 50,
37 metadata: HashMap::new(),
38 }
39 }
40
41 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
43 self.recursion_limit = limit;
44 self
45 }
46
47 pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
49 self.resume_from = Some(checkpoint_id.to_string());
50 self
51 }
52
53 pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
55 self.metadata.insert(key.to_string(), value);
56 self
57 }
58}
59
60impl Default for ExecutionConfig {
61 fn default() -> Self {
62 Self::new(&uuid::Uuid::new_v4().to_string())
63 }
64}
65
66pub struct NodeContext {
68 pub state: State,
70 pub config: ExecutionConfig,
72 pub step: usize,
74 progress_handle: Option<ProgressHandle>,
77}
78
79impl NodeContext {
80 pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
82 Self { state, config, step, progress_handle: None }
83 }
84
85 pub fn get(&self, key: &str) -> Option<&Value> {
87 self.state.get(key)
88 }
89
90 pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
92 self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
93 }
94
95 pub fn report_progress(&self) {
113 if let Some(handle) = &self.progress_handle {
114 handle.report_progress();
115 }
116 }
117
118 pub fn set_progress_handle(&mut self, handle: ProgressHandle) {
123 self.progress_handle = Some(handle);
124 }
125
126 pub fn progress_handle(&self) -> Option<&ProgressHandle> {
128 self.progress_handle.as_ref()
129 }
130}
131
132#[derive(Default)]
134pub struct NodeOutput {
135 pub updates: HashMap<String, Value>,
137 pub interrupt: Option<Interrupt>,
139 pub events: Vec<StreamEvent>,
141}
142
143impl NodeOutput {
144 pub fn new() -> Self {
146 Self::default()
147 }
148
149 pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
151 self.updates.insert(key.to_string(), value.into());
152 self
153 }
154
155 pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
157 self.updates.extend(updates);
158 self
159 }
160
161 pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
163 self.interrupt = Some(interrupt);
164 self
165 }
166
167 pub fn with_event(mut self, event: StreamEvent) -> Self {
169 self.events.push(event);
170 self
171 }
172
173 pub fn interrupt(message: &str) -> Self {
175 Self::new().with_interrupt(crate::interrupt::interrupt(message))
176 }
177
178 pub fn interrupt_with_data(message: &str, data: Value) -> Self {
180 Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
181 }
182}
183
184#[async_trait]
186pub trait Node: Send + Sync {
187 fn name(&self) -> &str;
189
190 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
192
193 fn execute_stream<'a>(
195 &'a self,
196 ctx: &'a NodeContext,
197 ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
198 let _name = self.name().to_string();
199 Box::pin(async_stream::stream! {
200 match self.execute(ctx).await {
201 Ok(output) => {
202 for event in output.events {
203 yield Ok(event);
204 }
205 }
206 Err(e) => yield Err(e),
207 }
208 })
209 }
210}
211
212pub type BoxedNode = Box<dyn Node>;
214
215pub type AsyncNodeFn = Box<
217 dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
218>;
219
220pub struct FunctionNode {
222 name: String,
223 func: AsyncNodeFn,
224}
225
226impl FunctionNode {
227 pub fn new<F, Fut>(name: &str, func: F) -> Self
229 where
230 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
231 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
232 {
233 Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
234 }
235}
236
237#[async_trait]
238impl Node for FunctionNode {
239 fn name(&self) -> &str {
240 &self.name
241 }
242
243 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
244 let mut ctx_owned = NodeContext::new(ctx.state.clone(), ctx.config.clone(), ctx.step);
245 if let Some(handle) = ctx.progress_handle() {
246 ctx_owned.set_progress_handle(handle.clone());
247 }
248 (self.func)(ctx_owned).await
249 }
250}
251
252pub struct PassthroughNode {
254 name: String,
255}
256
257impl PassthroughNode {
258 pub fn new(name: &str) -> Self {
260 Self { name: name.to_string() }
261 }
262}
263
264#[async_trait]
265impl Node for PassthroughNode {
266 fn name(&self) -> &str {
267 &self.name
268 }
269
270 async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
271 Ok(NodeOutput::new())
272 }
273}
274
275pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
277
278pub type AgentOutputMapper =
280 Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
281
282pub struct AgentNode {
284 name: String,
285 #[allow(dead_code)]
286 agent: Arc<dyn adk_core::Agent>,
287 input_mapper: AgentInputMapper,
289 output_mapper: AgentOutputMapper,
291}
292
293impl AgentNode {
294 pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
296 let name = agent.name().to_string();
297 Self {
298 name,
299 agent,
300 input_mapper: Box::new(default_input_mapper),
301 output_mapper: Box::new(default_output_mapper),
302 }
303 }
304
305 pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
307 where
308 F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
309 {
310 self.input_mapper = Box::new(mapper);
311 self
312 }
313
314 pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
316 where
317 F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
318 {
319 self.output_mapper = Box::new(mapper);
320 self
321 }
322}
323
324fn default_input_mapper(state: &State) -> adk_core::Content {
326 if let Some(messages) = state.get("messages") {
328 if let Some(arr) = messages.as_array() {
329 if let Some(last) = arr.last() {
330 if let Some(content) = last.get("content").and_then(|c| c.as_str()) {
331 return adk_core::Content::new("user").with_text(content);
332 }
333 }
334 }
335 }
336
337 if let Some(input) = state.get("input") {
339 if let Some(text) = input.as_str() {
340 return adk_core::Content::new("user").with_text(text);
341 }
342 }
343
344 adk_core::Content::new("user")
345}
346
347fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
349 let mut updates = HashMap::new();
350
351 let mut messages = Vec::new();
353 for event in events {
354 if let Some(content) = event.content() {
355 let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
356
357 if !text.is_empty() {
358 messages.push(serde_json::json!({
359 "role": "assistant",
360 "content": text
361 }));
362 }
363 }
364 }
365
366 if !messages.is_empty() {
367 updates.insert("messages".to_string(), serde_json::json!(messages));
368 }
369
370 updates
371}
372
373#[async_trait]
374impl Node for AgentNode {
375 fn name(&self) -> &str {
376 &self.name
377 }
378
379 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
380 use futures::StreamExt;
381
382 let content = (self.input_mapper)(&ctx.state);
384
385 let invocation_ctx = Arc::new(GraphInvocationContext::new(
387 ctx.config.thread_id.clone(),
388 content,
389 self.agent.clone(),
390 ));
391
392 let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
394 crate::error::GraphError::NodeExecutionFailed {
395 node: self.name.clone(),
396 message: e.to_string(),
397 }
398 })?;
399
400 let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
401
402 let updates = (self.output_mapper)(&events);
404
405 let mut output = NodeOutput::new().with_updates(updates);
407 for event in &events {
408 if let Ok(json) = serde_json::to_value(event) {
409 output = output.with_event(StreamEvent::custom(&self.name, "agent_event", json));
410 }
411 }
412
413 Ok(output)
414 }
415
416 fn execute_stream<'a>(
417 &'a self,
418 ctx: &'a NodeContext,
419 ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
420 use futures::StreamExt;
421 let name = self.name.clone();
422 let agent = self.agent.clone();
423 let input_mapper = &self.input_mapper;
424 let thread_id = ctx.config.thread_id.clone();
425 let content = (input_mapper)(&ctx.state);
426
427 Box::pin(async_stream::stream! {
428 tracing::debug!("AgentNode::execute_stream called for {}", name);
429 let invocation_ctx = Arc::new(GraphInvocationContext::new(
430 thread_id,
431 content,
432 agent.clone(),
433 ));
434
435 let stream = match agent.run(invocation_ctx).await {
436 Ok(s) => s,
437 Err(e) => {
438 yield Err(crate::error::GraphError::NodeExecutionFailed {
439 node: name.clone(),
440 message: e.to_string(),
441 });
442 return;
443 }
444 };
445
446 tokio::pin!(stream);
447 let mut all_events = Vec::new();
448
449 while let Some(result) = stream.next().await {
450 match result {
451 Ok(event) => {
452 if let Some(content) = event.content() {
454 let text: String = content.parts.iter().filter_map(|p| p.text()).collect();
455 if !text.is_empty() {
456 yield Ok(StreamEvent::Message {
457 node: name.clone(),
458 content: text,
459 is_final: false,
460 });
461 }
462 }
463 all_events.push(event);
464 }
465 Err(e) => {
466 yield Err(crate::error::GraphError::NodeExecutionFailed {
467 node: name.clone(),
468 message: e.to_string(),
469 });
470 return;
471 }
472 }
473 }
474
475 for event in &all_events {
477 if let Ok(json) = serde_json::to_value(event) {
478 yield Ok(StreamEvent::custom(&name, "agent_event", json));
479 }
480 }
481 })
482 }
483}
484
485struct GraphInvocationContext {
487 invocation_id: String,
488 user_content: adk_core::Content,
489 agent: Arc<dyn adk_core::Agent>,
490 session: Arc<GraphSession>,
491 run_config: adk_core::RunConfig,
492 ended: std::sync::atomic::AtomicBool,
493}
494
495impl GraphInvocationContext {
496 fn new(
497 session_id: String,
498 user_content: adk_core::Content,
499 agent: Arc<dyn adk_core::Agent>,
500 ) -> Self {
501 let invocation_id = uuid::Uuid::new_v4().to_string();
502 let session = Arc::new(GraphSession::new(session_id));
503 session.append_content(user_content.clone());
505 Self {
506 invocation_id,
507 user_content,
508 agent,
509 session,
510 run_config: adk_core::RunConfig::default(),
511 ended: std::sync::atomic::AtomicBool::new(false),
512 }
513 }
514}
515
516impl adk_core::ReadonlyContext for GraphInvocationContext {
518 fn invocation_id(&self) -> &str {
519 &self.invocation_id
520 }
521
522 fn agent_name(&self) -> &str {
523 self.agent.name()
524 }
525
526 fn user_id(&self) -> &str {
527 "graph_user"
528 }
529
530 fn app_name(&self) -> &str {
531 "graph_app"
532 }
533
534 fn session_id(&self) -> &str {
535 &self.session.id
536 }
537
538 fn branch(&self) -> &str {
539 "main"
540 }
541
542 fn user_content(&self) -> &adk_core::Content {
543 &self.user_content
544 }
545}
546
547#[async_trait]
549impl adk_core::CallbackContext for GraphInvocationContext {
550 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
551 None
552 }
553}
554
555#[async_trait]
557impl adk_core::InvocationContext for GraphInvocationContext {
558 fn agent(&self) -> Arc<dyn adk_core::Agent> {
559 self.agent.clone()
560 }
561
562 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
563 None
564 }
565
566 fn session(&self) -> &dyn adk_core::Session {
567 self.session.as_ref()
568 }
569
570 fn run_config(&self) -> &adk_core::RunConfig {
571 &self.run_config
572 }
573
574 fn end_invocation(&self) {
575 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
576 }
577
578 fn ended(&self) -> bool {
579 self.ended.load(std::sync::atomic::Ordering::SeqCst)
580 }
581}
582
583struct GraphSession {
585 id: String,
586 state: GraphState,
587 history: std::sync::RwLock<Vec<adk_core::Content>>,
588}
589
590impl GraphSession {
591 fn new(id: String) -> Self {
592 Self { id, state: GraphState::new(), history: std::sync::RwLock::new(Vec::new()) }
593 }
594
595 fn append_content(&self, content: adk_core::Content) {
596 if let Ok(mut h) = self.history.write() {
597 h.push(content);
598 }
599 }
600}
601
602impl adk_core::Session for GraphSession {
603 fn id(&self) -> &str {
604 &self.id
605 }
606
607 fn app_name(&self) -> &str {
608 "graph_app"
609 }
610
611 fn user_id(&self) -> &str {
612 "graph_user"
613 }
614
615 fn state(&self) -> &dyn adk_core::State {
616 &self.state
617 }
618
619 fn conversation_history(&self) -> Vec<adk_core::Content> {
620 self.history.read().ok().map(|h| h.clone()).unwrap_or_default()
621 }
622
623 fn append_to_history(&self, content: adk_core::Content) {
624 self.append_content(content);
625 }
626}
627
628struct GraphState {
630 data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
631}
632
633impl GraphState {
634 fn new() -> Self {
635 Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
636 }
637}
638
639impl adk_core::State for GraphState {
640 fn get(&self, key: &str) -> Option<serde_json::Value> {
641 self.data.read().ok()?.get(key).cloned()
642 }
643
644 fn set(&mut self, key: String, value: serde_json::Value) {
645 if let Err(msg) = adk_core::validate_state_key(&key) {
646 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
647 return;
648 }
649 if let Ok(mut data) = self.data.write() {
650 data.insert(key, value);
651 }
652 }
653
654 fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
655 self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 #[tokio::test]
664 async fn test_function_node() {
665 let node = FunctionNode::new("test", |_ctx| async {
666 Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
667 });
668
669 assert_eq!(node.name(), "test");
670
671 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
672 let output = node.execute(&ctx).await.unwrap();
673
674 assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
675 }
676
677 #[tokio::test]
678 async fn test_passthrough_node() {
679 let node = PassthroughNode::new("pass");
680 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
681 let output = node.execute(&ctx).await.unwrap();
682
683 assert!(output.updates.is_empty());
684 assert!(output.interrupt.is_none());
685 }
686
687 #[test]
688 fn test_node_output_builder() {
689 let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
690
691 assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
692 assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
693 }
694}