use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use awaken_contract::contract::config_store::ConfigStore;
use awaken_contract::contract::storage::ThreadRunStore;
use awaken_runtime::{AgentResolver, AgentRuntime};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use crate::mailbox::{Mailbox, MailboxLifecycleConfig};
use crate::transport::replay_buffer::EventReplayBuffer;
pub type ReplayBufferEntry = (Arc<EventReplayBuffer>, Instant);
pub type ReplayBufferMap = Arc<Mutex<HashMap<String, ReplayBufferEntry>>>;
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SkillCatalogContext {
Inline,
Fork,
}
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub struct SkillCatalogArgument {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub required: bool,
}
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub struct SkillCatalogEntry {
pub id: String,
pub name: String,
pub description: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub allowed_tools: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub when_to_use: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub arguments: Vec<SkillCatalogArgument>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub argument_hint: Option<String>,
pub user_invocable: bool,
pub model_invocable: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_override: Option<String>,
pub context: SkillCatalogContext,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub paths: Vec<String>,
}
pub trait SkillCatalogProvider: Send + Sync {
fn list_skills(&self) -> Vec<SkillCatalogEntry>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShutdownConfig {
#[serde(default = "default_shutdown_timeout")]
pub timeout_secs: u64,
}
fn default_shutdown_timeout() -> u64 {
30
}
impl Default for ShutdownConfig {
fn default() -> Self {
Self {
timeout_secs: default_shutdown_timeout(),
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MailboxLifecycleMode {
#[default]
Auto,
Manual,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub address: String,
#[serde(default = "default_sse_buffer")]
pub sse_buffer_size: usize,
#[serde(default = "default_replay_buffer_capacity")]
pub replay_buffer_capacity: usize,
#[serde(default)]
pub shutdown: ShutdownConfig,
#[serde(default = "default_max_concurrent")]
pub max_concurrent_requests: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub a2a_extended_card_bearer_token: Option<String>,
#[serde(default)]
pub mailbox_lifecycle: MailboxLifecycleMode,
}
fn default_sse_buffer() -> usize {
64
}
fn default_replay_buffer_capacity() -> usize {
1024
}
fn default_max_concurrent() -> usize {
100
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
address: "0.0.0.0:3000".to_string(),
sse_buffer_size: default_sse_buffer(),
replay_buffer_capacity: default_replay_buffer_capacity(),
shutdown: ShutdownConfig::default(),
max_concurrent_requests: default_max_concurrent(),
a2a_extended_card_bearer_token: None,
mailbox_lifecycle: MailboxLifecycleMode::Auto,
}
}
}
#[derive(Clone)]
pub struct AppState {
pub runtime: Arc<AgentRuntime>,
pub mailbox: Arc<Mailbox>,
pub store: Arc<dyn ThreadRunStore>,
pub resolver: Arc<dyn AgentResolver>,
pub config: ServerConfig,
pub config_store: Option<Arc<dyn ConfigStore>>,
pub config_runtime_manager: Option<Arc<crate::services::config_runtime::ConfigRuntimeManager>>,
pub skill_catalog_provider: Option<Arc<dyn SkillCatalogProvider>>,
pub replay_buffers: ReplayBufferMap,
pub mcp_http: Arc<crate::protocols::mcp::http::McpHttpState>,
}
impl AppState {
pub fn new(
runtime: Arc<AgentRuntime>,
mailbox: Arc<Mailbox>,
store: Arc<dyn ThreadRunStore>,
resolver: Arc<dyn AgentResolver>,
config: ServerConfig,
) -> Self {
Self {
runtime,
mailbox,
store,
resolver,
config,
config_store: None,
config_runtime_manager: None,
skill_catalog_provider: None,
replay_buffers: Arc::new(Mutex::new(HashMap::new())),
mcp_http: Arc::new(crate::protocols::mcp::http::McpHttpState::new()),
}
}
pub fn with_config_store(mut self, store: Arc<dyn ConfigStore>) -> Self {
self.config_store = Some(store);
self
}
pub fn with_config_runtime_manager(
mut self,
manager: Arc<crate::services::config_runtime::ConfigRuntimeManager>,
) -> Self {
self.config_runtime_manager = Some(manager);
self
}
pub fn with_skill_catalog_provider(mut self, provider: Arc<dyn SkillCatalogProvider>) -> Self {
self.skill_catalog_provider = Some(provider);
self
}
pub fn insert_replay_buffer(&self, key: String, buffer: Arc<EventReplayBuffer>) {
self.replay_buffers
.lock()
.insert(key, (buffer, Instant::now()));
}
pub fn get_replay_buffer(&self, key: &str) -> Option<Arc<EventReplayBuffer>> {
self.replay_buffers
.lock()
.get(key)
.map(|(buf, _)| Arc::clone(buf))
}
pub fn remove_replay_buffer(&self, key: &str) {
self.replay_buffers.lock().remove(key);
}
pub fn purge_stale_replay_buffers(&self, max_age: std::time::Duration) {
let now = Instant::now();
let mut buffers = self.replay_buffers.lock();
let before = buffers.len();
buffers.retain(|_key, (_buf, created_at)| {
let age = now.duration_since(*created_at);
if age < max_age {
return true;
}
false
});
let purged = before - buffers.len();
if purged > 0 {
tracing::debug!(purged, "purged stale replay buffers");
}
}
}
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("install SIGTERM handler");
tokio::select! {
_ = ctrl_c => {}
_ = sigterm.recv() => {}
}
}
#[cfg(not(unix))]
{
ctrl_c.await.ok();
}
tracing::info!("shutting down gracefully...");
}
pub async fn serve_with_shutdown(
listener: tokio::net::TcpListener,
app: axum::Router,
shutdown_timeout: std::time::Duration,
) -> std::io::Result<()> {
let drain_notify = Arc::new(tokio::sync::Notify::new());
let drain_notify2 = drain_notify.clone();
let graceful_signal = async move {
shutdown_signal().await;
drain_notify2.notify_one();
};
let server = axum::serve(listener, app).with_graceful_shutdown(graceful_signal);
let drain_deadline = async {
drain_notify.notified().await;
tokio::time::sleep(shutdown_timeout).await;
tracing::warn!(
"server did not drain within {}s — forcing exit",
shutdown_timeout.as_secs()
);
};
tokio::select! {
result = server => result,
() = drain_deadline => Ok(()),
}
}
pub async fn serve(state: AppState) -> std::io::Result<()> {
let addr = state.config.address.clone();
let timeout = std::time::Duration::from_secs(state.config.shutdown.timeout_secs);
let max_concurrent = state.config.max_concurrent_requests;
let mailbox_lifecycle = match state.config.mailbox_lifecycle {
MailboxLifecycleMode::Auto => {
let cleanup_state = state.clone();
Some(
state
.mailbox
.start_lifecycle_ready(MailboxLifecycleConfig {
maintenance_callback: Some(Arc::new(move || {
cleanup_state
.purge_stale_replay_buffers(std::time::Duration::from_secs(300));
})),
..Default::default()
})
.await
.map_err(|error| {
std::io::Error::other(format!("failed to start mailbox lifecycle: {error}"))
})?,
)
}
MailboxLifecycleMode::Manual => None,
};
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("listening on {addr}");
let app = crate::routes::build_router()
.layer(tower::limit::ConcurrencyLimitLayer::new(max_concurrent))
.with_state(state);
let result = serve_with_shutdown(listener, app, timeout).await;
if let Some(mailbox_lifecycle) = mailbox_lifecycle
&& let Err(error) = mailbox_lifecycle.shutdown().await
{
tracing::warn!(error = %error, "failed to stop mailbox lifecycle cleanly");
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn server_config_default_values() {
let config = ServerConfig::default();
assert_eq!(config.address, "0.0.0.0:3000");
assert_eq!(config.sse_buffer_size, 64);
assert_eq!(config.replay_buffer_capacity, 1024);
assert_eq!(config.shutdown.timeout_secs, 30);
assert_eq!(config.max_concurrent_requests, 100);
assert_eq!(config.mailbox_lifecycle, MailboxLifecycleMode::Auto);
}
#[test]
fn server_config_serde_roundtrip() {
let config = ServerConfig {
address: "127.0.0.1:8080".to_string(),
sse_buffer_size: 128,
replay_buffer_capacity: 512,
shutdown: ShutdownConfig { timeout_secs: 10 },
max_concurrent_requests: 50,
a2a_extended_card_bearer_token: None,
mailbox_lifecycle: MailboxLifecycleMode::Manual,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: ServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.address, "127.0.0.1:8080");
assert_eq!(parsed.sse_buffer_size, 128);
assert_eq!(parsed.replay_buffer_capacity, 512);
assert_eq!(parsed.shutdown.timeout_secs, 10);
assert_eq!(parsed.max_concurrent_requests, 50);
assert_eq!(parsed.mailbox_lifecycle, MailboxLifecycleMode::Manual);
}
#[test]
fn server_config_deserialize_with_defaults() {
let json = r#"{"address": "localhost:9000"}"#;
let config: ServerConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.address, "localhost:9000");
assert_eq!(config.sse_buffer_size, 64);
assert_eq!(config.shutdown.timeout_secs, 30);
assert_eq!(config.max_concurrent_requests, 100);
assert_eq!(config.mailbox_lifecycle, MailboxLifecycleMode::Auto);
}
#[test]
fn mailbox_lifecycle_mode_deserializes_manual() {
let json = r#"{"address": "localhost:9000", "mailbox_lifecycle": "manual"}"#;
let config: ServerConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.mailbox_lifecycle, MailboxLifecycleMode::Manual);
}
#[test]
fn shutdown_config_defaults() {
let config = ShutdownConfig::default();
assert_eq!(config.timeout_secs, 30);
}
#[test]
fn shutdown_config_custom() {
let json = r#"{"timeout_secs": 60}"#;
let config: ShutdownConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.timeout_secs, 60);
}
fn make_replay_map() -> ReplayBufferMap {
Arc::new(Mutex::new(HashMap::new()))
}
#[test]
fn insert_and_get_replay_buffer() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
buf.push_json(r#"{"hello":1}"#);
map.lock()
.insert("run-1".to_string(), (Arc::clone(&buf), Instant::now()));
let retrieved = map.lock().get("run-1").map(|(b, _)| Arc::clone(b));
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().current_seq(), 1);
}
#[test]
fn remove_replay_buffer_works() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("run-2".to_string(), (buf, Instant::now()));
assert!(map.lock().get("run-2").is_some());
map.lock().remove("run-2");
assert!(map.lock().get("run-2").is_none());
}
#[test]
fn purge_stale_replay_buffers_removes_all_with_zero_max_age() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("run-a".to_string(), (Arc::clone(&buf), Instant::now()));
map.lock()
.insert("run-b".to_string(), (buf, Instant::now()));
assert_eq!(map.lock().len(), 2);
let now = Instant::now();
map.lock().retain(|_key, (_buf, created_at)| {
now.duration_since(*created_at) < std::time::Duration::ZERO
});
assert_eq!(map.lock().len(), 0);
}
#[test]
fn purge_stale_replay_buffers_keeps_recent() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("run-c".to_string(), (buf, Instant::now()));
let now = Instant::now();
let max_age = std::time::Duration::from_secs(3600);
map.lock()
.retain(|_key, (_buf, created_at)| now.duration_since(*created_at) < max_age);
assert_eq!(map.lock().len(), 1);
}
#[test]
fn purge_stale_mixed_ages() {
let map = make_replay_map();
let old_instant = Instant::now()
.checked_sub(std::time::Duration::from_secs(120))
.unwrap_or_else(Instant::now);
let recent_instant = Instant::now();
let buf_old = Arc::new(EventReplayBuffer::new(16));
let buf_recent = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("old-run".to_string(), (buf_old, old_instant));
map.lock()
.insert("recent-run".to_string(), (buf_recent, recent_instant));
assert_eq!(map.lock().len(), 2);
let now = Instant::now();
let max_age = std::time::Duration::from_secs(60);
map.lock()
.retain(|_key, (_buf, created_at)| now.duration_since(*created_at) < max_age);
assert_eq!(map.lock().len(), 1);
assert!(map.lock().get("recent-run").is_some());
assert!(map.lock().get("old-run").is_none());
}
}