Skip to main content

koi_mdns/
protocol.rs

1use std::collections::HashMap;
2
3use serde::ser::SerializeMap;
4use serde::{Deserialize, Serialize, Serializer};
5use utoipa::ToSchema;
6
7use koi_common::api::{error_body, ErrorBody};
8use koi_common::error::ErrorCode;
9use koi_common::types::{EventKind, ServiceRecord};
10
11use crate::error::MdnsError;
12use crate::events::MdnsEvent;
13
14// ── mDNS-specific wire types ─────────────────────────────────────────
15
16/// Payload for registering a new service.
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
18pub struct RegisterPayload {
19    pub name: String,
20    #[serde(rename = "type")]
21    pub service_type: String,
22    pub port: u16,
23    /// Pin the A/AAAA record to a specific IP address.
24    /// When absent, all machine IPs are advertised (auto-detect).
25    #[serde(default)]
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub ip: Option<String>,
28    #[serde(default)]
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub lease_secs: Option<u64>,
31    #[serde(default)]
32    pub txt: HashMap<String, String>,
33}
34
35/// Result of a successful registration.
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
37pub struct RegistrationResult {
38    pub id: String,
39    pub name: String,
40    #[serde(rename = "type")]
41    pub service_type: String,
42    pub port: u16,
43    pub mode: LeaseMode,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub lease_secs: Option<u64>,
46}
47
48/// Result of a successful lease renewal (heartbeat).
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
50pub struct RenewalResult {
51    pub id: String,
52    pub lease_secs: u64,
53}
54
55/// How a registration stays alive (wire representation).
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
57#[serde(rename_all = "lowercase")]
58pub enum LeaseMode {
59    Session,
60    Heartbeat,
61    Permanent,
62}
63
64/// Wire-level registration state (display-only projection).
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
66#[serde(rename_all = "lowercase")]
67pub enum LeaseState {
68    Alive,
69    Draining,
70}
71
72/// Full registration state as exposed to admin queries.
73#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
74pub struct AdminRegistration {
75    pub id: String,
76    pub name: String,
77    #[serde(rename = "type")]
78    pub service_type: String,
79    pub port: u16,
80    pub mode: LeaseMode,
81    pub state: LeaseState,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub lease_secs: Option<u64>,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub remaining_secs: Option<u64>,
86    pub grace_secs: u64,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub session_id: Option<String>,
89    pub registered_at: String,
90    pub last_seen: String,
91    #[serde(default)]
92    pub txt: HashMap<String, String>,
93}
94
95/// Daemon status overview for admin queries.
96#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
97pub struct DaemonStatus {
98    pub version: String,
99    pub uptime_secs: u64,
100    pub platform: String,
101    pub registrations: RegistrationCounts,
102}
103
104/// Registration counts by state.
105#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
106pub struct RegistrationCounts {
107    pub alive: usize,
108    pub draining: usize,
109    pub permanent: usize,
110    pub total: usize,
111}
112
113// ── Request ──────────────────────────────────────────────────────────
114
115/// All possible inbound operations for mDNS.
116/// The top-level JSON key determines the variant.
117#[derive(Debug, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum Request {
120    Browse(String),
121    Register(RegisterPayload),
122    Unregister(String),
123    Resolve(String),
124    Subscribe(String),
125    Heartbeat(String),
126}
127
128// ── Response ─────────────────────────────────────────────────────────
129
130/// All possible outbound messages for the mDNS domain.
131/// Custom Serialize ensures the correct JSON shape for each variant:
132/// - Found/Registered/Unregistered/Resolved: `{"found": {...}}`
133/// - Error: `{"error": "code", "message": "..."}`  (flat)
134/// - Event: `{"event": "kind", "service": {...}}`   (flat)
135#[derive(Debug, Clone)]
136pub enum Response {
137    Found(ServiceRecord),
138    Registered(RegistrationResult),
139    Unregistered(String),
140    Resolved(ServiceRecord),
141    Event {
142        event: EventKind,
143        service: ServiceRecord,
144    },
145    Renewed(RenewalResult),
146    Error(ErrorBody),
147}
148
149impl Serialize for Response {
150    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
151        match self {
152            Response::Found(record) => {
153                let mut map = serializer.serialize_map(Some(1))?;
154                map.serialize_entry("found", record)?;
155                map.end()
156            }
157            Response::Registered(result) => {
158                let mut map = serializer.serialize_map(Some(1))?;
159                map.serialize_entry("registered", result)?;
160                map.end()
161            }
162            Response::Unregistered(id) => {
163                let mut map = serializer.serialize_map(Some(1))?;
164                map.serialize_entry("unregistered", id)?;
165                map.end()
166            }
167            Response::Resolved(record) => {
168                let mut map = serializer.serialize_map(Some(1))?;
169                map.serialize_entry("resolved", record)?;
170                map.end()
171            }
172            Response::Event { event, service } => {
173                let mut map = serializer.serialize_map(Some(2))?;
174                map.serialize_entry("event", event)?;
175                map.serialize_entry("service", service)?;
176                map.end()
177            }
178            Response::Renewed(result) => {
179                let mut map = serializer.serialize_map(Some(1))?;
180                map.serialize_entry("renewed", result)?;
181                map.end()
182            }
183            Response::Error(body) => {
184                let mut map = serializer.serialize_map(Some(2))?;
185                map.serialize_entry("error", &body.error)?;
186                map.serialize_entry("message", &body.message)?;
187                map.end()
188            }
189        }
190    }
191}
192
193// ── Pipeline helpers ─────────────────────────────────────────────────
194
195use koi_common::pipeline::PipelineResponse;
196
197/// Type alias for mDNS pipeline responses.
198pub type MdnsPipelineResponse = PipelineResponse<Response>;
199
200/// Convert a browse event into a pipeline response.
201pub fn browse_event_to_pipeline(event: MdnsEvent) -> MdnsPipelineResponse {
202    match event {
203        MdnsEvent::Resolved(record) | MdnsEvent::Found(record) => {
204            PipelineResponse::clean(Response::Found(record))
205        }
206        MdnsEvent::Removed { name, service_type } => PipelineResponse::clean(Response::Event {
207            event: EventKind::Removed,
208            service: ServiceRecord {
209                name,
210                service_type,
211                host: None,
212                ip: None,
213                port: None,
214                txt: Default::default(),
215            },
216        }),
217    }
218}
219
220/// Convert a subscribe event into a pipeline response.
221pub fn subscribe_event_to_pipeline(event: MdnsEvent) -> MdnsPipelineResponse {
222    let (kind, record) = match event {
223        MdnsEvent::Found(record) => (EventKind::Found, record),
224        MdnsEvent::Resolved(record) => (EventKind::Resolved, record),
225        MdnsEvent::Removed { name, service_type } => (
226            EventKind::Removed,
227            ServiceRecord {
228                name,
229                service_type,
230                host: None,
231                ip: None,
232                port: None,
233                txt: Default::default(),
234            },
235        ),
236    };
237    PipelineResponse::clean(Response::Event {
238        event: kind,
239        service: record,
240    })
241}
242
243/// Convert an MdnsError into a pipeline error response.
244pub fn error_to_pipeline(e: &MdnsError) -> MdnsPipelineResponse {
245    PipelineResponse::clean(Response::Error(error_body(
246        ErrorCode::from(e),
247        e.to_string(),
248    )))
249}
250
251// ── Tests ────────────────────────────────────────────────────────────
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    fn test_record() -> ServiceRecord {
258        ServiceRecord {
259            name: "Server A".into(),
260            service_type: "_http._tcp".into(),
261            host: Some("server.local".into()),
262            ip: Some("192.168.1.42".into()),
263            port: Some(8080),
264            txt: HashMap::from([("version".into(), "2.1".into())]),
265        }
266    }
267
268    // ── RegisterPayload tests ────────────────────────────────────────
269
270    #[test]
271    fn register_payload_deserializes_from_json() {
272        let json =
273            r#"{"name": "My App", "type": "_http._tcp", "port": 8080, "txt": {"version": "1.0"}}"#;
274        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
275        assert_eq!(payload.name, "My App");
276        assert_eq!(payload.service_type, "_http._tcp");
277        assert_eq!(payload.port, 8080);
278        assert_eq!(payload.txt.get("version").unwrap(), "1.0");
279    }
280
281    #[test]
282    fn register_payload_defaults_txt_to_empty() {
283        let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80}"#;
284        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
285        assert!(payload.txt.is_empty());
286    }
287
288    #[test]
289    fn register_payload_defaults_lease_to_none() {
290        let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80}"#;
291        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
292        assert!(payload.lease_secs.is_none());
293    }
294
295    #[test]
296    fn register_payload_accepts_lease_secs() {
297        let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80, "lease_secs": 300}"#;
298        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
299        assert_eq!(payload.lease_secs, Some(300));
300    }
301
302    #[test]
303    fn lease_mode_serializes_to_lowercase() {
304        assert_eq!(serde_json::to_value(LeaseMode::Session).unwrap(), "session");
305        assert_eq!(
306            serde_json::to_value(LeaseMode::Heartbeat).unwrap(),
307            "heartbeat"
308        );
309        assert_eq!(
310            serde_json::to_value(LeaseMode::Permanent).unwrap(),
311            "permanent"
312        );
313    }
314
315    #[test]
316    fn lease_state_serializes_to_lowercase() {
317        assert_eq!(serde_json::to_value(LeaseState::Alive).unwrap(), "alive");
318        assert_eq!(
319            serde_json::to_value(LeaseState::Draining).unwrap(),
320            "draining"
321        );
322    }
323
324    #[test]
325    fn renewal_result_roundtrips() {
326        let r = RenewalResult {
327            id: "abc".into(),
328            lease_secs: 300,
329        };
330        let json = serde_json::to_string(&r).unwrap();
331        let r2: RenewalResult = serde_json::from_str(&json).unwrap();
332        assert_eq!(r, r2);
333    }
334
335    // ── Request tests ────────────────────────────────────────────────
336
337    #[test]
338    fn browse_request_parses() {
339        let json = r#"{"browse": "_http._tcp"}"#;
340        let req: Request = serde_json::from_str(json).unwrap();
341        assert!(matches!(req, Request::Browse(ref s) if s == "_http._tcp"));
342    }
343
344    #[test]
345    fn register_request_parses() {
346        let json = r#"{"register": {"name": "My App", "type": "_http._tcp", "port": 8080}}"#;
347        let req: Request = serde_json::from_str(json).unwrap();
348        assert!(matches!(req, Request::Register(ref p) if p.name == "My App"));
349    }
350
351    #[test]
352    fn unregister_request_parses() {
353        let json = r#"{"unregister": "abc123"}"#;
354        let req: Request = serde_json::from_str(json).unwrap();
355        assert!(matches!(req, Request::Unregister(ref id) if id == "abc123"));
356    }
357
358    #[test]
359    fn resolve_request_parses() {
360        let json = r#"{"resolve": "My App._http._tcp.local."}"#;
361        let req: Request = serde_json::from_str(json).unwrap();
362        assert!(matches!(req, Request::Resolve(ref s) if s == "My App._http._tcp.local."));
363    }
364
365    #[test]
366    fn subscribe_request_parses() {
367        let json = r#"{"subscribe": "_http._tcp"}"#;
368        let req: Request = serde_json::from_str(json).unwrap();
369        assert!(matches!(req, Request::Subscribe(ref s) if s == "_http._tcp"));
370    }
371
372    #[test]
373    fn heartbeat_request_parses() {
374        let json = r#"{"heartbeat": "a1b2c3d4"}"#;
375        let req: Request = serde_json::from_str(json).unwrap();
376        assert!(matches!(req, Request::Heartbeat(ref id) if id == "a1b2c3d4"));
377    }
378
379    #[test]
380    fn unknown_verb_fails() {
381        let json = r#"{"explode": "boom"}"#;
382        let result = serde_json::from_str::<Request>(json);
383        assert!(result.is_err());
384    }
385
386    // ── Response tests ───────────────────────────────────────────────
387
388    #[test]
389    fn clean_response_has_no_pipeline_properties() {
390        let resp = MdnsPipelineResponse::clean(Response::Found(test_record()));
391        let json = serde_json::to_value(&resp).unwrap();
392        let obj = json.as_object().unwrap();
393        assert!(!obj.contains_key("status"));
394        assert!(!obj.contains_key("warning"));
395        assert!(obj.contains_key("found"));
396    }
397
398    #[test]
399    fn ongoing_response_includes_status() {
400        let resp = MdnsPipelineResponse::ongoing(Response::Found(test_record()));
401        let json = serde_json::to_value(&resp).unwrap();
402        let obj = json.as_object().unwrap();
403        assert_eq!(obj.get("status").unwrap(), "ongoing");
404        assert!(obj.contains_key("found"));
405    }
406
407    #[test]
408    fn finished_response_includes_status() {
409        let resp = MdnsPipelineResponse::finished(Response::Found(test_record()));
410        let json = serde_json::to_value(&resp).unwrap();
411        assert_eq!(json.get("status").unwrap(), "finished");
412    }
413
414    #[test]
415    fn warning_attaches_to_response() {
416        let resp = MdnsPipelineResponse::finished(Response::Found(test_record()))
417            .with_warning("TXT empty");
418        let json = serde_json::to_value(&resp).unwrap();
419        assert_eq!(json.get("warning").unwrap(), "TXT empty");
420        assert_eq!(json.get("status").unwrap(), "finished");
421    }
422
423    #[test]
424    fn flatten_produces_flat_json_not_nested() {
425        let resp = MdnsPipelineResponse::clean(Response::Found(test_record()));
426        let json = serde_json::to_value(&resp).unwrap();
427        assert!(json.get("found").is_some());
428        assert!(json.get("body").is_none());
429    }
430
431    #[test]
432    fn renewed_response_serializes_correctly() {
433        let resp = MdnsPipelineResponse::clean(Response::Renewed(RenewalResult {
434            id: "a1b2c3".into(),
435            lease_secs: 300,
436        }));
437        let json = serde_json::to_value(&resp).unwrap();
438        let renewed = json.get("renewed").unwrap();
439        assert_eq!(renewed.get("id").unwrap(), "a1b2c3");
440        assert_eq!(renewed.get("lease_secs").unwrap(), 300);
441    }
442
443    #[test]
444    fn error_response_serializes_correctly() {
445        let resp = MdnsPipelineResponse::clean(Response::Error(error_body(
446            ErrorCode::NotFound,
447            "No registration with id 'xyz'",
448        )));
449        let json = serde_json::to_value(&resp).unwrap();
450        assert_eq!(json.get("error").unwrap(), "not_found");
451        assert_eq!(
452            json.get("message").unwrap(),
453            "No registration with id 'xyz'"
454        );
455    }
456
457    #[test]
458    fn registered_response_serializes_correctly() {
459        let resp = MdnsPipelineResponse::clean(Response::Registered(RegistrationResult {
460            id: "a1b2c3".into(),
461            name: "My App".into(),
462            service_type: "_http._tcp".into(),
463            port: 8080,
464            mode: LeaseMode::Permanent,
465            lease_secs: None,
466        }));
467        let json = serde_json::to_value(&resp).unwrap();
468        let reg = json.get("registered").unwrap();
469        assert_eq!(reg.get("id").unwrap(), "a1b2c3");
470        assert_eq!(reg.get("name").unwrap(), "My App");
471    }
472
473    #[test]
474    fn unregistered_response_serializes_correctly() {
475        let resp = MdnsPipelineResponse::clean(Response::Unregistered("a1b2c3".into()));
476        let json = serde_json::to_value(&resp).unwrap();
477        assert_eq!(json.get("unregistered").unwrap(), "a1b2c3");
478    }
479
480    #[test]
481    fn event_response_serializes_correctly() {
482        let resp = MdnsPipelineResponse::clean(Response::Event {
483            event: EventKind::Found,
484            service: test_record(),
485        });
486        let json = serde_json::to_value(&resp).unwrap();
487        assert_eq!(json.get("event").unwrap(), "found");
488        assert!(json.get("service").is_some());
489    }
490
491    // ── Pipeline helper free function tests ─────────────────────────
492
493    #[test]
494    fn browse_event_resolved_produces_found() {
495        let event = MdnsEvent::Resolved(test_record());
496        let resp = browse_event_to_pipeline(event);
497        let json = serde_json::to_value(&resp).unwrap();
498        assert!(json.get("found").is_some(), "should have 'found' key");
499        assert_eq!(json.get("found").unwrap().get("name").unwrap(), "Server A");
500    }
501
502    #[test]
503    fn browse_event_removed_produces_event_removed() {
504        let event = MdnsEvent::Removed {
505            name: "Gone._http._tcp.local.".into(),
506            service_type: "_http._tcp".into(),
507        };
508        let resp = browse_event_to_pipeline(event);
509        let json = serde_json::to_value(&resp).unwrap();
510        assert_eq!(json.get("event").unwrap(), "removed");
511        assert_eq!(
512            json.get("service").unwrap().get("name").unwrap(),
513            "Gone._http._tcp.local."
514        );
515    }
516
517    #[test]
518    fn subscribe_event_found_produces_event_found() {
519        let event = MdnsEvent::Found(test_record());
520        let resp = subscribe_event_to_pipeline(event);
521        let json = serde_json::to_value(&resp).unwrap();
522        assert_eq!(json.get("event").unwrap(), "found");
523        assert!(json.get("service").is_some());
524    }
525
526    #[test]
527    fn subscribe_event_resolved_produces_event_resolved() {
528        let event = MdnsEvent::Resolved(test_record());
529        let resp = subscribe_event_to_pipeline(event);
530        let json = serde_json::to_value(&resp).unwrap();
531        assert_eq!(json.get("event").unwrap(), "resolved");
532        assert_eq!(
533            json.get("service").unwrap().get("name").unwrap(),
534            "Server A"
535        );
536    }
537
538    #[test]
539    fn error_to_pipeline_not_found() {
540        let err = MdnsError::RegistrationNotFound("xyz".into());
541        let resp = error_to_pipeline(&err);
542        let json = serde_json::to_value(&resp).unwrap();
543        assert_eq!(json.get("error").unwrap(), "not_found");
544        let msg = json.get("message").unwrap().as_str().unwrap();
545        assert!(msg.contains("xyz"), "message should contain id: {msg}");
546    }
547}