use axum::{
extract::{Query, State},
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::{AgentRegistration, Message};
#[derive(Default)]
pub struct RelayState {
pub messages: Vec<Message>,
pub agents: HashMap<String, AgentRegistration>, }
pub type SharedState = Arc<Mutex<RelayState>>;
#[derive(Deserialize)]
pub struct SendRequest {
pub from_session: String,
pub from_agent: String,
pub to_session: Option<String>,
pub content: String,
}
#[derive(Deserialize)]
pub struct InboxQuery {
pub session: String,
#[serde(default = "default_limit")]
pub limit: usize,
}
fn default_limit() -> usize {
50
}
#[derive(Deserialize)]
pub struct RegisterRequest {
pub session_id: String,
pub agent_id: String,
pub pid: u32,
#[serde(default)]
pub metadata: serde_json::Value,
}
#[derive(Serialize)]
pub struct CountResponse {
pub count: u64,
}
pub fn build_router() -> Router {
let state: SharedState = Arc::new(Mutex::new(RelayState::default()));
Router::new()
.route("/health", get(health))
.route("/agents", get(list_agents))
.route("/agents/register", post(register_agent))
.route("/agents/unregister", post(unregister_agent))
.route("/messages/send", post(send_message))
.route("/messages/inbox", get(inbox))
.route("/messages/unread", get(unread_count))
.layer(tower_http::cors::CorsLayer::permissive())
.with_state(state)
}
async fn health() -> &'static str {
"agent-relay server ok"
}
async fn register_agent(
State(state): State<SharedState>,
Json(req): Json<RegisterRequest>,
) -> (StatusCode, Json<AgentRegistration>) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let reg = AgentRegistration {
session_id: req.session_id.clone(),
agent_id: req.agent_id,
pid: req.pid,
registered_at: now,
last_heartbeat: now,
metadata: req.metadata,
};
let mut s = state.lock().unwrap();
s.agents.insert(req.session_id, reg.clone());
(StatusCode::CREATED, Json(reg))
}
async fn unregister_agent(
State(state): State<SharedState>,
Json(req): Json<serde_json::Value>,
) -> StatusCode {
let session = req["session_id"].as_str().unwrap_or("");
let mut s = state.lock().unwrap();
s.agents.remove(session);
StatusCode::OK
}
async fn list_agents(State(state): State<SharedState>) -> Json<Vec<AgentRegistration>> {
let s = state.lock().unwrap();
let mut agents: Vec<AgentRegistration> = s.agents.values().cloned().collect();
agents.sort_by(|a, b| b.last_heartbeat.cmp(&a.last_heartbeat));
Json(agents)
}
async fn send_message(
State(state): State<SharedState>,
Json(req): Json<SendRequest>,
) -> (StatusCode, Json<Message>) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let msg = Message {
id: format!("msg-{}", &uuid::Uuid::new_v4().to_string()[..8]),
from_session: req.from_session.clone(),
from_agent: req.from_agent,
to_session: req.to_session,
content: req.content,
timestamp: now,
read_by: vec![req.from_session],
};
let mut s = state.lock().unwrap();
s.messages.push(msg.clone());
if s.messages.len() > 10_000 {
s.messages.drain(..5_000);
}
(StatusCode::CREATED, Json(msg))
}
async fn inbox(
State(state): State<SharedState>,
Query(q): Query<InboxQuery>,
) -> Json<Vec<Message>> {
let mut s = state.lock().unwrap();
let mut result = Vec::new();
for msg in s.messages.iter_mut().rev() {
let dominated = msg.to_session.is_none()
|| msg.to_session.as_deref() == Some(&q.session)
|| msg.from_session == q.session;
if dominated {
if !msg.read_by.contains(&q.session) {
msg.read_by.push(q.session.clone());
}
result.push(msg.clone());
if result.len() >= q.limit {
break;
}
}
}
Json(result)
}
async fn unread_count(
State(state): State<SharedState>,
Query(q): Query<InboxQuery>,
) -> Json<CountResponse> {
let s = state.lock().unwrap();
let count = s
.messages
.iter()
.filter(|msg| {
let dominated =
msg.to_session.is_none() || msg.to_session.as_deref() == Some(&q.session);
dominated && msg.from_session != q.session && !msg.read_by.contains(&q.session)
})
.count() as u64;
Json(CountResponse { count })
}