1use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 routing::{get, post},
10 Json, Router,
11};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15
16use crate::{AgentRegistration, Message};
17
18#[derive(Default)]
21pub struct RelayState {
22 pub messages: Vec<Message>,
23 pub agents: HashMap<String, AgentRegistration>, }
25
26pub type SharedState = Arc<Mutex<RelayState>>;
27
28#[derive(Deserialize)]
31pub struct SendRequest {
32 pub from_session: String,
33 pub from_agent: String,
34 pub to_session: Option<String>,
35 pub content: String,
36}
37
38#[derive(Deserialize)]
39pub struct InboxQuery {
40 pub session: String,
41 #[serde(default = "default_limit")]
42 pub limit: usize,
43}
44
45fn default_limit() -> usize {
46 50
47}
48
49#[derive(Deserialize)]
50pub struct RegisterRequest {
51 pub session_id: String,
52 pub agent_id: String,
53 pub pid: u32,
54 #[serde(default)]
55 pub metadata: serde_json::Value,
56}
57
58#[derive(Serialize)]
59pub struct CountResponse {
60 pub count: u64,
61}
62
63pub fn build_router() -> Router {
66 let state: SharedState = Arc::new(Mutex::new(RelayState::default()));
67
68 Router::new()
69 .route("/health", get(health))
70 .route("/agents", get(list_agents))
71 .route("/agents/register", post(register_agent))
72 .route("/agents/unregister", post(unregister_agent))
73 .route("/messages/send", post(send_message))
74 .route("/messages/inbox", get(inbox))
75 .route("/messages/unread", get(unread_count))
76 .layer(tower_http::cors::CorsLayer::permissive())
77 .with_state(state)
78}
79
80async fn health() -> &'static str {
83 "agent-relay server ok"
84}
85
86async fn register_agent(
87 State(state): State<SharedState>,
88 Json(req): Json<RegisterRequest>,
89) -> (StatusCode, Json<AgentRegistration>) {
90 let now = std::time::SystemTime::now()
91 .duration_since(std::time::UNIX_EPOCH)
92 .unwrap_or_default()
93 .as_secs();
94
95 let reg = AgentRegistration {
96 session_id: req.session_id.clone(),
97 agent_id: req.agent_id,
98 pid: req.pid,
99 registered_at: now,
100 last_heartbeat: now,
101 metadata: req.metadata,
102 };
103
104 let mut s = state.lock().unwrap();
105 s.agents.insert(req.session_id, reg.clone());
106
107 (StatusCode::CREATED, Json(reg))
108}
109
110async fn unregister_agent(
111 State(state): State<SharedState>,
112 Json(req): Json<serde_json::Value>,
113) -> StatusCode {
114 let session = req["session_id"].as_str().unwrap_or("");
115 let mut s = state.lock().unwrap();
116 s.agents.remove(session);
117 StatusCode::OK
118}
119
120async fn list_agents(State(state): State<SharedState>) -> Json<Vec<AgentRegistration>> {
121 let s = state.lock().unwrap();
122 let mut agents: Vec<AgentRegistration> = s.agents.values().cloned().collect();
123 agents.sort_by(|a, b| b.last_heartbeat.cmp(&a.last_heartbeat));
124 Json(agents)
125}
126
127async fn send_message(
128 State(state): State<SharedState>,
129 Json(req): Json<SendRequest>,
130) -> (StatusCode, Json<Message>) {
131 let now = std::time::SystemTime::now()
132 .duration_since(std::time::UNIX_EPOCH)
133 .unwrap_or_default()
134 .as_secs();
135
136 let msg = Message {
137 id: format!("msg-{}", &uuid::Uuid::new_v4().to_string()[..8]),
138 from_session: req.from_session.clone(),
139 from_agent: req.from_agent,
140 to_session: req.to_session,
141 content: req.content,
142 timestamp: now,
143 read_by: vec![req.from_session],
144 };
145
146 let mut s = state.lock().unwrap();
147 s.messages.push(msg.clone());
148
149 if s.messages.len() > 10_000 {
151 s.messages.drain(..5_000);
152 }
153
154 (StatusCode::CREATED, Json(msg))
155}
156
157async fn inbox(
158 State(state): State<SharedState>,
159 Query(q): Query<InboxQuery>,
160) -> Json<Vec<Message>> {
161 let mut s = state.lock().unwrap();
162
163 let mut result = Vec::new();
164 for msg in s.messages.iter_mut().rev() {
165 let dominated = msg.to_session.is_none()
166 || msg.to_session.as_deref() == Some(&q.session)
167 || msg.from_session == q.session;
168
169 if dominated {
170 if !msg.read_by.contains(&q.session) {
171 msg.read_by.push(q.session.clone());
172 }
173 result.push(msg.clone());
174 if result.len() >= q.limit {
175 break;
176 }
177 }
178 }
179
180 Json(result)
181}
182
183async fn unread_count(
184 State(state): State<SharedState>,
185 Query(q): Query<InboxQuery>,
186) -> Json<CountResponse> {
187 let s = state.lock().unwrap();
188 let count = s
189 .messages
190 .iter()
191 .filter(|msg| {
192 let dominated =
193 msg.to_session.is_none() || msg.to_session.as_deref() == Some(&q.session);
194 dominated && msg.from_session != q.session && !msg.read_by.contains(&q.session)
195 })
196 .count() as u64;
197
198 Json(CountResponse { count })
199}