Skip to main content

agent_relay/
server.rs

1//! HTTP relay server for cross-machine agent messaging.
2//!
3//! Run with `agent-relay server --port 4800` on any reachable machine.
4//! Agents connect with `agent-relay --server http://host:4800 send "msg"`.
5
6use axum::{
7    extract::{Query, State},
8    http::StatusCode,
9    routing::{get, post},
10    Json, Router,
11};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15
16use crate::{AgentRegistration, Message};
17
18// ── In-memory state ──
19
20#[derive(Default)]
21pub struct RelayState {
22    pub messages: Vec<Message>,
23    pub agents: HashMap<String, AgentRegistration>, // session_id -> reg
24}
25
26pub type SharedState = Arc<Mutex<RelayState>>;
27
28// ── Request/Response types ──
29
30#[derive(Deserialize)]
31pub struct SendRequest {
32    pub from_session: String,
33    pub from_agent: String,
34    pub to_session: Option<String>,
35    pub content: String,
36}
37
38#[derive(Deserialize)]
39pub struct InboxQuery {
40    pub session: String,
41    #[serde(default = "default_limit")]
42    pub limit: usize,
43}
44
45fn default_limit() -> usize {
46    50
47}
48
49#[derive(Deserialize)]
50pub struct RegisterRequest {
51    pub session_id: String,
52    pub agent_id: String,
53    pub pid: u32,
54    #[serde(default)]
55    pub metadata: serde_json::Value,
56}
57
58#[derive(Serialize)]
59pub struct CountResponse {
60    pub count: u64,
61}
62
63// ── Router ──
64
65pub fn build_router() -> Router {
66    let state: SharedState = Arc::new(Mutex::new(RelayState::default()));
67
68    Router::new()
69        .route("/health", get(health))
70        .route("/agents", get(list_agents))
71        .route("/agents/register", post(register_agent))
72        .route("/agents/unregister", post(unregister_agent))
73        .route("/messages/send", post(send_message))
74        .route("/messages/inbox", get(inbox))
75        .route("/messages/unread", get(unread_count))
76        .layer(tower_http::cors::CorsLayer::permissive())
77        .with_state(state)
78}
79
80// ── Handlers ──
81
82async fn health() -> &'static str {
83    "agent-relay server ok"
84}
85
86async fn register_agent(
87    State(state): State<SharedState>,
88    Json(req): Json<RegisterRequest>,
89) -> (StatusCode, Json<AgentRegistration>) {
90    let now = std::time::SystemTime::now()
91        .duration_since(std::time::UNIX_EPOCH)
92        .unwrap_or_default()
93        .as_secs();
94
95    let reg = AgentRegistration {
96        session_id: req.session_id.clone(),
97        agent_id: req.agent_id,
98        pid: req.pid,
99        registered_at: now,
100        last_heartbeat: now,
101        metadata: req.metadata,
102    };
103
104    let mut s = state.lock().unwrap();
105    s.agents.insert(req.session_id, reg.clone());
106
107    (StatusCode::CREATED, Json(reg))
108}
109
110async fn unregister_agent(
111    State(state): State<SharedState>,
112    Json(req): Json<serde_json::Value>,
113) -> StatusCode {
114    let session = req["session_id"].as_str().unwrap_or("");
115    let mut s = state.lock().unwrap();
116    s.agents.remove(session);
117    StatusCode::OK
118}
119
120async fn list_agents(State(state): State<SharedState>) -> Json<Vec<AgentRegistration>> {
121    let s = state.lock().unwrap();
122    let mut agents: Vec<AgentRegistration> = s.agents.values().cloned().collect();
123    agents.sort_by(|a, b| b.last_heartbeat.cmp(&a.last_heartbeat));
124    Json(agents)
125}
126
127async fn send_message(
128    State(state): State<SharedState>,
129    Json(req): Json<SendRequest>,
130) -> (StatusCode, Json<Message>) {
131    let now = std::time::SystemTime::now()
132        .duration_since(std::time::UNIX_EPOCH)
133        .unwrap_or_default()
134        .as_secs();
135
136    let msg = Message {
137        id: format!("msg-{}", &uuid::Uuid::new_v4().to_string()[..8]),
138        from_session: req.from_session.clone(),
139        from_agent: req.from_agent,
140        to_session: req.to_session,
141        content: req.content,
142        timestamp: now,
143        read_by: vec![req.from_session],
144    };
145
146    let mut s = state.lock().unwrap();
147    s.messages.push(msg.clone());
148
149    // Cap at 10k messages in memory
150    if s.messages.len() > 10_000 {
151        s.messages.drain(..5_000);
152    }
153
154    (StatusCode::CREATED, Json(msg))
155}
156
157async fn inbox(
158    State(state): State<SharedState>,
159    Query(q): Query<InboxQuery>,
160) -> Json<Vec<Message>> {
161    let mut s = state.lock().unwrap();
162
163    let mut result = Vec::new();
164    for msg in s.messages.iter_mut().rev() {
165        let dominated = msg.to_session.is_none()
166            || msg.to_session.as_deref() == Some(&q.session)
167            || msg.from_session == q.session;
168
169        if dominated {
170            if !msg.read_by.contains(&q.session) {
171                msg.read_by.push(q.session.clone());
172            }
173            result.push(msg.clone());
174            if result.len() >= q.limit {
175                break;
176            }
177        }
178    }
179
180    Json(result)
181}
182
183async fn unread_count(
184    State(state): State<SharedState>,
185    Query(q): Query<InboxQuery>,
186) -> Json<CountResponse> {
187    let s = state.lock().unwrap();
188    let count = s
189        .messages
190        .iter()
191        .filter(|msg| {
192            let dominated =
193                msg.to_session.is_none() || msg.to_session.as_deref() == Some(&q.session);
194            dominated && msg.from_session != q.session && !msg.read_by.contains(&q.session)
195        })
196        .count() as u64;
197
198    Json(CountResponse { count })
199}