1use crate::{event::Event, InvocationContext, Result};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use std::pin::Pin;
5use std::sync::Arc;
6
7pub type EventStream = Pin<Box<dyn Stream<Item = Result<Event>> + Send>>;
8
9#[async_trait]
10pub trait Agent: Send + Sync {
11 fn name(&self) -> &str;
12 fn description(&self) -> &str;
13 fn sub_agents(&self) -> &[Arc<dyn Agent>];
14
15 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream>;
16}
17
18#[cfg(test)]
19mod tests {
20 use super::*;
21 use crate::{Content, ReadonlyContext, RunConfig};
22 use async_stream::stream;
23
24 struct TestAgent {
25 name: String,
26 }
27
28 use crate::{CallbackContext, Session, State};
29 use std::collections::HashMap;
30
31 struct MockState;
32 impl State for MockState {
33 fn get(&self, _key: &str) -> Option<serde_json::Value> {
34 None
35 }
36 fn set(&mut self, _key: String, _value: serde_json::Value) {}
37 fn all(&self) -> HashMap<String, serde_json::Value> {
38 HashMap::new()
39 }
40 }
41
42 struct MockSession;
43 impl Session for MockSession {
44 fn id(&self) -> &str {
45 "session"
46 }
47 fn app_name(&self) -> &str {
48 "app"
49 }
50 fn user_id(&self) -> &str {
51 "user"
52 }
53 fn state(&self) -> &dyn State {
54 &MockState
55 }
56 fn conversation_history(&self) -> Vec<Content> {
57 Vec::new()
58 }
59 }
60
61 #[allow(dead_code)]
62 struct TestContext {
63 content: Content,
64 config: RunConfig,
65 session: MockSession,
66 }
67
68 #[allow(dead_code)]
69 impl TestContext {
70 fn new() -> Self {
71 Self {
72 content: Content::new("user"),
73 config: RunConfig::default(),
74 session: MockSession,
75 }
76 }
77 }
78
79 #[async_trait]
80 impl ReadonlyContext for TestContext {
81 fn invocation_id(&self) -> &str {
82 "test"
83 }
84 fn agent_name(&self) -> &str {
85 "test"
86 }
87 fn user_id(&self) -> &str {
88 "user"
89 }
90 fn app_name(&self) -> &str {
91 "app"
92 }
93 fn session_id(&self) -> &str {
94 "session"
95 }
96 fn branch(&self) -> &str {
97 ""
98 }
99 fn user_content(&self) -> &Content {
100 &self.content
101 }
102 }
103
104 #[async_trait]
105 impl CallbackContext for TestContext {
106 fn artifacts(&self) -> Option<Arc<dyn crate::Artifacts>> {
107 None
108 }
109 }
110
111 #[async_trait]
112 impl InvocationContext for TestContext {
113 fn agent(&self) -> Arc<dyn Agent> {
114 unimplemented!()
115 }
116 fn memory(&self) -> Option<Arc<dyn crate::Memory>> {
117 None
118 }
119 fn session(&self) -> &dyn Session {
120 &self.session
121 }
122 fn run_config(&self) -> &RunConfig {
123 &self.config
124 }
125 fn end_invocation(&self) {}
126 fn ended(&self) -> bool {
127 false
128 }
129 }
130
131 #[async_trait]
132 impl Agent for TestAgent {
133 fn name(&self) -> &str {
134 &self.name
135 }
136
137 fn description(&self) -> &str {
138 "test agent"
139 }
140
141 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
142 &[]
143 }
144
145 async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
146 let s = stream! {
147 yield Ok(Event::new("test"));
148 };
149 Ok(Box::pin(s))
150 }
151 }
152
153 #[test]
154 fn test_agent_trait() {
155 let agent = TestAgent { name: "test".to_string() };
156 assert_eq!(agent.name(), "test");
157 assert_eq!(agent.description(), "test agent");
158 }
159}