use std::sync::Arc;
use crate::artifacts::{ArtifactService, InMemoryArtifactService};
use crate::error::AgentError;
use crate::events::Event;
use crate::memory::{InMemoryMemoryService, MemoryService};
use crate::plugin::{Plugin, PluginManager};
use crate::session::{InMemorySessionService, SessionId, SessionService};
use crate::state::State;
use crate::text::TextAgent;
pub struct InMemoryRunner {
root_agent: Arc<dyn TextAgent>,
session_service: Arc<dyn SessionService>,
memory_service: Arc<dyn MemoryService>,
artifact_service: Arc<dyn ArtifactService>,
plugins: PluginManager,
app_name: String,
}
impl InMemoryRunner {
pub fn new(agent: Arc<dyn TextAgent>, app_name: impl Into<String>) -> Self {
Self {
root_agent: agent,
session_service: Arc::new(InMemorySessionService::new()),
memory_service: Arc::new(InMemoryMemoryService::new()),
artifact_service: Arc::new(InMemoryArtifactService::new()),
plugins: PluginManager::new(),
app_name: app_name.into(),
}
}
pub fn session_service(mut self, svc: Arc<dyn SessionService>) -> Self {
self.session_service = svc;
self
}
pub fn memory_service(mut self, svc: Arc<dyn MemoryService>) -> Self {
self.memory_service = svc;
self
}
pub fn artifact_service(mut self, svc: Arc<dyn ArtifactService>) -> Self {
self.artifact_service = svc;
self
}
pub fn plugin(mut self, p: impl Plugin + 'static) -> Self {
self.plugins.add(Arc::new(p));
self
}
pub async fn run(
&self,
prompt: &str,
user_id: &str,
session_id: Option<&SessionId>,
) -> Result<String, AgentError> {
let session = match session_id {
Some(id) => self
.session_service
.get_session(id)
.await
.map_err(|e| AgentError::Other(format!("Session error: {e}")))?
.ok_or_else(|| AgentError::Other(format!("Session not found: {id}")))?,
None => self
.session_service
.create_session(&self.app_name, user_id)
.await
.map_err(|e| AgentError::Other(format!("Session create error: {e}")))?,
};
let state = State::new();
let events = self
.session_service
.get_events(&session.id)
.await
.map_err(|e| AgentError::Other(format!("Events error: {e}")))?;
for event in &events {
for (key, value) in &event.actions.state_delta {
state.set(key.clone(), value.clone());
}
}
state.set("input", prompt);
let user_event = Event::new("user", Some(prompt.to_string()));
self.session_service
.append_event(&session.id, user_event)
.await
.map_err(|e| AgentError::Other(format!("Event append error: {e}")))?;
let result = self.root_agent.run(&state).await?;
let result_event = Event::new(self.root_agent.name(), Some(result.clone()));
self.session_service
.append_event(&session.id, result_event)
.await
.map_err(|e| AgentError::Other(format!("Event append error: {e}")))?;
Ok(result)
}
pub async fn run_ephemeral(&self, prompt: &str) -> Result<String, AgentError> {
let state = State::new();
state.set("input", prompt);
self.root_agent.run(&state).await
}
pub fn session_service_ref(&self) -> &dyn SessionService {
self.session_service.as_ref()
}
pub fn app_name(&self) -> &str {
&self.app_name
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::text::FnTextAgent;
fn echo_agent() -> Arc<dyn TextAgent> {
Arc::new(FnTextAgent::new("echo", |state| {
let input: String = state.get("input").unwrap_or_default();
Ok(format!("Echo: {input}"))
}))
}
#[tokio::test]
async fn run_ephemeral() {
let runner = InMemoryRunner::new(echo_agent(), "test-app");
let result = runner.run_ephemeral("Hello").await.unwrap();
assert_eq!(result, "Echo: Hello");
}
#[tokio::test]
async fn run_with_session_creates_and_persists() {
let runner = InMemoryRunner::new(echo_agent(), "test-app");
let result = runner.run("Hello", "user-1", None).await.unwrap();
assert_eq!(result, "Echo: Hello");
let sessions = runner
.session_service_ref()
.list_sessions("test-app", "user-1")
.await
.unwrap();
assert_eq!(sessions.len(), 1);
let events = runner
.session_service_ref()
.get_events(&sessions[0].id)
.await
.unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].author, "user");
assert_eq!(events[1].author, "echo");
}
#[tokio::test]
async fn run_resumes_existing_session() {
let runner = InMemoryRunner::new(echo_agent(), "test-app");
let result1 = runner.run("First", "user-1", None).await.unwrap();
assert_eq!(result1, "Echo: First");
let sessions = runner
.session_service_ref()
.list_sessions("test-app", "user-1")
.await
.unwrap();
let session_id = &sessions[0].id;
let result2 = runner
.run("Second", "user-1", Some(session_id))
.await
.unwrap();
assert_eq!(result2, "Echo: Second");
let events = runner
.session_service_ref()
.get_events(session_id)
.await
.unwrap();
assert_eq!(events.len(), 4);
}
#[tokio::test]
async fn run_with_nonexistent_session_errors() {
let runner = InMemoryRunner::new(echo_agent(), "test-app");
let fake_id = SessionId::new();
let result = runner.run("Hello", "user-1", Some(&fake_id)).await;
assert!(result.is_err());
}
#[tokio::test]
async fn custom_session_service() {
let custom_svc = Arc::new(InMemorySessionService::new());
let runner = InMemoryRunner::new(echo_agent(), "app").session_service(custom_svc.clone());
runner.run("Hi", "u1", None).await.unwrap();
let sessions = custom_svc.list_sessions("app", "u1").await.unwrap();
assert_eq!(sessions.len(), 1);
}
}