Skip to main content

koi_mdns/
lib.rs

1//! Koi mDNS - mDNS/DNS-SD service discovery domain.
2//!
3//! This crate implements the mDNS capability for Koi. It exposes a domain
4//! boundary via [`MdnsCore`] with three faces:
5//!
6//! - **Commands**: Methods that drive domain actions (register, browse, etc.)
7//! - **State**: Read-only snapshots (admin_status, admin_registrations)
8//! - **Events**: Broadcast channel for service lifecycle events
9
10mod browse;
11mod daemon;
12pub mod error;
13pub mod events;
14pub mod http;
15pub mod protocol;
16mod registry;
17
18pub use self::browse::BrowseHandle;
19pub use self::error::{MdnsError, Result};
20pub use self::events::MdnsEvent;
21pub use self::registry::LeasePolicy;
22
23use std::sync::Arc;
24use std::time::Instant;
25
26use self::daemon::MdnsDaemon;
27use self::registry::{InsertOutcome, Registry};
28
29use koi_common::capability::{Capability, CapabilityStatus};
30use koi_common::firewall::{FirewallPort, FirewallProtocol};
31use koi_common::id::generate_short_id;
32use koi_common::types::{ServiceRecord, ServiceType, SessionId, META_QUERY};
33use tokio::sync::broadcast;
34use tokio_util::sync::CancellationToken;
35
36use crate::protocol::{
37    AdminRegistration, DaemonStatus, LeaseMode, RegisterPayload, RegistrationResult,
38};
39
40/// Capacity for the broadcast channel used by service event subscribers.
41const BROADCAST_CHANNEL_CAPACITY: usize = 256;
42
43/// How often the reaper sweeps for expired registrations.
44const REAPER_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
45
46/// mDNS UDP port.
47pub const MDNS_PORT: u16 = 5353;
48
49/// Firewall ports required by the mDNS capability.
50pub fn firewall_ports() -> Vec<FirewallPort> {
51    vec![FirewallPort::new("mDNS", FirewallProtocol::Udp, MDNS_PORT)]
52}
53
54/// The core mDNS facade. All adapters interact through this.
55pub struct MdnsCore {
56    daemon: Arc<MdnsDaemon>,
57    registry: Arc<Registry>,
58    event_tx: broadcast::Sender<MdnsEvent>,
59    started_at: Instant,
60}
61
62impl MdnsCore {
63    /// Create a new core with a default (never-cancelled) token.
64    /// Used by standalone commands where the runtime drops on exit.
65    pub fn new() -> Result<Self> {
66        Self::with_cancel(CancellationToken::new())
67    }
68
69    /// Create a new core with a shared cancellation token.
70    /// Used by daemon mode for ordered shutdown.
71    pub fn with_cancel(cancel: CancellationToken) -> Result<Self> {
72        let daemon = Arc::new(MdnsDaemon::new()?);
73        let registry = Arc::new(Registry::new());
74        let (event_tx, _) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
75        let started_at = Instant::now();
76
77        // Spawn reaper task - sweeps expired registrations every 5 seconds
78        let reaper_registry = registry.clone();
79        let reaper_daemon = daemon.clone();
80        let reaper_cancel = cancel.clone();
81        tokio::spawn(async move {
82            let mut interval = tokio::time::interval(REAPER_INTERVAL);
83            loop {
84                tokio::select! {
85                    _ = interval.tick() => {
86                        let expired = reaper_registry.reap();
87                        for (id, payload) in &expired {
88                            tracing::info!(
89                                name = %payload.name, id,
90                                reason = "expired",
91                                "Service unregistered"
92                            );
93                            if let Ok(st) = ServiceType::parse(&payload.service_type) {
94                                let _ = reaper_daemon.unregister(&payload.name, st.as_str());
95                            }
96                        }
97                    }
98                    _ = reaper_cancel.cancelled() => {
99                        tracing::debug!("Reaper task stopped");
100                        break;
101                    }
102                }
103            }
104        });
105
106        Ok(Self {
107            daemon,
108            registry,
109            event_tx,
110            started_at,
111        })
112    }
113
114    // ── Commands ──────────────────────────────────────────────────────
115
116    /// Start browsing for services of the given type.
117    /// Pass `META_QUERY` to discover all service types on the network.
118    ///
119    /// The returned `BrowseHandle` calls `stop_browse` on drop, so the
120    /// underlying daemon resource is always cleaned up.
121    pub async fn browse(&self, service_type: &str) -> Result<BrowseHandle> {
122        let is_meta = service_type == META_QUERY;
123        let browse_type = if is_meta {
124            META_QUERY.to_string()
125        } else {
126            ServiceType::parse(service_type)?.as_str().to_string()
127        };
128        let receiver = self.daemon.browse(&browse_type).await?;
129        let event_tx = self.event_tx.clone();
130        Ok(BrowseHandle::new(
131            receiver,
132            event_tx,
133            is_meta,
134            browse_type,
135            self.daemon.clone(),
136        ))
137    }
138
139    /// Register a service with a default permanent policy.
140    pub fn register(&self, payload: RegisterPayload) -> Result<RegistrationResult> {
141        self.register_with_policy(payload, LeasePolicy::Permanent, None)
142    }
143
144    /// The single registration entry point. Every adapter explicitly chooses a policy.
145    pub fn register_with_policy(
146        &self,
147        payload: RegisterPayload,
148        policy: LeasePolicy,
149        session_id: Option<SessionId>,
150    ) -> Result<RegistrationResult> {
151        let st = ServiceType::parse(&payload.service_type)?;
152        let new_id = generate_short_id();
153
154        let outcome =
155            self.registry
156                .insert_or_reconnect(new_id, payload.clone(), policy.clone(), session_id);
157
158        match &outcome {
159            InsertOutcome::New { id } => {
160                if let Err(e) = self.daemon.register(
161                    &payload.name,
162                    st.as_str(),
163                    payload.port,
164                    payload.ip.as_deref(),
165                    &payload.txt,
166                ) {
167                    let _ = self.registry.remove(id);
168                    return Err(e);
169                }
170            }
171            InsertOutcome::Reconnected { old_payload, .. } => {
172                if old_payload.port != payload.port || old_payload.txt != payload.txt {
173                    let _ = self.daemon.unregister(&old_payload.name, st.as_str());
174                    if let Err(e) = self.daemon.register(
175                        &payload.name,
176                        st.as_str(),
177                        payload.port,
178                        payload.ip.as_deref(),
179                        &payload.txt,
180                    ) {
181                        tracing::warn!(
182                            name = %payload.name,
183                            error = %e,
184                            "Failed to re-register with updated payload during reconnection"
185                        );
186                    }
187                }
188            }
189        }
190
191        let id = outcome.id().to_string();
192        let (mode, lease_secs) = match &policy {
193            LeasePolicy::Session { .. } => (LeaseMode::Session, None),
194            LeasePolicy::Heartbeat { lease, .. } => (LeaseMode::Heartbeat, Some(lease.as_secs())),
195            LeasePolicy::Permanent => (LeaseMode::Permanent, None),
196        };
197
198        let result = RegistrationResult {
199            id,
200            name: payload.name.clone(),
201            service_type: st.short().to_string(),
202            port: payload.port,
203            mode,
204            lease_secs,
205        };
206
207        tracing::info!(
208            name = %result.name,
209            service_type = %result.service_type,
210            port = result.port,
211            id = %result.id,
212            "Service registered"
213        );
214
215        Ok(result)
216    }
217
218    /// Record a heartbeat for a registration. Resets last_seen; revives if draining.
219    /// Returns the lease duration in seconds (0 for non-heartbeat policies).
220    pub fn heartbeat(&self, id: &str) -> Result<u64> {
221        self.registry.heartbeat(id)
222    }
223
224    /// Notify the core that a session has disconnected.
225    /// All non-permanent registrations for this session begin draining.
226    pub fn session_disconnected(&self, session_id: &SessionId) {
227        let drained = self.registry.drain_session(session_id);
228        for id in &drained {
229            tracing::info!(
230                id,
231                session = %session_id.as_str(),
232                "Session disconnected, registration draining"
233            );
234        }
235    }
236
237    /// Unregister a previously registered service.
238    pub fn unregister(&self, id: &str) -> Result<()> {
239        let payload = self.registry.remove(id)?;
240        let st = ServiceType::parse(&payload.service_type)?;
241        self.daemon.unregister(&payload.name, st.as_str())?;
242        tracing::info!(name = %payload.name, id, reason = "explicit", "Service unregistered");
243        Ok(())
244    }
245
246    /// Resolve a specific service instance by its full name.
247    pub async fn resolve(&self, instance: &str) -> Result<ServiceRecord> {
248        self.daemon.resolve(instance).await
249    }
250
251    /// Subscribe to all service events. Returns a broadcast receiver.
252    pub fn subscribe(&self) -> broadcast::Receiver<MdnsEvent> {
253        self.event_tx.subscribe()
254    }
255
256    /// Shut down gracefully: unregister all services, then stop the daemon.
257    pub async fn shutdown(&self) -> Result<()> {
258        let ids: Vec<String> = self.registry.all_ids();
259        for id in &ids {
260            if let Err(e) = self.unregister(id) {
261                tracing::warn!(id, error = %e, "Failed to unregister service during shutdown");
262            }
263        }
264        self.daemon.shutdown().await?;
265        tracing::info!("mDNS core shut down");
266        Ok(())
267    }
268
269    // ── State (read-only snapshots) ──────────────────────────────────
270
271    /// Daemon status overview.
272    pub fn admin_status(&self) -> DaemonStatus {
273        DaemonStatus {
274            version: env!("CARGO_PKG_VERSION").to_string(),
275            uptime_secs: self.started_at.elapsed().as_secs(),
276            platform: std::env::consts::OS.to_string(),
277            registrations: self.registry.counts(),
278        }
279    }
280
281    /// Snapshot all registrations for admin display.
282    pub fn admin_registrations(&self) -> Vec<(String, AdminRegistration)> {
283        self.registry.snapshot()
284    }
285
286    /// Snapshot one registration by ID or prefix.
287    pub fn admin_inspect(&self, id_or_prefix: &str) -> Result<AdminRegistration> {
288        let full_id = self.registry.resolve_prefix(id_or_prefix)?;
289        self.registry.snapshot_one(&full_id)
290    }
291
292    /// Admin: force-unregister a registration by ID or prefix.
293    pub fn admin_force_unregister(&self, id_or_prefix: &str) -> Result<()> {
294        let full_id = self.registry.resolve_prefix(id_or_prefix)?;
295        let payload = self.registry.remove(&full_id)?;
296        let st = ServiceType::parse(&payload.service_type)?;
297        let _ = self.daemon.unregister(&payload.name, st.as_str());
298        tracing::info!(
299            name = %payload.name,
300            id = %full_id,
301            reason = "admin_force",
302            "Service unregistered"
303        );
304        Ok(())
305    }
306
307    /// Admin: force-drain a registration by ID or prefix.
308    pub fn admin_drain(&self, id_or_prefix: &str) -> Result<()> {
309        let full_id = self.registry.resolve_prefix(id_or_prefix)?;
310        self.registry.force_drain(&full_id)
311    }
312
313    /// Admin: force-revive a draining registration by ID or prefix.
314    pub fn admin_revive(&self, id_or_prefix: &str) -> Result<()> {
315        let full_id = self.registry.resolve_prefix(id_or_prefix)?;
316        self.registry.force_revive(&full_id)
317    }
318}
319
320impl Capability for MdnsCore {
321    fn name(&self) -> &str {
322        "mdns"
323    }
324
325    fn status(&self) -> CapabilityStatus {
326        let counts = self.registry.counts();
327        let summary = format!(
328            "{} registered ({} alive, {} draining)",
329            counts.total, counts.alive, counts.draining
330        );
331        CapabilityStatus {
332            name: "mdns".to_string(),
333            summary,
334            healthy: true,
335        }
336    }
337}