adk_core/
agent.rs

1use crate::{event::Event, InvocationContext, Result};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use std::pin::Pin;
5use std::sync::Arc;
6
7pub type EventStream = Pin<Box<dyn Stream<Item = Result<Event>> + Send>>;
8
9#[async_trait]
10pub trait Agent: Send + Sync {
11    fn name(&self) -> &str;
12    fn description(&self) -> &str;
13    fn sub_agents(&self) -> &[Arc<dyn Agent>];
14
15    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream>;
16}
17
18#[cfg(test)]
19mod tests {
20    use super::*;
21    use crate::{Content, ReadonlyContext, RunConfig};
22    use async_stream::stream;
23
24    struct TestAgent {
25        name: String,
26    }
27
28    use crate::{CallbackContext, Session, State};
29    use std::collections::HashMap;
30
31    struct MockState;
32    impl State for MockState {
33        fn get(&self, _key: &str) -> Option<serde_json::Value> {
34            None
35        }
36        fn set(&mut self, _key: String, _value: serde_json::Value) {}
37        fn all(&self) -> HashMap<String, serde_json::Value> {
38            HashMap::new()
39        }
40    }
41
42    struct MockSession;
43    impl Session for MockSession {
44        fn id(&self) -> &str {
45            "session"
46        }
47        fn app_name(&self) -> &str {
48            "app"
49        }
50        fn user_id(&self) -> &str {
51            "user"
52        }
53        fn state(&self) -> &dyn State {
54            &MockState
55        }
56        fn conversation_history(&self) -> Vec<Content> {
57            Vec::new()
58        }
59    }
60
61    #[allow(dead_code)]
62    struct TestContext {
63        content: Content,
64        config: RunConfig,
65        session: MockSession,
66    }
67
68    #[allow(dead_code)]
69    impl TestContext {
70        fn new() -> Self {
71            Self {
72                content: Content::new("user"),
73                config: RunConfig::default(),
74                session: MockSession,
75            }
76        }
77    }
78
79    #[async_trait]
80    impl ReadonlyContext for TestContext {
81        fn invocation_id(&self) -> &str {
82            "test"
83        }
84        fn agent_name(&self) -> &str {
85            "test"
86        }
87        fn user_id(&self) -> &str {
88            "user"
89        }
90        fn app_name(&self) -> &str {
91            "app"
92        }
93        fn session_id(&self) -> &str {
94            "session"
95        }
96        fn branch(&self) -> &str {
97            ""
98        }
99        fn user_content(&self) -> &Content {
100            &self.content
101        }
102    }
103
104    #[async_trait]
105    impl CallbackContext for TestContext {
106        fn artifacts(&self) -> Option<Arc<dyn crate::Artifacts>> {
107            None
108        }
109    }
110
111    #[async_trait]
112    impl InvocationContext for TestContext {
113        fn agent(&self) -> Arc<dyn Agent> {
114            unimplemented!()
115        }
116        fn memory(&self) -> Option<Arc<dyn crate::Memory>> {
117            None
118        }
119        fn session(&self) -> &dyn Session {
120            &self.session
121        }
122        fn run_config(&self) -> &RunConfig {
123            &self.config
124        }
125        fn end_invocation(&self) {}
126        fn ended(&self) -> bool {
127            false
128        }
129    }
130
131    #[async_trait]
132    impl Agent for TestAgent {
133        fn name(&self) -> &str {
134            &self.name
135        }
136
137        fn description(&self) -> &str {
138            "test agent"
139        }
140
141        fn sub_agents(&self) -> &[Arc<dyn Agent>] {
142            &[]
143        }
144
145        async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
146            let s = stream! {
147                yield Ok(Event::new("test"));
148            };
149            Ok(Box::pin(s))
150        }
151    }
152
153    #[test]
154    fn test_agent_trait() {
155        let agent = TestAgent { name: "test".to_string() };
156        assert_eq!(agent.name(), "test");
157        assert_eq!(agent.description(), "test agent");
158    }
159}