use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::{Mutex, Notify, RwLock, broadcast, mpsc};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};
use adk_core::Agent;
#[cfg(feature = "memory")]
use adk_core::Memory;
#[cfg(feature = "sandbox")]
use adk_sandbox::SandboxBackend;
use adk_session::service::{CreateRequest, SessionService};
use crate::agent_builder::{BuildError, build_agent};
use crate::checkpoint::CheckpointManager;
use crate::parking::ToolParkingLot;
use crate::replay::create_event_stream;
use crate::resolver::ModelResolver;
use crate::runtime::{AgentHandle, EnvironmentConfig, ManagedAgentRuntime, SessionHandle};
use crate::session_loop::SessionLoop;
use crate::types::{ManagedAgentDef, RuntimeError, SessionEvent, SessionStatus, UserEvent};
#[allow(dead_code)] pub(crate) struct ActiveSession {
pub(crate) agent: Arc<dyn Agent>,
pub(crate) event_tx: mpsc::Sender<crate::types::UserEvent>,
pub(crate) broadcast_tx: broadcast::Sender<crate::types::SessionEvent>,
pub(crate) cancel_token: CancellationToken,
pub(crate) pause_flag: Arc<Mutex<bool>>,
pub(crate) pause_notify: Arc<Notify>,
pub(crate) status: Arc<RwLock<SessionStatus>>,
pub(crate) checkpoint: Arc<RwLock<CheckpointManager>>,
}
pub struct DefaultManagedAgentRuntime {
model_resolver: Arc<dyn ModelResolver>,
session_service: Arc<dyn SessionService>,
#[cfg(feature = "sandbox")]
sandbox: Option<Arc<dyn SandboxBackend>>,
#[cfg(feature = "memory")]
memory: Option<Arc<dyn Memory>>,
agents: Arc<RwLock<HashMap<String, RegisteredAgent>>>,
sessions: Arc<RwLock<HashMap<String, ActiveSession>>>,
}
#[allow(dead_code)] struct RegisteredAgent {
agent: Arc<dyn Agent>,
def: ManagedAgentDef,
}
impl DefaultManagedAgentRuntime {
pub fn new(
model_resolver: Arc<dyn ModelResolver>,
session_service: Arc<dyn SessionService>,
) -> Self {
Self {
model_resolver,
session_service,
#[cfg(feature = "sandbox")]
sandbox: None,
#[cfg(feature = "memory")]
memory: None,
agents: Arc::new(RwLock::new(HashMap::new())),
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
#[cfg(feature = "sandbox")]
pub fn with_sandbox(mut self, sandbox: Arc<dyn SandboxBackend>) -> Self {
self.sandbox = Some(sandbox);
self
}
#[cfg(feature = "memory")]
pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
self.memory = Some(memory);
self
}
pub fn model_resolver(&self) -> &Arc<dyn ModelResolver> {
&self.model_resolver
}
pub fn session_service(&self) -> &Arc<dyn SessionService> {
&self.session_service
}
#[cfg(feature = "sandbox")]
pub fn sandbox(&self) -> Option<&Arc<dyn SandboxBackend>> {
self.sandbox.as_ref()
}
#[cfg(feature = "memory")]
pub fn memory(&self) -> Option<&Arc<dyn Memory>> {
self.memory.as_ref()
}
#[cfg(test)]
pub(crate) fn sessions(&self) -> &Arc<RwLock<HashMap<String, ActiveSession>>> {
&self.sessions
}
}
const DEFAULT_EVENT_CHANNEL_CAPACITY: usize = 64;
const DEFAULT_BROADCAST_CHANNEL_CAPACITY: usize = 256;
const DEFAULT_PARKING_TIMEOUT: Duration = Duration::from_secs(300);
#[async_trait]
impl ManagedAgentRuntime for DefaultManagedAgentRuntime {
async fn create(&self, def: ManagedAgentDef) -> Result<AgentHandle, RuntimeError> {
let model = self.model_resolver.resolve(&def.model).await.map_err(|e| {
RuntimeError::ProviderError {
provider: format!("{:?}", def.model),
message: e.to_string(),
}
})?;
#[cfg(feature = "sandbox")]
let agent = build_agent(&def, model, self.sandbox.clone()).map_err(|e| match e {
BuildError::InvalidDef(msg) => RuntimeError::invalid_request(msg),
BuildError::BuildFailed(msg) => RuntimeError::internal(msg),
})?;
#[cfg(not(feature = "sandbox"))]
let agent = build_agent(&def, model).map_err(|e| match e {
BuildError::InvalidDef(msg) => RuntimeError::invalid_request(msg),
BuildError::BuildFailed(msg) => RuntimeError::internal(msg),
})?;
let handle_id = uuid::Uuid::new_v4().to_string();
info!(agent_handle = %handle_id, agent_name = %def.name, "agent created");
let registered = RegisteredAgent { agent, def };
self.agents.write().await.insert(handle_id.clone(), registered);
Ok(AgentHandle(handle_id))
}
async fn start_session(
&self,
agent: &AgentHandle,
_env: Option<EnvironmentConfig>,
) -> Result<SessionHandle, RuntimeError> {
let agents = self.agents.read().await;
let registered = agents
.get(&agent.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: agent.0.clone() })?;
let agent_arc = Arc::clone(®istered.agent);
drop(agents);
let session_id = uuid::Uuid::new_v4().to_string();
let (event_tx, event_rx) = mpsc::channel(DEFAULT_EVENT_CHANNEL_CAPACITY);
let (broadcast_tx, _) = broadcast::channel(DEFAULT_BROADCAST_CHANNEL_CAPACITY);
let cancel_token = CancellationToken::new();
let pause_flag = Arc::new(Mutex::new(false));
let pause_notify = Arc::new(Notify::new());
let parking = Arc::new(ToolParkingLot::new(DEFAULT_PARKING_TIMEOUT));
let checkpoint = Arc::new(RwLock::new(CheckpointManager::new(session_id.clone())));
self.session_service
.create(CreateRequest {
app_name: "managed".to_string(),
user_id: "managed_user".to_string(),
session_id: Some(session_id.clone()),
state: std::collections::HashMap::new(),
})
.await
.map_err(|e| RuntimeError::internal(format!("failed to seed session: {e}")))?;
#[cfg(feature = "memory")]
let session_loop = SessionLoop::with_pause_controls(
session_id.clone(),
event_rx,
broadcast_tx.clone(),
Arc::clone(&parking),
cancel_token.clone(),
Arc::clone(&pause_flag),
Arc::clone(&pause_notify),
Arc::clone(&checkpoint),
Arc::clone(&agent_arc),
Arc::clone(&self.session_service),
self.memory.clone(),
);
#[cfg(not(feature = "memory"))]
let session_loop = SessionLoop::with_pause_controls(
session_id.clone(),
event_rx,
broadcast_tx.clone(),
Arc::clone(&parking),
cancel_token.clone(),
Arc::clone(&pause_flag),
Arc::clone(&pause_notify),
Arc::clone(&checkpoint),
Arc::clone(&agent_arc),
Arc::clone(&self.session_service),
);
tokio::spawn(session_loop.run());
let status = Arc::new(RwLock::new(SessionStatus::Queued));
let active_session = ActiveSession {
agent: agent_arc,
event_tx,
broadcast_tx,
cancel_token,
pause_flag,
pause_notify,
status,
checkpoint,
};
self.sessions.write().await.insert(session_id.clone(), active_session);
info!(session_id = %session_id, "session started");
Ok(SessionHandle(session_id))
}
async fn send_event(
&self,
session: &SessionHandle,
event: UserEvent,
) -> Result<(), RuntimeError> {
let sessions = self.sessions.read().await;
let active = sessions
.get(&session.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: session.0.clone() })?;
active
.event_tx
.send(event)
.await
.map_err(|_| RuntimeError::conflict("session loop channel closed"))?;
Ok(())
}
async fn stream_events(
&self,
session: &SessionHandle,
from_seq: Option<u64>,
) -> Result<BoxStream<'static, SessionEvent>, RuntimeError> {
let sessions = self.sessions.read().await;
let active = sessions
.get(&session.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: session.0.clone() })?;
let broadcast_rx = active.broadcast_tx.subscribe();
let checkpoint = active.checkpoint.read().await;
let stream = create_event_stream(&checkpoint, broadcast_rx, from_seq);
Ok(stream)
}
async fn interrupt(&self, session: &SessionHandle) -> Result<(), RuntimeError> {
let sessions = self.sessions.read().await;
let active = sessions
.get(&session.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: session.0.clone() })?;
debug!(session_id = %session.0, "interrupting session");
active.cancel_token.cancel();
Ok(())
}
async fn pause(&self, session: &SessionHandle) -> Result<(), RuntimeError> {
let sessions = self.sessions.read().await;
let active = sessions
.get(&session.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: session.0.clone() })?;
debug!(session_id = %session.0, "pausing session");
*active.pause_flag.lock().await = true;
*active.status.write().await = SessionStatus::Paused;
Ok(())
}
async fn resume(&self, session: &SessionHandle) -> Result<(), RuntimeError> {
let sessions = self.sessions.read().await;
let active = sessions
.get(&session.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: session.0.clone() })?;
debug!(session_id = %session.0, "resuming session");
*active.pause_flag.lock().await = false;
*active.status.write().await = SessionStatus::Running;
active.pause_notify.notify_one();
Ok(())
}
async fn status(&self, session: &SessionHandle) -> Result<SessionStatus, RuntimeError> {
let sessions = self.sessions.read().await;
let active = sessions
.get(&session.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: session.0.clone() })?;
Ok(*active.status.read().await)
}
async fn archive(&self, session: &SessionHandle) -> Result<(), RuntimeError> {
let sessions = self.sessions.read().await;
let active = sessions
.get(&session.0)
.ok_or_else(|| RuntimeError::NotFound { session_id: session.0.clone() })?;
debug!(session_id = %session.0, "archiving session");
*active.status.write().await = SessionStatus::Archived;
active.cancel_token.cancel();
Ok(())
}
async fn delete_session(&self, session: &SessionHandle) -> Result<(), RuntimeError> {
{
let sessions = self.sessions.read().await;
if let Some(active) = sessions.get(&session.0) {
*active.status.write().await = SessionStatus::Archived;
active.cancel_token.cancel();
}
}
let removed = self.sessions.write().await.remove(&session.0);
if removed.is_none() {
return Err(RuntimeError::NotFound { session_id: session.0.clone() });
}
debug!(session_id = %session.0, "session deleted");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::resolver::DefaultModelResolver;
use crate::types::{ContentBlock, ModelRef};
use adk_core::{Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream};
use async_stream::stream;
use futures::StreamExt;
use std::time::Duration;
fn mock_session_service() -> Arc<dyn SessionService> {
Arc::new(adk_session::InMemorySessionService::new())
}
struct MockLlm {
name: String,
}
impl MockLlm {
fn new(name: &str) -> Self {
Self { name: name.to_string() }
}
}
#[async_trait]
impl Llm for MockLlm {
fn name(&self) -> &str {
&self.name
}
async fn generate_content(
&self,
_request: LlmRequest,
_stream: bool,
) -> adk_core::Result<LlmResponseStream> {
let s = stream! {
yield Ok(LlmResponse {
content: Some(Content::new("model").with_text("Hello from mock")),
partial: false,
turn_complete: true,
finish_reason: Some(FinishReason::Stop),
..Default::default()
});
};
Ok(Box::pin(s))
}
}
struct MockResolver;
#[async_trait]
impl ModelResolver for MockResolver {
async fn resolve(
&self,
_model_ref: &ModelRef,
) -> crate::resolver::ResolverResult<Arc<dyn Llm>> {
Ok(Arc::new(MockLlm::new("mock-model")))
}
}
fn create_test_runtime() -> DefaultManagedAgentRuntime {
let resolver: Arc<dyn ModelResolver> = Arc::new(MockResolver);
let sessions = mock_session_service();
DefaultManagedAgentRuntime::new(resolver, sessions)
}
#[test]
fn test_new_with_minimal_config() {
let resolver = Arc::new(DefaultModelResolver::new());
let sessions = mock_session_service();
let _runtime = DefaultManagedAgentRuntime::new(resolver, sessions);
#[cfg(feature = "sandbox")]
assert!(_runtime.sandbox().is_none());
#[cfg(feature = "memory")]
assert!(_runtime.memory().is_none());
}
#[cfg(all(feature = "sandbox", feature = "memory"))]
#[test]
fn test_new_with_sandbox_and_memory() {
use adk_sandbox::{
BackendCapabilities, EnforcedLimits, ExecRequest, ExecResult, Language, SandboxBackend,
SandboxError,
};
struct FakeSandbox;
#[async_trait]
impl SandboxBackend for FakeSandbox {
fn name(&self) -> &str {
"fake"
}
fn capabilities(&self) -> BackendCapabilities {
BackendCapabilities {
supported_languages: vec![Language::Python],
isolation_class: "fake".to_string(),
enforced_limits: EnforcedLimits {
timeout: true,
memory: false,
network_isolation: false,
filesystem_isolation: false,
environment_isolation: false,
},
}
}
async fn execute(&self, _request: ExecRequest) -> Result<ExecResult, SandboxError> {
Ok(ExecResult {
stdout: "ok".to_string(),
stderr: String::new(),
exit_code: 0,
duration: std::time::Duration::from_millis(1),
})
}
}
struct FakeMemory;
#[async_trait]
impl adk_core::Memory for FakeMemory {
async fn search(&self, _query: &str) -> adk_core::Result<Vec<adk_core::MemoryEntry>> {
Ok(vec![])
}
}
let resolver = Arc::new(DefaultModelResolver::new());
let sessions = mock_session_service();
let runtime = DefaultManagedAgentRuntime::new(resolver, sessions)
.with_sandbox(Arc::new(FakeSandbox))
.with_memory(Arc::new(FakeMemory));
assert!(runtime.sandbox().is_some());
assert!(runtime.memory().is_some());
}
#[test]
fn test_sessions_map_starts_empty() {
let resolver = Arc::new(DefaultModelResolver::new());
let sessions = mock_session_service();
let runtime = DefaultManagedAgentRuntime::new(resolver, sessions);
let sessions = runtime.sessions().try_read().unwrap();
assert!(sessions.is_empty());
}
#[test]
fn test_accessors_return_injected_services() {
let resolver: Arc<dyn ModelResolver> = Arc::new(DefaultModelResolver::new());
let session_service = mock_session_service();
let runtime =
DefaultManagedAgentRuntime::new(Arc::clone(&resolver), Arc::clone(&session_service));
let _r: &Arc<dyn ModelResolver> = runtime.model_resolver();
let _s: &Arc<dyn SessionService> = runtime.session_service();
}
#[tokio::test]
async fn test_create_agent_returns_handle() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "test-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: Some("You are helpful.".to_string()),
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let handle = runtime.create(def).await.unwrap();
assert!(!handle.0.is_empty());
}
#[tokio::test]
async fn test_create_agent_stores_in_registry() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "stored-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let handle = runtime.create(def).await.unwrap();
let agents = runtime.agents.read().await;
assert!(agents.contains_key(&handle.0));
}
#[tokio::test]
async fn test_create_multiple_agents() {
let runtime = create_test_runtime();
let make_def = |name: &str| ManagedAgentDef {
name: name.to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let h1 = runtime.create(make_def("agent-1")).await.unwrap();
let h2 = runtime.create(make_def("agent-2")).await.unwrap();
assert_ne!(h1.0, h2.0);
assert_eq!(runtime.agents.read().await.len(), 2);
}
#[tokio::test]
async fn test_start_session_returns_handle() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "session-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
assert!(!session.0.is_empty());
}
#[tokio::test]
async fn test_start_session_initial_status_queued() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "status-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
let status = runtime.status(&session).await.unwrap();
assert_eq!(status, SessionStatus::Queued);
}
#[tokio::test]
async fn test_start_session_unknown_agent_returns_error() {
let runtime = create_test_runtime();
let fake_agent = AgentHandle("nonexistent".to_string());
let result = runtime.start_session(&fake_agent, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_send_event_message() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "event-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
let event =
UserEvent::Message { content: vec![ContentBlock::Text { text: "Hello".to_string() }] };
let result = runtime.send_event(&session, event).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_send_event_unknown_session_returns_error() {
let runtime = create_test_runtime();
let fake_session = SessionHandle("nonexistent".to_string());
let event =
UserEvent::Message { content: vec![ContentBlock::Text { text: "Hello".to_string() }] };
let result = runtime.send_event(&fake_session, event).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_stream_events_receives_broadcast() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "stream-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
let mut stream = runtime.stream_events(&session, None).await.unwrap();
let event =
UserEvent::Message { content: vec![ContentBlock::Text { text: "Test".to_string() }] };
runtime.send_event(&session, event).await.unwrap();
let first_event = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("timed out waiting for event")
.expect("stream ended unexpectedly");
match first_event {
SessionEvent::StatusRunning { .. } => {}
other => panic!("expected StatusRunning, got: {other:?}"),
}
}
#[tokio::test]
async fn test_stream_events_unknown_session_returns_error() {
let runtime = create_test_runtime();
let fake_session = SessionHandle("nonexistent".to_string());
let result = runtime.stream_events(&fake_session, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_interrupt_cancels_session() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "interrupt-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
let result = runtime.interrupt(&session).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_pause_sets_paused_status() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "pause-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
runtime.pause(&session).await.unwrap();
let status = runtime.status(&session).await.unwrap();
assert_eq!(status, SessionStatus::Paused);
}
#[tokio::test]
async fn test_resume_clears_pause() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "resume-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
runtime.pause(&session).await.unwrap();
assert_eq!(runtime.status(&session).await.unwrap(), SessionStatus::Paused);
runtime.resume(&session).await.unwrap();
assert_eq!(runtime.status(&session).await.unwrap(), SessionStatus::Running);
}
#[tokio::test]
async fn test_archive_sets_archived_status() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "archive-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
runtime.archive(&session).await.unwrap();
let status = runtime.status(&session).await.unwrap();
assert_eq!(status, SessionStatus::Archived);
}
#[tokio::test]
async fn test_delete_session_removes_from_registry() {
let runtime = create_test_runtime();
let def = ManagedAgentDef {
name: "delete-agent".to_string(),
model: ModelRef::Shorthand("gemini-2.5-flash".to_string()),
system: None,
description: None,
tools: vec![],
mcp_servers: vec![],
skills: vec![],
permission_policy: None,
metadata: None,
};
let agent = runtime.create(def).await.unwrap();
let session = runtime.start_session(&agent, None).await.unwrap();
runtime.delete_session(&session).await.unwrap();
let result = runtime.status(&session).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_delete_nonexistent_session_returns_error() {
let runtime = create_test_runtime();
let fake_session = SessionHandle("nonexistent".to_string());
let result = runtime.delete_session(&fake_session).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_interrupt_nonexistent_session_returns_error() {
let runtime = create_test_runtime();
let fake_session = SessionHandle("nonexistent".to_string());
let result = runtime.interrupt(&fake_session).await;
assert!(result.is_err());
}
}