Skip to main content

rs_adk/
text_runner.rs

1//! InMemoryRunner — runs TextAgents with session management and services.
2//!
3//! Provides a complete runtime for text-based agent execution with automatic
4//! session management, memory, artifacts, and plugin hooks.
5
6use std::sync::Arc;
7
8use crate::artifacts::{ArtifactService, InMemoryArtifactService};
9use crate::error::AgentError;
10use crate::events::Event;
11use crate::memory::{InMemoryMemoryService, MemoryService};
12use crate::plugin::{Plugin, PluginManager};
13use crate::session::{InMemorySessionService, SessionId, SessionService};
14use crate::state::State;
15use crate::text::TextAgent;
16
17/// Runs TextAgents with full service wiring (session, memory, artifacts, plugins).
18///
19/// Auto-wires in-memory service implementations by default; override with
20/// builder methods for custom persistence.
21pub struct InMemoryRunner {
22    root_agent: Arc<dyn TextAgent>,
23    session_service: Arc<dyn SessionService>,
24    memory_service: Arc<dyn MemoryService>,
25    artifact_service: Arc<dyn ArtifactService>,
26    plugins: PluginManager,
27    app_name: String,
28}
29
30impl InMemoryRunner {
31    /// Create a new runner with in-memory defaults for all services.
32    pub fn new(agent: Arc<dyn TextAgent>, app_name: impl Into<String>) -> Self {
33        Self {
34            root_agent: agent,
35            session_service: Arc::new(InMemorySessionService::new()),
36            memory_service: Arc::new(InMemoryMemoryService::new()),
37            artifact_service: Arc::new(InMemoryArtifactService::new()),
38            plugins: PluginManager::new(),
39            app_name: app_name.into(),
40        }
41    }
42
43    /// Override the session service.
44    pub fn session_service(mut self, svc: Arc<dyn SessionService>) -> Self {
45        self.session_service = svc;
46        self
47    }
48
49    /// Override the memory service.
50    pub fn memory_service(mut self, svc: Arc<dyn MemoryService>) -> Self {
51        self.memory_service = svc;
52        self
53    }
54
55    /// Override the artifact service.
56    pub fn artifact_service(mut self, svc: Arc<dyn ArtifactService>) -> Self {
57        self.artifact_service = svc;
58        self
59    }
60
61    /// Add a plugin.
62    pub fn plugin(mut self, p: impl Plugin + 'static) -> Self {
63        self.plugins.add(Arc::new(p));
64        self
65    }
66
67    /// Run with session management. Creates or resumes a session.
68    ///
69    /// 1. Creates a new session or loads an existing one
70    /// 2. Sets `"input"` in state from `prompt`
71    /// 3. Runs the agent
72    /// 4. Persists the result as an event in the session
73    /// 5. Returns the agent's text output
74    pub async fn run(
75        &self,
76        prompt: &str,
77        user_id: &str,
78        session_id: Option<&SessionId>,
79    ) -> Result<String, AgentError> {
80        // 1. Create or load session
81        let session = match session_id {
82            Some(id) => self
83                .session_service
84                .get_session(id)
85                .await
86                .map_err(|e| AgentError::Other(format!("Session error: {e}")))?
87                .ok_or_else(|| AgentError::Other(format!("Session not found: {id}")))?,
88            None => self
89                .session_service
90                .create_session(&self.app_name, user_id)
91                .await
92                .map_err(|e| AgentError::Other(format!("Session create error: {e}")))?,
93        };
94
95        // 2. Build state and set input
96        let state = State::new();
97
98        // Load existing events to rebuild state (state deltas)
99        let events = self
100            .session_service
101            .get_events(&session.id)
102            .await
103            .map_err(|e| AgentError::Other(format!("Events error: {e}")))?;
104        for event in &events {
105            for (key, value) in &event.actions.state_delta {
106                state.set(key.clone(), value.clone());
107            }
108        }
109
110        state.set("input", prompt);
111
112        // Persist user input event
113        let user_event = Event::new("user", Some(prompt.to_string()));
114        self.session_service
115            .append_event(&session.id, user_event)
116            .await
117            .map_err(|e| AgentError::Other(format!("Event append error: {e}")))?;
118
119        // 3. Run agent
120        let result = self.root_agent.run(&state).await?;
121
122        // 4. Persist result event
123        let result_event = Event::new(self.root_agent.name(), Some(result.clone()));
124        self.session_service
125            .append_event(&session.id, result_event)
126            .await
127            .map_err(|e| AgentError::Other(format!("Event append error: {e}")))?;
128
129        // 5. Return result
130        Ok(result)
131    }
132
133    /// Run without persistence (one-shot, ephemeral).
134    pub async fn run_ephemeral(&self, prompt: &str) -> Result<String, AgentError> {
135        let state = State::new();
136        state.set("input", prompt);
137        self.root_agent.run(&state).await
138    }
139
140    /// Access the session service.
141    pub fn session_service_ref(&self) -> &dyn SessionService {
142        self.session_service.as_ref()
143    }
144
145    /// Access the app name.
146    pub fn app_name(&self) -> &str {
147        &self.app_name
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::text::FnTextAgent;
155
156    fn echo_agent() -> Arc<dyn TextAgent> {
157        Arc::new(FnTextAgent::new("echo", |state| {
158            let input: String = state.get("input").unwrap_or_default();
159            Ok(format!("Echo: {input}"))
160        }))
161    }
162
163    #[tokio::test]
164    async fn run_ephemeral() {
165        let runner = InMemoryRunner::new(echo_agent(), "test-app");
166        let result = runner.run_ephemeral("Hello").await.unwrap();
167        assert_eq!(result, "Echo: Hello");
168    }
169
170    #[tokio::test]
171    async fn run_with_session_creates_and_persists() {
172        let runner = InMemoryRunner::new(echo_agent(), "test-app");
173
174        // First run — creates session
175        let result = runner.run("Hello", "user-1", None).await.unwrap();
176        assert_eq!(result, "Echo: Hello");
177
178        // Verify session was created
179        let sessions = runner
180            .session_service_ref()
181            .list_sessions("test-app", "user-1")
182            .await
183            .unwrap();
184        assert_eq!(sessions.len(), 1);
185
186        // Verify events were persisted (user input + agent response)
187        let events = runner
188            .session_service_ref()
189            .get_events(&sessions[0].id)
190            .await
191            .unwrap();
192        assert_eq!(events.len(), 2);
193        assert_eq!(events[0].author, "user");
194        assert_eq!(events[1].author, "echo");
195    }
196
197    #[tokio::test]
198    async fn run_resumes_existing_session() {
199        let runner = InMemoryRunner::new(echo_agent(), "test-app");
200
201        // Create a session via first run
202        let result1 = runner.run("First", "user-1", None).await.unwrap();
203        assert_eq!(result1, "Echo: First");
204
205        // Get the session ID
206        let sessions = runner
207            .session_service_ref()
208            .list_sessions("test-app", "user-1")
209            .await
210            .unwrap();
211        let session_id = &sessions[0].id;
212
213        // Resume with the same session
214        let result2 = runner
215            .run("Second", "user-1", Some(session_id))
216            .await
217            .unwrap();
218        assert_eq!(result2, "Echo: Second");
219
220        // Should have 4 events total (2 per run)
221        let events = runner
222            .session_service_ref()
223            .get_events(session_id)
224            .await
225            .unwrap();
226        assert_eq!(events.len(), 4);
227    }
228
229    #[tokio::test]
230    async fn run_with_nonexistent_session_errors() {
231        let runner = InMemoryRunner::new(echo_agent(), "test-app");
232        let fake_id = SessionId::new();
233        let result = runner.run("Hello", "user-1", Some(&fake_id)).await;
234        assert!(result.is_err());
235    }
236
237    #[tokio::test]
238    async fn custom_session_service() {
239        let custom_svc = Arc::new(InMemorySessionService::new());
240        let runner = InMemoryRunner::new(echo_agent(), "app").session_service(custom_svc.clone());
241
242        runner.run("Hi", "u1", None).await.unwrap();
243
244        let sessions = custom_svc.list_sessions("app", "u1").await.unwrap();
245        assert_eq!(sessions.len(), 1);
246    }
247}