use std::sync::Arc;
use std::time::Duration;
use axum::extract::{Extension, Path, Query};
use axum::response::sse::{Event, Sse};
use axum::response::{IntoResponse, Json};
use axum::routing::{delete, get, post, put};
use axum::Router;
use tokio_stream::Stream;
use koi_common::error::ErrorCode;
use koi_common::pipeline::PipelineResponse;
use utoipa::IntoParams;
use crate::error::MdnsError;
use crate::protocol::{
AdminRegistration, DaemonStatus, RegisterPayload, RegistrationCounts, RegistrationResult,
RenewalResult, Response,
};
use crate::{LeasePolicy, MdnsCore};
const DEFAULT_HEARTBEAT_LEASE: Duration = Duration::from_secs(90);
const DEFAULT_HEARTBEAT_GRACE: Duration = Duration::from_secs(30);
const DEFAULT_SSE_IDLE: Duration = Duration::from_secs(5);
#[derive(Debug, serde::Deserialize, IntoParams)]
pub struct BrowseParams {
#[serde(rename = "type", default)]
pub service_type: Option<String>,
pub idle_for: Option<u64>,
}
#[derive(Debug, serde::Deserialize, IntoParams)]
pub struct ResolveParams {
pub name: String,
}
#[derive(Debug, serde::Deserialize, IntoParams)]
pub struct EventsParams {
#[serde(rename = "type")]
pub service_type: String,
pub idle_for: Option<u64>,
}
fn idle_duration(idle_for: Option<u64>) -> Option<Duration> {
match idle_for {
None => Some(DEFAULT_SSE_IDLE),
Some(0) => None,
Some(n) => Some(Duration::from_secs(n)),
}
}
pub mod paths {
pub const PREFIX: &str = "/v1/mdns";
pub const DISCOVER: &str = "/v1/mdns/discover";
pub const ANNOUNCE: &str = "/v1/mdns/announce";
pub const UNREGISTER: &str = "/v1/mdns/unregister/{id}";
pub const RESOLVE: &str = "/v1/mdns/resolve";
pub const SUBSCRIBE: &str = "/v1/mdns/subscribe";
pub const HEARTBEAT: &str = "/v1/mdns/heartbeat/{id}";
pub const ADMIN_STATUS: &str = "/v1/mdns/admin/status";
pub const ADMIN_LS: &str = "/v1/mdns/admin/ls";
pub const ADMIN_INSPECT: &str = "/v1/mdns/admin/inspect/{id}";
pub const ADMIN_UNREGISTER: &str = "/v1/mdns/admin/unregister/{id}";
pub const ADMIN_DRAIN: &str = "/v1/mdns/admin/drain/{id}";
pub const ADMIN_REVIVE: &str = "/v1/mdns/admin/revive/{id}";
pub fn rel(full: &str) -> &str {
full.strip_prefix(PREFIX).unwrap_or(full)
}
}
pub fn routes(core: Arc<MdnsCore>) -> Router {
use paths::rel;
Router::new()
.route(rel(paths::DISCOVER), get(browse_handler))
.route(rel(paths::ANNOUNCE), post(register_handler))
.route(rel(paths::UNREGISTER), delete(unregister_handler))
.route(rel(paths::HEARTBEAT), put(heartbeat_handler))
.route(rel(paths::RESOLVE), get(resolve_handler))
.route(rel(paths::SUBSCRIBE), get(events_handler))
.route(rel(paths::ADMIN_STATUS), get(admin_status_handler))
.route(rel(paths::ADMIN_LS), get(admin_registrations_handler))
.route(rel(paths::ADMIN_INSPECT), get(admin_inspect_handler))
.route(
rel(paths::ADMIN_UNREGISTER),
delete(admin_unregister_handler),
)
.route(rel(paths::ADMIN_DRAIN), post(admin_drain_handler))
.route(rel(paths::ADMIN_REVIVE), post(admin_revive_handler))
.layer(Extension(core))
}
#[utoipa::path(get, path = "/discover", tag = "mdns",
summary = "Browse for mDNS services (SSE stream)",
params(BrowseParams),
responses((status = 200, description = "SSE stream", content_type = "text/event-stream")))]
async fn browse_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Query(params): Query<BrowseParams>,
) -> impl IntoResponse {
let browse_type = params
.service_type
.as_deref()
.unwrap_or(koi_common::types::META_QUERY);
let handle = match core.browse(browse_type).await {
Ok(h) => h,
Err(e) => return Sse::new(error_event_stream(e)).into_response(),
};
let idle = idle_duration(params.idle_for);
let handle = Arc::new(handle);
let stream = async_stream::stream! {
loop {
let next = match idle {
Some(dur) => match tokio::time::timeout(dur, handle.recv()).await {
Ok(result) => result,
Err(_) => break, },
None => handle.recv().await,
};
match next {
Some(event) => {
let resp = crate::protocol::browse_event_to_pipeline(event);
match serde_json::to_string(&resp) {
Ok(data) => {
let id = uuid::Uuid::now_v7().to_string();
yield Ok::<_, std::convert::Infallible>(Event::default().id(id).data(data));
}
Err(e) => {
tracing::warn!(error = %e, "SSE browse event serialization failed");
break;
}
}
}
None => break,
}
}
};
Sse::new(stream).into_response()
}
#[utoipa::path(post, path = "/announce", tag = "mdns",
summary = "Register a new mDNS service",
request_body = RegisterPayload,
responses((status = 201, body = RegistrationResult)))]
async fn register_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Json(payload): Json<RegisterPayload>,
) -> impl IntoResponse {
let policy = match policy_from_lease_secs(payload.lease_secs) {
Ok(p) => p,
Err(e) => return error_json(e).into_response(),
};
match core.register_with_policy(payload, policy, None) {
Ok(result) => {
let resp = PipelineResponse::clean(Response::Registered(result));
(axum::http::StatusCode::CREATED, Json(resp)).into_response()
}
Err(e) => error_json(e).into_response(),
}
}
#[utoipa::path(delete, path = "/unregister/{id}", tag = "mdns",
summary = "Unregister a service",
params(("id" = String, Path, description = "Registration ID")),
responses((status = 200)))]
async fn unregister_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Path(id): Path<String>,
) -> impl IntoResponse {
match core.unregister(&id) {
Ok(()) => {
let resp = PipelineResponse::clean(Response::Unregistered(id));
Json(resp).into_response()
}
Err(e) => error_json(e).into_response(),
}
}
#[utoipa::path(get, path = "/resolve", tag = "mdns",
summary = "Resolve a service by name",
params(ResolveParams),
responses((status = 200)))]
async fn resolve_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Query(params): Query<ResolveParams>,
) -> impl IntoResponse {
match core.resolve(¶ms.name).await {
Ok(record) => {
let resp = PipelineResponse::clean(Response::Resolved(record));
Json(resp).into_response()
}
Err(e) => error_json(e).into_response(),
}
}
#[utoipa::path(get, path = "/subscribe", tag = "mdns",
summary = "Subscribe to mDNS lifecycle events (SSE stream)",
params(EventsParams),
responses((status = 200, description = "SSE stream", content_type = "text/event-stream")))]
async fn events_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Query(params): Query<EventsParams>,
) -> impl IntoResponse {
let handle = match core.browse(¶ms.service_type).await {
Ok(h) => h,
Err(e) => return Sse::new(error_event_stream(e)).into_response(),
};
let idle = idle_duration(params.idle_for);
let handle = Arc::new(handle);
let stream = async_stream::stream! {
loop {
let next = match idle {
Some(dur) => match tokio::time::timeout(dur, handle.recv()).await {
Ok(result) => result,
Err(_) => break, },
None => handle.recv().await,
};
match next {
Some(event) => {
let resp = crate::protocol::subscribe_event_to_pipeline(event);
match serde_json::to_string(&resp) {
Ok(data) => {
let id = uuid::Uuid::now_v7().to_string();
yield Ok::<_, std::convert::Infallible>(Event::default().id(id).data(data));
}
Err(e) => {
tracing::warn!(error = %e, "SSE subscribe event serialization failed");
break;
}
}
}
None => break,
}
}
};
Sse::new(stream).into_response()
}
#[utoipa::path(put, path = "/heartbeat/{id}", tag = "mdns",
summary = "Renew a service heartbeat lease",
params(("id" = String, Path, description = "Registration ID")),
responses((status = 200, body = RenewalResult)))]
async fn heartbeat_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Path(id): Path<String>,
) -> impl IntoResponse {
match core.heartbeat(&id) {
Ok(lease_secs) => {
let resp = PipelineResponse::clean(Response::Renewed(RenewalResult { id, lease_secs }));
Json(resp).into_response()
}
Err(e) => error_json(e).into_response(),
}
}
#[utoipa::path(get, path = "/admin/status", tag = "mdns",
summary = "Daemon status overview",
responses((status = 200, body = DaemonStatus)))]
async fn admin_status_handler(Extension(core): Extension<Arc<MdnsCore>>) -> impl IntoResponse {
Json(core.admin_status())
}
#[utoipa::path(get, path = "/admin/ls", tag = "mdns",
summary = "List all registrations",
responses((status = 200, body = Vec<AdminRegistration>)))]
async fn admin_registrations_handler(
Extension(core): Extension<Arc<MdnsCore>>,
) -> impl IntoResponse {
let entries: Vec<_> = core
.admin_registrations()
.into_iter()
.map(|(_, admin)| admin)
.collect();
Json(entries)
}
#[utoipa::path(get, path = "/admin/inspect/{id}", tag = "mdns",
summary = "Inspect a single registration",
params(("id" = String, Path, description = "Registration ID")),
responses((status = 200, body = AdminRegistration)))]
async fn admin_inspect_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Path(id): Path<String>,
) -> impl IntoResponse {
match core.admin_inspect(&id) {
Ok(admin) => Json(admin).into_response(),
Err(e) => error_json(e).into_response(),
}
}
#[utoipa::path(delete, path = "/admin/unregister/{id}", tag = "mdns",
summary = "Force-unregister a service",
params(("id" = String, Path, description = "Registration ID")),
responses((status = 200)))]
async fn admin_unregister_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Path(id): Path<String>,
) -> impl IntoResponse {
match core.admin_force_unregister(&id) {
Ok(()) => {
let resp = PipelineResponse::clean(Response::Unregistered(id));
Json(resp).into_response()
}
Err(e) => error_json(e).into_response(),
}
}
#[utoipa::path(post, path = "/admin/drain/{id}", tag = "mdns",
summary = "Begin grace period for a registration",
params(("id" = String, Path, description = "Registration ID")),
responses((status = 200)))]
async fn admin_drain_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Path(id): Path<String>,
) -> impl IntoResponse {
match core.admin_drain(&id) {
Ok(()) => Json(serde_json::json!({"drained": id})).into_response(),
Err(e) => error_json(e).into_response(),
}
}
#[utoipa::path(post, path = "/admin/revive/{id}", tag = "mdns",
summary = "Cancel draining and revive a registration",
params(("id" = String, Path, description = "Registration ID")),
responses((status = 200)))]
async fn admin_revive_handler(
Extension(core): Extension<Arc<MdnsCore>>,
Path(id): Path<String>,
) -> impl IntoResponse {
match core.admin_revive(&id) {
Ok(()) => Json(serde_json::json!({"revived": id})).into_response(),
Err(e) => error_json(e).into_response(),
}
}
fn error_json(e: MdnsError) -> impl IntoResponse {
let code = ErrorCode::from(&e);
let status_code = axum::http::StatusCode::from_u16(code.http_status())
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
(status_code, Json(crate::protocol::error_to_pipeline(&e)))
}
fn error_event_stream(
e: MdnsError,
) -> impl Stream<Item = std::result::Result<Event, std::convert::Infallible>> {
let data =
serde_json::to_string(&crate::protocol::error_to_pipeline(&e)).unwrap_or_else(|_| {
let msg = serde_json::Value::String(e.to_string());
format!(r#"{{"error":"serialization_failed","message":{msg}}}"#)
});
async_stream::stream! {
let id = uuid::Uuid::now_v7().to_string();
yield Ok(Event::default().id(id).data(data));
}
}
fn policy_from_lease_secs(lease_secs: Option<u64>) -> Result<LeasePolicy, MdnsError> {
match lease_secs {
Some(0) => Err(MdnsError::InvalidPayload(
"lease_secs=0 is not allowed via HTTP".into(),
)),
Some(secs) => Ok(LeasePolicy::Heartbeat {
lease: Duration::from_secs(secs),
grace: DEFAULT_HEARTBEAT_GRACE,
}),
None => Ok(LeasePolicy::Heartbeat {
lease: DEFAULT_HEARTBEAT_LEASE,
grace: DEFAULT_HEARTBEAT_GRACE,
}),
}
}
#[derive(utoipa::OpenApi)]
#[openapi(
paths(
browse_handler,
register_handler,
unregister_handler,
resolve_handler,
events_handler,
heartbeat_handler,
admin_status_handler,
admin_registrations_handler,
admin_inspect_handler,
admin_unregister_handler,
admin_drain_handler,
admin_revive_handler,
),
components(schemas(
RegisterPayload,
RegistrationResult,
RenewalResult,
AdminRegistration,
DaemonStatus,
RegistrationCounts,
crate::protocol::LeaseMode,
crate::protocol::LeaseState,
))
)]
pub struct MdnsApiDoc;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn idle_duration_absent_returns_default_5s() {
let d = idle_duration(None);
assert_eq!(d, Some(Duration::from_secs(5)));
}
#[test]
fn idle_duration_zero_returns_none_infinite() {
let d = idle_duration(Some(0));
assert_eq!(d, None);
}
#[test]
fn idle_duration_explicit_value() {
let d = idle_duration(Some(15));
assert_eq!(d, Some(Duration::from_secs(15)));
}
#[test]
fn idle_duration_one_second() {
let d = idle_duration(Some(1));
assert_eq!(d, Some(Duration::from_secs(1)));
}
#[test]
fn policy_from_none_returns_default_heartbeat() {
let policy = policy_from_lease_secs(None).unwrap();
assert!(matches!(
policy,
LeasePolicy::Heartbeat { lease, grace }
if lease == Duration::from_secs(90) && grace == Duration::from_secs(30)
));
}
#[test]
fn policy_from_zero_returns_error() {
let result = policy_from_lease_secs(Some(0));
assert!(result.is_err(), "lease_secs=0 should be rejected via HTTP");
}
#[test]
fn policy_from_custom_returns_heartbeat_with_custom_lease() {
let policy = policy_from_lease_secs(Some(300)).unwrap();
assert!(matches!(
policy,
LeasePolicy::Heartbeat { lease, grace }
if lease == Duration::from_secs(300) && grace == Duration::from_secs(30)
));
}
#[test]
fn browse_params_deserializes_type_field() {
let json = r#"{"type": "_http._tcp"}"#;
let params: BrowseParams = serde_json::from_str(json).unwrap();
assert_eq!(params.service_type.as_deref(), Some("_http._tcp"));
assert!(params.idle_for.is_none());
}
#[test]
fn browse_params_type_is_optional() {
let json = r#"{}"#;
let params: BrowseParams = serde_json::from_str(json).unwrap();
assert!(params.service_type.is_none());
}
#[test]
fn browse_params_deserializes_idle_for() {
let json = r#"{"type": "_http._tcp", "idle_for": 10}"#;
let params: BrowseParams = serde_json::from_str(json).unwrap();
assert_eq!(params.idle_for, Some(10));
}
#[test]
fn resolve_params_deserializes_name() {
let json = r#"{"name": "My NAS._http._tcp.local."}"#;
let params: ResolveParams = serde_json::from_str(json).unwrap();
assert_eq!(params.name, "My NAS._http._tcp.local.");
}
#[tokio::test]
async fn error_json_not_found_maps_to_404() {
let resp = error_json(MdnsError::RegistrationNotFound("abc".into())).into_response();
assert_eq!(resp.status(), axum::http::StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn error_json_invalid_type_maps_to_400() {
let resp = error_json(MdnsError::InvalidServiceType("bad".into())).into_response();
assert_eq!(resp.status(), axum::http::StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn error_json_body_is_json_with_error_field() {
let resp = error_json(MdnsError::RegistrationNotFound("xyz".into())).into_response();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("error").is_some());
assert!(json.get("message").is_some());
}
#[test]
fn events_params_deserializes() {
let json = r#"{"type": "_http._tcp", "idle_for": 0}"#;
let params: EventsParams = serde_json::from_str(json).unwrap();
assert_eq!(params.service_type, "_http._tcp");
assert_eq!(params.idle_for, Some(0));
}
#[test]
fn events_params_without_idle_for() {
let json = r#"{"type": "_ssh._tcp"}"#;
let params: EventsParams = serde_json::from_str(json).unwrap();
assert_eq!(params.service_type, "_ssh._tcp");
assert!(params.idle_for.is_none());
}
#[test]
fn default_heartbeat_lease_is_90s() {
assert_eq!(DEFAULT_HEARTBEAT_LEASE, Duration::from_secs(90));
}
#[test]
fn default_heartbeat_grace_is_30s() {
assert_eq!(DEFAULT_HEARTBEAT_GRACE, Duration::from_secs(30));
}
#[test]
fn default_sse_idle_is_5s() {
assert_eq!(DEFAULT_SSE_IDLE, Duration::from_secs(5));
}
#[test]
fn policy_from_one_second_returns_heartbeat() {
let policy = policy_from_lease_secs(Some(1)).unwrap();
assert!(matches!(
policy,
LeasePolicy::Heartbeat { lease, .. }
if lease == Duration::from_secs(1)
));
}
#[test]
fn policy_from_u64_max_returns_heartbeat() {
let policy = policy_from_lease_secs(Some(u64::MAX)).unwrap();
assert!(matches!(policy, LeasePolicy::Heartbeat { .. }));
}
#[test]
fn uuid_v7_is_valid_sse_id() {
let id = uuid::Uuid::now_v7().to_string();
assert_eq!(id.len(), 36, "UUIDv7 string should be 36 chars: {id}");
assert!(!id.contains('\n'), "must not contain newlines");
assert!(!id.contains('\r'), "must not contain carriage returns");
}
#[test]
fn uuid_v7_is_monotonic() {
let a = uuid::Uuid::now_v7().to_string();
let b = uuid::Uuid::now_v7().to_string();
assert!(a <= b, "UUIDv7 should be monotonic: {a} <= {b}");
}
#[test]
fn uuid_v7_is_unique() {
let ids: std::collections::HashSet<String> =
(0..100).map(|_| uuid::Uuid::now_v7().to_string()).collect();
assert_eq!(ids.len(), 100, "100 UUIDv7 IDs should all be unique");
}
}