1use std::sync::Arc;
7
8use crate::artifacts::{ArtifactService, InMemoryArtifactService};
9use crate::error::AgentError;
10use crate::events::Event;
11use crate::memory::{InMemoryMemoryService, MemoryService};
12use crate::plugin::{Plugin, PluginManager};
13use crate::session::{InMemorySessionService, SessionId, SessionService};
14use crate::state::State;
15use crate::text::TextAgent;
16
17pub struct InMemoryRunner {
22 root_agent: Arc<dyn TextAgent>,
23 session_service: Arc<dyn SessionService>,
24 memory_service: Arc<dyn MemoryService>,
25 artifact_service: Arc<dyn ArtifactService>,
26 plugins: PluginManager,
27 app_name: String,
28}
29
30impl InMemoryRunner {
31 pub fn new(agent: Arc<dyn TextAgent>, app_name: impl Into<String>) -> Self {
33 Self {
34 root_agent: agent,
35 session_service: Arc::new(InMemorySessionService::new()),
36 memory_service: Arc::new(InMemoryMemoryService::new()),
37 artifact_service: Arc::new(InMemoryArtifactService::new()),
38 plugins: PluginManager::new(),
39 app_name: app_name.into(),
40 }
41 }
42
43 pub fn session_service(mut self, svc: Arc<dyn SessionService>) -> Self {
45 self.session_service = svc;
46 self
47 }
48
49 pub fn memory_service(mut self, svc: Arc<dyn MemoryService>) -> Self {
51 self.memory_service = svc;
52 self
53 }
54
55 pub fn artifact_service(mut self, svc: Arc<dyn ArtifactService>) -> Self {
57 self.artifact_service = svc;
58 self
59 }
60
61 pub fn plugin(mut self, p: impl Plugin + 'static) -> Self {
63 self.plugins.add(Arc::new(p));
64 self
65 }
66
67 pub async fn run(
75 &self,
76 prompt: &str,
77 user_id: &str,
78 session_id: Option<&SessionId>,
79 ) -> Result<String, AgentError> {
80 let session = match session_id {
82 Some(id) => self
83 .session_service
84 .get_session(id)
85 .await
86 .map_err(|e| AgentError::Other(format!("Session error: {e}")))?
87 .ok_or_else(|| AgentError::Other(format!("Session not found: {id}")))?,
88 None => self
89 .session_service
90 .create_session(&self.app_name, user_id)
91 .await
92 .map_err(|e| AgentError::Other(format!("Session create error: {e}")))?,
93 };
94
95 let state = State::new();
97
98 let events = self
100 .session_service
101 .get_events(&session.id)
102 .await
103 .map_err(|e| AgentError::Other(format!("Events error: {e}")))?;
104 for event in &events {
105 for (key, value) in &event.actions.state_delta {
106 state.set(key.clone(), value.clone());
107 }
108 }
109
110 state.set("input", prompt);
111
112 let user_event = Event::new("user", Some(prompt.to_string()));
114 self.session_service
115 .append_event(&session.id, user_event)
116 .await
117 .map_err(|e| AgentError::Other(format!("Event append error: {e}")))?;
118
119 let result = self.root_agent.run(&state).await?;
121
122 let result_event = Event::new(self.root_agent.name(), Some(result.clone()));
124 self.session_service
125 .append_event(&session.id, result_event)
126 .await
127 .map_err(|e| AgentError::Other(format!("Event append error: {e}")))?;
128
129 Ok(result)
131 }
132
133 pub async fn run_ephemeral(&self, prompt: &str) -> Result<String, AgentError> {
135 let state = State::new();
136 state.set("input", prompt);
137 self.root_agent.run(&state).await
138 }
139
140 pub fn session_service_ref(&self) -> &dyn SessionService {
142 self.session_service.as_ref()
143 }
144
145 pub fn app_name(&self) -> &str {
147 &self.app_name
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::text::FnTextAgent;
155
156 fn echo_agent() -> Arc<dyn TextAgent> {
157 Arc::new(FnTextAgent::new("echo", |state| {
158 let input: String = state.get("input").unwrap_or_default();
159 Ok(format!("Echo: {input}"))
160 }))
161 }
162
163 #[tokio::test]
164 async fn run_ephemeral() {
165 let runner = InMemoryRunner::new(echo_agent(), "test-app");
166 let result = runner.run_ephemeral("Hello").await.unwrap();
167 assert_eq!(result, "Echo: Hello");
168 }
169
170 #[tokio::test]
171 async fn run_with_session_creates_and_persists() {
172 let runner = InMemoryRunner::new(echo_agent(), "test-app");
173
174 let result = runner.run("Hello", "user-1", None).await.unwrap();
176 assert_eq!(result, "Echo: Hello");
177
178 let sessions = runner
180 .session_service_ref()
181 .list_sessions("test-app", "user-1")
182 .await
183 .unwrap();
184 assert_eq!(sessions.len(), 1);
185
186 let events = runner
188 .session_service_ref()
189 .get_events(&sessions[0].id)
190 .await
191 .unwrap();
192 assert_eq!(events.len(), 2);
193 assert_eq!(events[0].author, "user");
194 assert_eq!(events[1].author, "echo");
195 }
196
197 #[tokio::test]
198 async fn run_resumes_existing_session() {
199 let runner = InMemoryRunner::new(echo_agent(), "test-app");
200
201 let result1 = runner.run("First", "user-1", None).await.unwrap();
203 assert_eq!(result1, "Echo: First");
204
205 let sessions = runner
207 .session_service_ref()
208 .list_sessions("test-app", "user-1")
209 .await
210 .unwrap();
211 let session_id = &sessions[0].id;
212
213 let result2 = runner
215 .run("Second", "user-1", Some(session_id))
216 .await
217 .unwrap();
218 assert_eq!(result2, "Echo: Second");
219
220 let events = runner
222 .session_service_ref()
223 .get_events(session_id)
224 .await
225 .unwrap();
226 assert_eq!(events.len(), 4);
227 }
228
229 #[tokio::test]
230 async fn run_with_nonexistent_session_errors() {
231 let runner = InMemoryRunner::new(echo_agent(), "test-app");
232 let fake_id = SessionId::new();
233 let result = runner.run("Hello", "user-1", Some(&fake_id)).await;
234 assert!(result.is_err());
235 }
236
237 #[tokio::test]
238 async fn custom_session_service() {
239 let custom_svc = Arc::new(InMemorySessionService::new());
240 let runner = InMemoryRunner::new(echo_agent(), "app").session_service(custom_svc.clone());
241
242 runner.run("Hi", "u1", None).await.unwrap();
243
244 let sessions = custom_svc.list_sessions("app", "u1").await.unwrap();
245 assert_eq!(sessions.len(), 1);
246 }
247}