1use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use async_trait::async_trait;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16#[derive(Clone)]
18pub struct ExecutionConfig {
19 pub thread_id: String,
21 pub resume_from: Option<String>,
23 pub recursion_limit: usize,
25 pub metadata: HashMap<String, Value>,
27}
28
29impl ExecutionConfig {
30 pub fn new(thread_id: &str) -> Self {
32 Self {
33 thread_id: thread_id.to_string(),
34 resume_from: None,
35 recursion_limit: 50,
36 metadata: HashMap::new(),
37 }
38 }
39
40 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
42 self.recursion_limit = limit;
43 self
44 }
45
46 pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
48 self.resume_from = Some(checkpoint_id.to_string());
49 self
50 }
51
52 pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
54 self.metadata.insert(key.to_string(), value);
55 self
56 }
57}
58
59impl Default for ExecutionConfig {
60 fn default() -> Self {
61 Self::new(&uuid::Uuid::new_v4().to_string())
62 }
63}
64
65pub struct NodeContext {
67 pub state: State,
69 pub config: ExecutionConfig,
71 pub step: usize,
73}
74
75impl NodeContext {
76 pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
78 Self { state, config, step }
79 }
80
81 pub fn get(&self, key: &str) -> Option<&Value> {
83 self.state.get(key)
84 }
85
86 pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
88 self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
89 }
90}
91
92#[derive(Default)]
94pub struct NodeOutput {
95 pub updates: HashMap<String, Value>,
97 pub interrupt: Option<Interrupt>,
99 pub events: Vec<StreamEvent>,
101}
102
103impl NodeOutput {
104 pub fn new() -> Self {
106 Self::default()
107 }
108
109 pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
111 self.updates.insert(key.to_string(), value.into());
112 self
113 }
114
115 pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
117 self.updates.extend(updates);
118 self
119 }
120
121 pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
123 self.interrupt = Some(interrupt);
124 self
125 }
126
127 pub fn with_event(mut self, event: StreamEvent) -> Self {
129 self.events.push(event);
130 self
131 }
132
133 pub fn interrupt(message: &str) -> Self {
135 Self::new().with_interrupt(crate::interrupt::interrupt(message))
136 }
137
138 pub fn interrupt_with_data(message: &str, data: Value) -> Self {
140 Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
141 }
142}
143
144#[async_trait]
146pub trait Node: Send + Sync {
147 fn name(&self) -> &str;
149
150 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
152}
153
154pub type BoxedNode = Box<dyn Node>;
156
157pub type AsyncNodeFn = Box<
159 dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
160>;
161
162pub struct FunctionNode {
164 name: String,
165 func: AsyncNodeFn,
166}
167
168impl FunctionNode {
169 pub fn new<F, Fut>(name: &str, func: F) -> Self
171 where
172 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
173 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
174 {
175 Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
176 }
177}
178
179#[async_trait]
180impl Node for FunctionNode {
181 fn name(&self) -> &str {
182 &self.name
183 }
184
185 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
186 let ctx_owned =
187 NodeContext { state: ctx.state.clone(), config: ctx.config.clone(), step: ctx.step };
188 (self.func)(ctx_owned).await
189 }
190}
191
192pub struct PassthroughNode {
194 name: String,
195}
196
197impl PassthroughNode {
198 pub fn new(name: &str) -> Self {
200 Self { name: name.to_string() }
201 }
202}
203
204#[async_trait]
205impl Node for PassthroughNode {
206 fn name(&self) -> &str {
207 &self.name
208 }
209
210 async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
211 Ok(NodeOutput::new())
212 }
213}
214
215pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
217
218pub type AgentOutputMapper =
220 Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
221
222pub struct AgentNode {
224 name: String,
225 #[allow(dead_code)]
226 agent: Arc<dyn adk_core::Agent>,
227 input_mapper: AgentInputMapper,
229 output_mapper: AgentOutputMapper,
231}
232
233impl AgentNode {
234 pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
236 let name = agent.name().to_string();
237 Self {
238 name,
239 agent,
240 input_mapper: Box::new(default_input_mapper),
241 output_mapper: Box::new(default_output_mapper),
242 }
243 }
244
245 pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
247 where
248 F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
249 {
250 self.input_mapper = Box::new(mapper);
251 self
252 }
253
254 pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
256 where
257 F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
258 {
259 self.output_mapper = Box::new(mapper);
260 self
261 }
262}
263
264fn default_input_mapper(state: &State) -> adk_core::Content {
266 if let Some(messages) = state.get("messages") {
268 if let Some(arr) = messages.as_array() {
269 if let Some(last) = arr.last() {
270 if let Some(content) = last.get("content").and_then(|c| c.as_str()) {
271 return adk_core::Content::new("user").with_text(content);
272 }
273 }
274 }
275 }
276
277 if let Some(input) = state.get("input") {
279 if let Some(text) = input.as_str() {
280 return adk_core::Content::new("user").with_text(text);
281 }
282 }
283
284 adk_core::Content::new("user")
285}
286
287fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
289 let mut updates = HashMap::new();
290
291 let mut messages = Vec::new();
293 for event in events {
294 if let Some(content) = event.content() {
295 let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
296
297 if !text.is_empty() {
298 messages.push(serde_json::json!({
299 "role": "assistant",
300 "content": text
301 }));
302 }
303 }
304 }
305
306 if !messages.is_empty() {
307 updates.insert("messages".to_string(), serde_json::json!(messages));
308 }
309
310 updates
311}
312
313#[async_trait]
314impl Node for AgentNode {
315 fn name(&self) -> &str {
316 &self.name
317 }
318
319 async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
320 use futures::StreamExt;
321
322 let content = (self.input_mapper)(&ctx.state);
324
325 let invocation_ctx = Arc::new(GraphInvocationContext::new(
327 ctx.config.thread_id.clone(),
328 content,
329 self.agent.clone(),
330 ));
331
332 let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
334 crate::error::GraphError::NodeExecutionFailed {
335 node: self.name.clone(),
336 message: e.to_string(),
337 }
338 })?;
339
340 let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
341
342 let updates = (self.output_mapper)(&events);
344
345 Ok(NodeOutput::new().with_updates(updates))
346 }
347}
348
349struct GraphInvocationContext {
351 invocation_id: String,
352 user_content: adk_core::Content,
353 agent: Arc<dyn adk_core::Agent>,
354 session: GraphSession,
355 run_config: adk_core::RunConfig,
356 ended: std::sync::atomic::AtomicBool,
357}
358
359impl GraphInvocationContext {
360 fn new(
361 session_id: String,
362 user_content: adk_core::Content,
363 agent: Arc<dyn adk_core::Agent>,
364 ) -> Self {
365 let invocation_id = uuid::Uuid::new_v4().to_string();
366 Self {
367 invocation_id,
368 user_content,
369 agent,
370 session: GraphSession::new(session_id),
371 run_config: adk_core::RunConfig::default(),
372 ended: std::sync::atomic::AtomicBool::new(false),
373 }
374 }
375}
376
377impl adk_core::ReadonlyContext for GraphInvocationContext {
379 fn invocation_id(&self) -> &str {
380 &self.invocation_id
381 }
382
383 fn agent_name(&self) -> &str {
384 self.agent.name()
385 }
386
387 fn user_id(&self) -> &str {
388 "graph_user"
389 }
390
391 fn app_name(&self) -> &str {
392 "graph_app"
393 }
394
395 fn session_id(&self) -> &str {
396 &self.session.id
397 }
398
399 fn branch(&self) -> &str {
400 "main"
401 }
402
403 fn user_content(&self) -> &adk_core::Content {
404 &self.user_content
405 }
406}
407
408#[async_trait]
410impl adk_core::CallbackContext for GraphInvocationContext {
411 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
412 None
413 }
414}
415
416#[async_trait]
418impl adk_core::InvocationContext for GraphInvocationContext {
419 fn agent(&self) -> Arc<dyn adk_core::Agent> {
420 self.agent.clone()
421 }
422
423 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
424 None
425 }
426
427 fn session(&self) -> &dyn adk_core::Session {
428 &self.session
429 }
430
431 fn run_config(&self) -> &adk_core::RunConfig {
432 &self.run_config
433 }
434
435 fn end_invocation(&self) {
436 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
437 }
438
439 fn ended(&self) -> bool {
440 self.ended.load(std::sync::atomic::Ordering::SeqCst)
441 }
442}
443
444struct GraphSession {
446 id: String,
447 state: GraphState,
448}
449
450impl GraphSession {
451 fn new(id: String) -> Self {
452 Self { id, state: GraphState::new() }
453 }
454}
455
456impl adk_core::Session for GraphSession {
457 fn id(&self) -> &str {
458 &self.id
459 }
460
461 fn app_name(&self) -> &str {
462 "graph_app"
463 }
464
465 fn user_id(&self) -> &str {
466 "graph_user"
467 }
468
469 fn state(&self) -> &dyn adk_core::State {
470 &self.state
471 }
472
473 fn conversation_history(&self) -> Vec<adk_core::Content> {
474 vec![]
475 }
476}
477
478struct GraphState {
480 data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
481}
482
483impl GraphState {
484 fn new() -> Self {
485 Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
486 }
487}
488
489impl adk_core::State for GraphState {
490 fn get(&self, key: &str) -> Option<serde_json::Value> {
491 self.data.read().ok()?.get(key).cloned()
492 }
493
494 fn set(&mut self, key: String, value: serde_json::Value) {
495 if let Ok(mut data) = self.data.write() {
496 data.insert(key, value);
497 }
498 }
499
500 fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
501 self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[tokio::test]
510 async fn test_function_node() {
511 let node = FunctionNode::new("test", |_ctx| async {
512 Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
513 });
514
515 assert_eq!(node.name(), "test");
516
517 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
518 let output = node.execute(&ctx).await.unwrap();
519
520 assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
521 }
522
523 #[tokio::test]
524 async fn test_passthrough_node() {
525 let node = PassthroughNode::new("pass");
526 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
527 let output = node.execute(&ctx).await.unwrap();
528
529 assert!(output.updates.is_empty());
530 assert!(output.interrupt.is_none());
531 }
532
533 #[test]
534 fn test_node_output_builder() {
535 let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
536
537 assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
538 assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
539 }
540}