use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use awaken_contract::contract::storage::ThreadRunStore;
use awaken_runtime::{AgentResolver, AgentRuntime};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use crate::mailbox::Mailbox;
use crate::transport::replay_buffer::EventReplayBuffer;
pub type ReplayBufferEntry = (Arc<EventReplayBuffer>, Instant);
pub type ReplayBufferMap = Arc<Mutex<HashMap<String, ReplayBufferEntry>>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShutdownConfig {
#[serde(default = "default_shutdown_timeout")]
pub timeout_secs: u64,
}
fn default_shutdown_timeout() -> u64 {
30
}
impl Default for ShutdownConfig {
fn default() -> Self {
Self {
timeout_secs: default_shutdown_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub address: String,
#[serde(default = "default_sse_buffer")]
pub sse_buffer_size: usize,
#[serde(default = "default_replay_buffer_capacity")]
pub replay_buffer_capacity: usize,
#[serde(default)]
pub shutdown: ShutdownConfig,
#[serde(default = "default_max_concurrent")]
pub max_concurrent_requests: usize,
}
fn default_sse_buffer() -> usize {
64
}
fn default_replay_buffer_capacity() -> usize {
1024
}
fn default_max_concurrent() -> usize {
100
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
address: "0.0.0.0:3000".to_string(),
sse_buffer_size: default_sse_buffer(),
replay_buffer_capacity: default_replay_buffer_capacity(),
shutdown: ShutdownConfig::default(),
max_concurrent_requests: default_max_concurrent(),
}
}
}
#[derive(Clone)]
pub struct AppState {
pub runtime: Arc<AgentRuntime>,
pub mailbox: Arc<Mailbox>,
pub store: Arc<dyn ThreadRunStore>,
pub resolver: Arc<dyn AgentResolver>,
pub config: ServerConfig,
pub replay_buffers: ReplayBufferMap,
}
impl AppState {
pub fn new(
runtime: Arc<AgentRuntime>,
mailbox: Arc<Mailbox>,
store: Arc<dyn ThreadRunStore>,
resolver: Arc<dyn AgentResolver>,
config: ServerConfig,
) -> Self {
Self {
runtime,
mailbox,
store,
resolver,
config,
replay_buffers: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn insert_replay_buffer(&self, key: String, buffer: Arc<EventReplayBuffer>) {
self.replay_buffers
.lock()
.insert(key, (buffer, Instant::now()));
}
pub fn get_replay_buffer(&self, key: &str) -> Option<Arc<EventReplayBuffer>> {
self.replay_buffers
.lock()
.get(key)
.map(|(buf, _)| Arc::clone(buf))
}
pub fn remove_replay_buffer(&self, key: &str) {
self.replay_buffers.lock().remove(key);
}
pub fn purge_stale_replay_buffers(&self, max_age: std::time::Duration) {
let now = Instant::now();
let mut buffers = self.replay_buffers.lock();
let before = buffers.len();
buffers.retain(|_key, (_buf, created_at)| {
let age = now.duration_since(*created_at);
if age < max_age {
return true;
}
false
});
let purged = before - buffers.len();
if purged > 0 {
tracing::debug!(purged, "purged stale replay buffers");
}
}
}
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("install SIGTERM handler");
tokio::select! {
_ = ctrl_c => {}
_ = sigterm.recv() => {}
}
}
#[cfg(not(unix))]
{
ctrl_c.await.ok();
}
tracing::info!("shutting down gracefully...");
}
pub async fn serve_with_shutdown(
listener: tokio::net::TcpListener,
app: axum::Router,
shutdown_timeout: std::time::Duration,
) -> std::io::Result<()> {
let drain_notify = Arc::new(tokio::sync::Notify::new());
let drain_notify2 = drain_notify.clone();
let graceful_signal = async move {
shutdown_signal().await;
drain_notify2.notify_one();
};
let server = axum::serve(listener, app).with_graceful_shutdown(graceful_signal);
let drain_deadline = async {
drain_notify.notified().await;
tokio::time::sleep(shutdown_timeout).await;
tracing::warn!(
"server did not drain within {}s — forcing exit",
shutdown_timeout.as_secs()
);
};
tokio::select! {
result = server => result,
() = drain_deadline => Ok(()),
}
}
pub async fn serve(state: AppState) -> std::io::Result<()> {
let addr = state.config.address.clone();
let timeout = std::time::Duration::from_secs(state.config.shutdown.timeout_secs);
let max_concurrent = state.config.max_concurrent_requests;
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("listening on {addr}");
let app = crate::routes::build_router()
.layer(tower::limit::ConcurrencyLimitLayer::new(max_concurrent))
.with_state(state);
serve_with_shutdown(listener, app, timeout).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn server_config_default_values() {
let config = ServerConfig::default();
assert_eq!(config.address, "0.0.0.0:3000");
assert_eq!(config.sse_buffer_size, 64);
assert_eq!(config.replay_buffer_capacity, 1024);
assert_eq!(config.shutdown.timeout_secs, 30);
assert_eq!(config.max_concurrent_requests, 100);
}
#[test]
fn server_config_serde_roundtrip() {
let config = ServerConfig {
address: "127.0.0.1:8080".to_string(),
sse_buffer_size: 128,
replay_buffer_capacity: 512,
shutdown: ShutdownConfig { timeout_secs: 10 },
max_concurrent_requests: 50,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: ServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.address, "127.0.0.1:8080");
assert_eq!(parsed.sse_buffer_size, 128);
assert_eq!(parsed.replay_buffer_capacity, 512);
assert_eq!(parsed.shutdown.timeout_secs, 10);
assert_eq!(parsed.max_concurrent_requests, 50);
}
#[test]
fn server_config_deserialize_with_defaults() {
let json = r#"{"address": "localhost:9000"}"#;
let config: ServerConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.address, "localhost:9000");
assert_eq!(config.sse_buffer_size, 64);
assert_eq!(config.shutdown.timeout_secs, 30);
assert_eq!(config.max_concurrent_requests, 100);
}
#[test]
fn shutdown_config_defaults() {
let config = ShutdownConfig::default();
assert_eq!(config.timeout_secs, 30);
}
#[test]
fn shutdown_config_custom() {
let json = r#"{"timeout_secs": 60}"#;
let config: ShutdownConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.timeout_secs, 60);
}
fn make_replay_map() -> ReplayBufferMap {
Arc::new(Mutex::new(HashMap::new()))
}
#[test]
fn insert_and_get_replay_buffer() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
buf.push_json(r#"{"hello":1}"#);
map.lock()
.insert("run-1".to_string(), (Arc::clone(&buf), Instant::now()));
let retrieved = map.lock().get("run-1").map(|(b, _)| Arc::clone(b));
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().current_seq(), 1);
}
#[test]
fn remove_replay_buffer_works() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("run-2".to_string(), (buf, Instant::now()));
assert!(map.lock().get("run-2").is_some());
map.lock().remove("run-2");
assert!(map.lock().get("run-2").is_none());
}
#[test]
fn purge_stale_replay_buffers_removes_all_with_zero_max_age() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("run-a".to_string(), (Arc::clone(&buf), Instant::now()));
map.lock()
.insert("run-b".to_string(), (buf, Instant::now()));
assert_eq!(map.lock().len(), 2);
let now = Instant::now();
map.lock().retain(|_key, (_buf, created_at)| {
now.duration_since(*created_at) < std::time::Duration::ZERO
});
assert_eq!(map.lock().len(), 0);
}
#[test]
fn purge_stale_replay_buffers_keeps_recent() {
let map = make_replay_map();
let buf = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("run-c".to_string(), (buf, Instant::now()));
let now = Instant::now();
let max_age = std::time::Duration::from_secs(3600);
map.lock()
.retain(|_key, (_buf, created_at)| now.duration_since(*created_at) < max_age);
assert_eq!(map.lock().len(), 1);
}
#[test]
fn purge_stale_mixed_ages() {
let map = make_replay_map();
let old_instant = Instant::now()
.checked_sub(std::time::Duration::from_secs(120))
.unwrap_or_else(Instant::now);
let recent_instant = Instant::now();
let buf_old = Arc::new(EventReplayBuffer::new(16));
let buf_recent = Arc::new(EventReplayBuffer::new(16));
map.lock()
.insert("old-run".to_string(), (buf_old, old_instant));
map.lock()
.insert("recent-run".to_string(), (buf_recent, recent_instant));
assert_eq!(map.lock().len(), 2);
let now = Instant::now();
let max_age = std::time::Duration::from_secs(60);
map.lock()
.retain(|_key, (_buf, created_at)| now.duration_since(*created_at) < max_age);
assert_eq!(map.lock().len(), 1);
assert!(map.lock().get("recent-run").is_some());
assert!(map.lock().get("old-run").is_none());
}
}