mod browse;
mod daemon;
pub mod error;
pub mod events;
pub mod http;
pub mod protocol;
mod registry;
pub use self::browse::BrowseHandle;
pub use self::error::{MdnsError, Result};
pub use self::events::MdnsEvent;
pub use self::registry::LeasePolicy;
use std::sync::Arc;
use std::time::Instant;
use self::daemon::MdnsDaemon;
use self::registry::{InsertOutcome, Registry};
use koi_common::capability::{Capability, CapabilityStatus};
use koi_common::firewall::{FirewallPort, FirewallProtocol};
use koi_common::id::generate_short_id;
use koi_common::types::{ServiceRecord, ServiceType, SessionId, META_QUERY};
use tokio::sync::broadcast;
use tokio_util::sync::CancellationToken;
use crate::protocol::{
AdminRegistration, DaemonStatus, LeaseMode, RegisterPayload, RegistrationResult,
};
const BROADCAST_CHANNEL_CAPACITY: usize = 256;
const REAPER_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
pub const MDNS_PORT: u16 = 5353;
pub fn firewall_ports() -> Vec<FirewallPort> {
vec![FirewallPort::new("mDNS", FirewallProtocol::Udp, MDNS_PORT)]
}
pub struct MdnsCore {
daemon: Arc<MdnsDaemon>,
registry: Arc<Registry>,
event_tx: broadcast::Sender<MdnsEvent>,
started_at: Instant,
}
impl MdnsCore {
pub fn new() -> Result<Self> {
Self::with_cancel(CancellationToken::new())
}
pub fn with_cancel(cancel: CancellationToken) -> Result<Self> {
let daemon = Arc::new(MdnsDaemon::new()?);
let registry = Arc::new(Registry::new());
let (event_tx, _) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
let started_at = Instant::now();
let reaper_registry = registry.clone();
let reaper_daemon = daemon.clone();
let reaper_cancel = cancel.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(REAPER_INTERVAL);
loop {
tokio::select! {
_ = interval.tick() => {
let expired = reaper_registry.reap();
for (id, payload) in &expired {
tracing::info!(
name = %payload.name, id,
reason = "expired",
"Service unregistered"
);
if let Ok(st) = ServiceType::parse(&payload.service_type) {
let _ = reaper_daemon.unregister(&payload.name, st.as_str());
}
}
}
_ = reaper_cancel.cancelled() => {
tracing::debug!("Reaper task stopped");
break;
}
}
}
});
Ok(Self {
daemon,
registry,
event_tx,
started_at,
})
}
pub async fn browse(&self, service_type: &str) -> Result<BrowseHandle> {
let is_meta = service_type == META_QUERY;
let browse_type = if is_meta {
META_QUERY.to_string()
} else {
ServiceType::parse(service_type)?.as_str().to_string()
};
let receiver = self.daemon.browse(&browse_type).await?;
let event_tx = self.event_tx.clone();
Ok(BrowseHandle::new(
receiver,
event_tx,
is_meta,
browse_type,
self.daemon.clone(),
))
}
pub fn register(&self, payload: RegisterPayload) -> Result<RegistrationResult> {
self.register_with_policy(payload, LeasePolicy::Permanent, None)
}
pub fn register_with_policy(
&self,
payload: RegisterPayload,
policy: LeasePolicy,
session_id: Option<SessionId>,
) -> Result<RegistrationResult> {
let st = ServiceType::parse(&payload.service_type)?;
let new_id = generate_short_id();
let outcome =
self.registry
.insert_or_reconnect(new_id, payload.clone(), policy.clone(), session_id);
match &outcome {
InsertOutcome::New { id } => {
if let Err(e) = self.daemon.register(
&payload.name,
st.as_str(),
payload.port,
payload.ip.as_deref(),
&payload.txt,
) {
let _ = self.registry.remove(id);
return Err(e);
}
}
InsertOutcome::Reconnected { old_payload, .. } => {
if old_payload.port != payload.port || old_payload.txt != payload.txt {
let _ = self.daemon.unregister(&old_payload.name, st.as_str());
if let Err(e) = self.daemon.register(
&payload.name,
st.as_str(),
payload.port,
payload.ip.as_deref(),
&payload.txt,
) {
tracing::warn!(
name = %payload.name,
error = %e,
"Failed to re-register with updated payload during reconnection"
);
}
}
}
}
let id = outcome.id().to_string();
let (mode, lease_secs) = match &policy {
LeasePolicy::Session { .. } => (LeaseMode::Session, None),
LeasePolicy::Heartbeat { lease, .. } => (LeaseMode::Heartbeat, Some(lease.as_secs())),
LeasePolicy::Permanent => (LeaseMode::Permanent, None),
};
let result = RegistrationResult {
id,
name: payload.name.clone(),
service_type: st.short().to_string(),
port: payload.port,
mode,
lease_secs,
};
tracing::info!(
name = %result.name,
service_type = %result.service_type,
port = result.port,
id = %result.id,
"Service registered"
);
Ok(result)
}
pub fn heartbeat(&self, id: &str) -> Result<u64> {
self.registry.heartbeat(id)
}
pub fn session_disconnected(&self, session_id: &SessionId) {
let drained = self.registry.drain_session(session_id);
for id in &drained {
tracing::info!(
id,
session = %session_id.as_str(),
"Session disconnected, registration draining"
);
}
}
pub fn unregister(&self, id: &str) -> Result<()> {
let payload = self.registry.remove(id)?;
let st = ServiceType::parse(&payload.service_type)?;
self.daemon.unregister(&payload.name, st.as_str())?;
tracing::info!(name = %payload.name, id, reason = "explicit", "Service unregistered");
Ok(())
}
pub async fn resolve(&self, instance: &str) -> Result<ServiceRecord> {
self.daemon.resolve(instance).await
}
pub fn subscribe(&self) -> broadcast::Receiver<MdnsEvent> {
self.event_tx.subscribe()
}
pub async fn shutdown(&self) -> Result<()> {
let ids: Vec<String> = self.registry.all_ids();
for id in &ids {
if let Err(e) = self.unregister(id) {
tracing::warn!(id, error = %e, "Failed to unregister service during shutdown");
}
}
self.daemon.shutdown().await?;
tracing::info!("mDNS core shut down");
Ok(())
}
pub fn admin_status(&self) -> DaemonStatus {
DaemonStatus {
version: env!("CARGO_PKG_VERSION").to_string(),
uptime_secs: self.started_at.elapsed().as_secs(),
platform: std::env::consts::OS.to_string(),
registrations: self.registry.counts(),
}
}
pub fn admin_registrations(&self) -> Vec<(String, AdminRegistration)> {
self.registry.snapshot()
}
pub fn admin_inspect(&self, id_or_prefix: &str) -> Result<AdminRegistration> {
let full_id = self.registry.resolve_prefix(id_or_prefix)?;
self.registry.snapshot_one(&full_id)
}
pub fn admin_force_unregister(&self, id_or_prefix: &str) -> Result<()> {
let full_id = self.registry.resolve_prefix(id_or_prefix)?;
let payload = self.registry.remove(&full_id)?;
let st = ServiceType::parse(&payload.service_type)?;
let _ = self.daemon.unregister(&payload.name, st.as_str());
tracing::info!(
name = %payload.name,
id = %full_id,
reason = "admin_force",
"Service unregistered"
);
Ok(())
}
pub fn admin_drain(&self, id_or_prefix: &str) -> Result<()> {
let full_id = self.registry.resolve_prefix(id_or_prefix)?;
self.registry.force_drain(&full_id)
}
pub fn admin_revive(&self, id_or_prefix: &str) -> Result<()> {
let full_id = self.registry.resolve_prefix(id_or_prefix)?;
self.registry.force_revive(&full_id)
}
}
impl Capability for MdnsCore {
fn name(&self) -> &str {
"mdns"
}
fn status(&self) -> CapabilityStatus {
let counts = self.registry.counts();
let summary = format!(
"{} registered ({} alive, {} draining)",
counts.total, counts.alive, counts.draining
);
CapabilityStatus {
name: "mdns".to_string(),
summary,
healthy: true,
}
}
}