Skip to main content

agent_relay/
server.rs

1//! HTTP relay server with SQLite persistence.
2//!
3//! Messages survive server restarts. Run with:
4//!   `agent-relay server --port 4800`
5//!
6//! Data stored in `~/.agent-relay/relay.db` by default.
7
8use axum::{
9    extract::{Query, State},
10    http::StatusCode,
11    routing::{get, post},
12    Json, Router,
13};
14use rusqlite::Connection;
15use serde::{Deserialize, Serialize};
16use std::path::Path;
17use std::sync::{Arc, Mutex};
18
19use crate::{AgentRegistration, Message};
20
21// ── Shared state ──
22
23pub type SharedState = Arc<Mutex<Connection>>;
24
25/// Initialize the database schema.
26pub fn init_db(conn: &Connection) {
27    conn.execute_batch(
28        "CREATE TABLE IF NOT EXISTS messages (
29            id TEXT PRIMARY KEY,
30            from_session TEXT NOT NULL,
31            from_agent TEXT NOT NULL,
32            to_session TEXT,
33            content TEXT NOT NULL,
34            timestamp INTEGER NOT NULL
35        );
36        CREATE TABLE IF NOT EXISTS message_reads (
37            message_id TEXT NOT NULL,
38            session_id TEXT NOT NULL,
39            PRIMARY KEY (message_id, session_id)
40        );
41        CREATE TABLE IF NOT EXISTS agents (
42            session_id TEXT PRIMARY KEY,
43            agent_id TEXT NOT NULL,
44            pid INTEGER NOT NULL,
45            registered_at INTEGER NOT NULL,
46            last_heartbeat INTEGER NOT NULL,
47            metadata TEXT NOT NULL DEFAULT '{}'
48        );",
49    )
50    .expect("Failed to initialize database schema");
51}
52
53// ── Request/Response types ──
54
55#[derive(Deserialize)]
56pub struct SendRequest {
57    pub from_session: String,
58    pub from_agent: String,
59    pub to_session: Option<String>,
60    pub content: String,
61}
62
63#[derive(Deserialize)]
64pub struct InboxQuery {
65    pub session: String,
66    #[serde(default = "default_limit")]
67    pub limit: usize,
68}
69
70fn default_limit() -> usize {
71    50
72}
73
74#[derive(Deserialize)]
75pub struct RegisterRequest {
76    pub session_id: String,
77    pub agent_id: String,
78    pub pid: u32,
79    #[serde(default)]
80    pub metadata: serde_json::Value,
81}
82
83#[derive(Serialize)]
84pub struct CountResponse {
85    pub count: u64,
86}
87
88// ── Router ──
89
90/// Build the router with a SQLite-backed database at the given path.
91pub fn build_router(db_path: &Path) -> Router {
92    let conn = Connection::open(db_path).expect("Failed to open SQLite database");
93    conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;")
94        .expect("Failed to set pragmas");
95    init_db(&conn);
96
97    let state: SharedState = Arc::new(Mutex::new(conn));
98
99    Router::new()
100        .route("/health", get(health))
101        .route("/agents", get(list_agents))
102        .route("/agents/register", post(register_agent))
103        .route("/agents/unregister", post(unregister_agent))
104        .route("/messages/send", post(send_message))
105        .route("/messages/inbox", get(inbox))
106        .route("/messages/unread", get(unread_count))
107        .layer(tower_http::cors::CorsLayer::permissive())
108        .with_state(state)
109}
110
111// ── Handlers ──
112
113async fn health() -> &'static str {
114    "agent-relay server ok"
115}
116
117async fn register_agent(
118    State(state): State<SharedState>,
119    Json(req): Json<RegisterRequest>,
120) -> (StatusCode, Json<AgentRegistration>) {
121    let now = crate::Relay::now();
122    let metadata_str = req.metadata.to_string();
123
124    let conn = state.lock().unwrap();
125    conn.execute(
126        "INSERT OR REPLACE INTO agents (session_id, agent_id, pid, registered_at, last_heartbeat, metadata)
127         VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
128        rusqlite::params![req.session_id, req.agent_id, req.pid, now, now, metadata_str],
129    )
130    .expect("Failed to insert agent");
131
132    let reg = AgentRegistration {
133        session_id: req.session_id,
134        agent_id: req.agent_id,
135        pid: req.pid,
136        registered_at: now,
137        last_heartbeat: now,
138        metadata: req.metadata,
139    };
140
141    (StatusCode::CREATED, Json(reg))
142}
143
144async fn unregister_agent(
145    State(state): State<SharedState>,
146    Json(req): Json<serde_json::Value>,
147) -> StatusCode {
148    let session = req["session_id"].as_str().unwrap_or("");
149    let conn = state.lock().unwrap();
150    let _ = conn.execute("DELETE FROM agents WHERE session_id = ?1", [session]);
151    StatusCode::OK
152}
153
154async fn list_agents(State(state): State<SharedState>) -> Json<Vec<AgentRegistration>> {
155    let conn = state.lock().unwrap();
156    let mut stmt = conn
157        .prepare("SELECT session_id, agent_id, pid, registered_at, last_heartbeat, metadata FROM agents ORDER BY last_heartbeat DESC")
158        .unwrap();
159
160    let agents: Vec<AgentRegistration> = stmt
161        .query_map([], |row| {
162            let metadata_str: String = row.get(5)?;
163            Ok(AgentRegistration {
164                session_id: row.get(0)?,
165                agent_id: row.get(1)?,
166                pid: row.get::<_, u32>(2)?,
167                registered_at: row.get(3)?,
168                last_heartbeat: row.get(4)?,
169                metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
170            })
171        })
172        .unwrap()
173        .filter_map(|r| r.ok())
174        .collect();
175
176    Json(agents)
177}
178
179async fn send_message(
180    State(state): State<SharedState>,
181    Json(req): Json<SendRequest>,
182) -> (StatusCode, Json<Message>) {
183    let now = crate::Relay::now();
184    let id = format!("msg-{}", &uuid::Uuid::new_v4().to_string()[..8]);
185
186    let conn = state.lock().unwrap();
187    conn.execute(
188        "INSERT INTO messages (id, from_session, from_agent, to_session, content, timestamp)
189         VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
190        rusqlite::params![id, req.from_session, req.from_agent, req.to_session, req.content, now],
191    )
192    .expect("Failed to insert message");
193
194    // Mark as read by sender
195    let _ = conn.execute(
196        "INSERT OR IGNORE INTO message_reads (message_id, session_id) VALUES (?1, ?2)",
197        rusqlite::params![id, req.from_session],
198    );
199
200    let msg = Message {
201        id,
202        from_session: req.from_session,
203        from_agent: req.from_agent,
204        to_session: req.to_session,
205        content: req.content,
206        timestamp: now,
207        read_by: vec![],
208    };
209
210    (StatusCode::CREATED, Json(msg))
211}
212
213async fn inbox(
214    State(state): State<SharedState>,
215    Query(q): Query<InboxQuery>,
216) -> Json<Vec<Message>> {
217    let conn = state.lock().unwrap();
218
219    let mut stmt = conn
220        .prepare(
221            "SELECT id, from_session, from_agent, to_session, content, timestamp
222             FROM messages
223             WHERE to_session IS NULL OR to_session = ?1 OR from_session = ?1
224             ORDER BY timestamp DESC
225             LIMIT ?2",
226        )
227        .unwrap();
228
229    let messages: Vec<Message> = stmt
230        .query_map(rusqlite::params![q.session, q.limit], |row| {
231            Ok(Message {
232                id: row.get(0)?,
233                from_session: row.get(1)?,
234                from_agent: row.get(2)?,
235                to_session: row.get(3)?,
236                content: row.get(4)?,
237                timestamp: row.get(5)?,
238                read_by: vec![],
239            })
240        })
241        .unwrap()
242        .filter_map(|r| r.ok())
243        .collect();
244
245    // Mark all returned messages as read by this session
246    for msg in &messages {
247        let _ = conn.execute(
248            "INSERT OR IGNORE INTO message_reads (message_id, session_id) VALUES (?1, ?2)",
249            rusqlite::params![msg.id, q.session],
250        );
251    }
252
253    Json(messages)
254}
255
256async fn unread_count(
257    State(state): State<SharedState>,
258    Query(q): Query<InboxQuery>,
259) -> Json<CountResponse> {
260    let conn = state.lock().unwrap();
261
262    let count: u64 = conn
263        .query_row(
264            "SELECT COUNT(*) FROM messages m
265             WHERE (m.to_session IS NULL OR m.to_session = ?1)
266               AND m.from_session != ?1
267               AND NOT EXISTS (
268                   SELECT 1 FROM message_reads mr
269                   WHERE mr.message_id = m.id AND mr.session_id = ?1
270               )",
271            [&q.session],
272            |row| row.get(0),
273        )
274        .unwrap_or(0);
275
276    Json(CountResponse { count })
277}