1use axum::{
9 extract::{Query, State},
10 http::StatusCode,
11 routing::{get, post},
12 Json, Router,
13};
14use rusqlite::Connection;
15use serde::{Deserialize, Serialize};
16use std::path::Path;
17use std::sync::{Arc, Mutex};
18
19use crate::{AgentRegistration, Message};
20
21pub type SharedState = Arc<Mutex<Connection>>;
24
25pub fn init_db(conn: &Connection) {
27 conn.execute_batch(
28 "CREATE TABLE IF NOT EXISTS messages (
29 id TEXT PRIMARY KEY,
30 from_session TEXT NOT NULL,
31 from_agent TEXT NOT NULL,
32 to_session TEXT,
33 content TEXT NOT NULL,
34 timestamp INTEGER NOT NULL
35 );
36 CREATE TABLE IF NOT EXISTS message_reads (
37 message_id TEXT NOT NULL,
38 session_id TEXT NOT NULL,
39 PRIMARY KEY (message_id, session_id)
40 );
41 CREATE TABLE IF NOT EXISTS agents (
42 session_id TEXT PRIMARY KEY,
43 agent_id TEXT NOT NULL,
44 pid INTEGER NOT NULL,
45 registered_at INTEGER NOT NULL,
46 last_heartbeat INTEGER NOT NULL,
47 metadata TEXT NOT NULL DEFAULT '{}'
48 );",
49 )
50 .expect("Failed to initialize database schema");
51}
52
53#[derive(Deserialize)]
56pub struct SendRequest {
57 pub from_session: String,
58 pub from_agent: String,
59 pub to_session: Option<String>,
60 pub content: String,
61}
62
63#[derive(Deserialize)]
64pub struct InboxQuery {
65 pub session: String,
66 #[serde(default = "default_limit")]
67 pub limit: usize,
68}
69
70fn default_limit() -> usize {
71 50
72}
73
74#[derive(Deserialize)]
75pub struct RegisterRequest {
76 pub session_id: String,
77 pub agent_id: String,
78 pub pid: u32,
79 #[serde(default)]
80 pub metadata: serde_json::Value,
81}
82
83#[derive(Serialize)]
84pub struct CountResponse {
85 pub count: u64,
86}
87
88pub fn build_router(db_path: &Path) -> Router {
92 let conn = Connection::open(db_path).expect("Failed to open SQLite database");
93 conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;")
94 .expect("Failed to set pragmas");
95 init_db(&conn);
96
97 let state: SharedState = Arc::new(Mutex::new(conn));
98
99 Router::new()
100 .route("/health", get(health))
101 .route("/agents", get(list_agents))
102 .route("/agents/register", post(register_agent))
103 .route("/agents/unregister", post(unregister_agent))
104 .route("/messages/send", post(send_message))
105 .route("/messages/inbox", get(inbox))
106 .route("/messages/unread", get(unread_count))
107 .layer(tower_http::cors::CorsLayer::permissive())
108 .with_state(state)
109}
110
111async fn health() -> &'static str {
114 "agent-relay server ok"
115}
116
117async fn register_agent(
118 State(state): State<SharedState>,
119 Json(req): Json<RegisterRequest>,
120) -> (StatusCode, Json<AgentRegistration>) {
121 let now = crate::Relay::now();
122 let metadata_str = req.metadata.to_string();
123
124 let conn = state.lock().unwrap();
125 conn.execute(
126 "INSERT OR REPLACE INTO agents (session_id, agent_id, pid, registered_at, last_heartbeat, metadata)
127 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
128 rusqlite::params![req.session_id, req.agent_id, req.pid, now, now, metadata_str],
129 )
130 .expect("Failed to insert agent");
131
132 let reg = AgentRegistration {
133 session_id: req.session_id,
134 agent_id: req.agent_id,
135 pid: req.pid,
136 registered_at: now,
137 last_heartbeat: now,
138 metadata: req.metadata,
139 };
140
141 (StatusCode::CREATED, Json(reg))
142}
143
144async fn unregister_agent(
145 State(state): State<SharedState>,
146 Json(req): Json<serde_json::Value>,
147) -> StatusCode {
148 let session = req["session_id"].as_str().unwrap_or("");
149 let conn = state.lock().unwrap();
150 let _ = conn.execute("DELETE FROM agents WHERE session_id = ?1", [session]);
151 StatusCode::OK
152}
153
154async fn list_agents(State(state): State<SharedState>) -> Json<Vec<AgentRegistration>> {
155 let conn = state.lock().unwrap();
156 let mut stmt = conn
157 .prepare("SELECT session_id, agent_id, pid, registered_at, last_heartbeat, metadata FROM agents ORDER BY last_heartbeat DESC")
158 .unwrap();
159
160 let agents: Vec<AgentRegistration> = stmt
161 .query_map([], |row| {
162 let metadata_str: String = row.get(5)?;
163 Ok(AgentRegistration {
164 session_id: row.get(0)?,
165 agent_id: row.get(1)?,
166 pid: row.get::<_, u32>(2)?,
167 registered_at: row.get(3)?,
168 last_heartbeat: row.get(4)?,
169 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
170 })
171 })
172 .unwrap()
173 .filter_map(|r| r.ok())
174 .collect();
175
176 Json(agents)
177}
178
179async fn send_message(
180 State(state): State<SharedState>,
181 Json(req): Json<SendRequest>,
182) -> (StatusCode, Json<Message>) {
183 let now = crate::Relay::now();
184 let id = format!("msg-{}", &uuid::Uuid::new_v4().to_string()[..8]);
185
186 let conn = state.lock().unwrap();
187 conn.execute(
188 "INSERT INTO messages (id, from_session, from_agent, to_session, content, timestamp)
189 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
190 rusqlite::params![id, req.from_session, req.from_agent, req.to_session, req.content, now],
191 )
192 .expect("Failed to insert message");
193
194 let _ = conn.execute(
196 "INSERT OR IGNORE INTO message_reads (message_id, session_id) VALUES (?1, ?2)",
197 rusqlite::params![id, req.from_session],
198 );
199
200 let msg = Message {
201 id,
202 from_session: req.from_session,
203 from_agent: req.from_agent,
204 to_session: req.to_session,
205 content: req.content,
206 timestamp: now,
207 read_by: vec![],
208 };
209
210 (StatusCode::CREATED, Json(msg))
211}
212
213async fn inbox(
214 State(state): State<SharedState>,
215 Query(q): Query<InboxQuery>,
216) -> Json<Vec<Message>> {
217 let conn = state.lock().unwrap();
218
219 let mut stmt = conn
220 .prepare(
221 "SELECT id, from_session, from_agent, to_session, content, timestamp
222 FROM messages
223 WHERE to_session IS NULL OR to_session = ?1 OR from_session = ?1
224 ORDER BY timestamp DESC
225 LIMIT ?2",
226 )
227 .unwrap();
228
229 let messages: Vec<Message> = stmt
230 .query_map(rusqlite::params![q.session, q.limit], |row| {
231 Ok(Message {
232 id: row.get(0)?,
233 from_session: row.get(1)?,
234 from_agent: row.get(2)?,
235 to_session: row.get(3)?,
236 content: row.get(4)?,
237 timestamp: row.get(5)?,
238 read_by: vec![],
239 })
240 })
241 .unwrap()
242 .filter_map(|r| r.ok())
243 .collect();
244
245 for msg in &messages {
247 let _ = conn.execute(
248 "INSERT OR IGNORE INTO message_reads (message_id, session_id) VALUES (?1, ?2)",
249 rusqlite::params![msg.id, q.session],
250 );
251 }
252
253 Json(messages)
254}
255
256async fn unread_count(
257 State(state): State<SharedState>,
258 Query(q): Query<InboxQuery>,
259) -> Json<CountResponse> {
260 let conn = state.lock().unwrap();
261
262 let count: u64 = conn
263 .query_row(
264 "SELECT COUNT(*) FROM messages m
265 WHERE (m.to_session IS NULL OR m.to_session = ?1)
266 AND m.from_session != ?1
267 AND NOT EXISTS (
268 SELECT 1 FROM message_reads mr
269 WHERE mr.message_id = m.id AND mr.session_id = ?1
270 )",
271 [&q.session],
272 |row| row.get(0),
273 )
274 .unwrap_or(0);
275
276 Json(CountResponse { count })
277}