use axum::{
extract::{Query, State},
http::StatusCode,
routing::{get, post},
Json, Router,
};
use rusqlite::Connection;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::{Arc, Mutex};
use crate::{AgentRegistration, Message};
pub type SharedState = Arc<Mutex<Connection>>;
pub fn init_db(conn: &Connection) {
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
from_session TEXT NOT NULL,
from_agent TEXT NOT NULL,
to_session TEXT,
content TEXT NOT NULL,
timestamp INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS message_reads (
message_id TEXT NOT NULL,
session_id TEXT NOT NULL,
PRIMARY KEY (message_id, session_id)
);
CREATE TABLE IF NOT EXISTS agents (
session_id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
pid INTEGER NOT NULL,
registered_at INTEGER NOT NULL,
last_heartbeat INTEGER NOT NULL,
metadata TEXT NOT NULL DEFAULT '{}'
);",
)
.expect("Failed to initialize database schema");
}
#[derive(Deserialize)]
pub struct SendRequest {
pub from_session: String,
pub from_agent: String,
pub to_session: Option<String>,
pub content: String,
}
#[derive(Deserialize)]
pub struct InboxQuery {
pub session: String,
#[serde(default = "default_limit")]
pub limit: usize,
}
fn default_limit() -> usize {
50
}
#[derive(Deserialize)]
pub struct RegisterRequest {
pub session_id: String,
pub agent_id: String,
pub pid: u32,
#[serde(default)]
pub metadata: serde_json::Value,
}
#[derive(Serialize)]
pub struct CountResponse {
pub count: u64,
}
pub fn build_router(db_path: &Path) -> Router {
let conn = Connection::open(db_path).expect("Failed to open SQLite database");
conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;")
.expect("Failed to set pragmas");
init_db(&conn);
let state: SharedState = Arc::new(Mutex::new(conn));
Router::new()
.route("/health", get(health))
.route("/agents", get(list_agents))
.route("/agents/register", post(register_agent))
.route("/agents/unregister", post(unregister_agent))
.route("/messages/send", post(send_message))
.route("/messages/inbox", get(inbox))
.route("/messages/unread", get(unread_count))
.layer(tower_http::cors::CorsLayer::permissive())
.with_state(state)
}
async fn health() -> &'static str {
"agent-relay server ok"
}
async fn register_agent(
State(state): State<SharedState>,
Json(req): Json<RegisterRequest>,
) -> (StatusCode, Json<AgentRegistration>) {
let now = crate::Relay::now();
let metadata_str = req.metadata.to_string();
let conn = state.lock().unwrap();
conn.execute(
"INSERT OR REPLACE INTO agents (session_id, agent_id, pid, registered_at, last_heartbeat, metadata)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
rusqlite::params![req.session_id, req.agent_id, req.pid, now, now, metadata_str],
)
.expect("Failed to insert agent");
let reg = AgentRegistration {
session_id: req.session_id,
agent_id: req.agent_id,
pid: req.pid,
registered_at: now,
last_heartbeat: now,
metadata: req.metadata,
};
(StatusCode::CREATED, Json(reg))
}
async fn unregister_agent(
State(state): State<SharedState>,
Json(req): Json<serde_json::Value>,
) -> StatusCode {
let session = req["session_id"].as_str().unwrap_or("");
let conn = state.lock().unwrap();
let _ = conn.execute("DELETE FROM agents WHERE session_id = ?1", [session]);
StatusCode::OK
}
async fn list_agents(State(state): State<SharedState>) -> Json<Vec<AgentRegistration>> {
let conn = state.lock().unwrap();
let mut stmt = conn
.prepare("SELECT session_id, agent_id, pid, registered_at, last_heartbeat, metadata FROM agents ORDER BY last_heartbeat DESC")
.unwrap();
let agents: Vec<AgentRegistration> = stmt
.query_map([], |row| {
let metadata_str: String = row.get(5)?;
Ok(AgentRegistration {
session_id: row.get(0)?,
agent_id: row.get(1)?,
pid: row.get::<_, u32>(2)?,
registered_at: row.get(3)?,
last_heartbeat: row.get(4)?,
metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
})
})
.unwrap()
.filter_map(|r| r.ok())
.collect();
Json(agents)
}
async fn send_message(
State(state): State<SharedState>,
Json(req): Json<SendRequest>,
) -> (StatusCode, Json<Message>) {
let now = crate::Relay::now();
let id = format!("msg-{}", &uuid::Uuid::new_v4().to_string()[..8]);
let conn = state.lock().unwrap();
conn.execute(
"INSERT INTO messages (id, from_session, from_agent, to_session, content, timestamp)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
rusqlite::params![id, req.from_session, req.from_agent, req.to_session, req.content, now],
)
.expect("Failed to insert message");
let _ = conn.execute(
"INSERT OR IGNORE INTO message_reads (message_id, session_id) VALUES (?1, ?2)",
rusqlite::params![id, req.from_session],
);
let msg = Message {
id,
from_session: req.from_session,
from_agent: req.from_agent,
to_session: req.to_session,
content: req.content,
timestamp: now,
read_by: vec![],
};
(StatusCode::CREATED, Json(msg))
}
async fn inbox(
State(state): State<SharedState>,
Query(q): Query<InboxQuery>,
) -> Json<Vec<Message>> {
let conn = state.lock().unwrap();
let mut stmt = conn
.prepare(
"SELECT id, from_session, from_agent, to_session, content, timestamp
FROM messages
WHERE to_session IS NULL OR to_session = ?1 OR from_session = ?1
ORDER BY timestamp DESC
LIMIT ?2",
)
.unwrap();
let messages: Vec<Message> = stmt
.query_map(rusqlite::params![q.session, q.limit], |row| {
Ok(Message {
id: row.get(0)?,
from_session: row.get(1)?,
from_agent: row.get(2)?,
to_session: row.get(3)?,
content: row.get(4)?,
timestamp: row.get(5)?,
read_by: vec![],
})
})
.unwrap()
.filter_map(|r| r.ok())
.collect();
for msg in &messages {
let _ = conn.execute(
"INSERT OR IGNORE INTO message_reads (message_id, session_id) VALUES (?1, ?2)",
rusqlite::params![msg.id, q.session],
);
}
Json(messages)
}
async fn unread_count(
State(state): State<SharedState>,
Query(q): Query<InboxQuery>,
) -> Json<CountResponse> {
let conn = state.lock().unwrap();
let count: u64 = conn
.query_row(
"SELECT COUNT(*) FROM messages m
WHERE (m.to_session IS NULL OR m.to_session = ?1)
AND m.from_session != ?1
AND NOT EXISTS (
SELECT 1 FROM message_reads mr
WHERE mr.message_id = m.id AND mr.session_id = ?1
)",
[&q.session],
|row| row.get(0),
)
.unwrap_or(0);
Json(CountResponse { count })
}