Skip to main content

koi_mdns/
protocol.rs

1use serde::ser::SerializeMap;
2use serde::{Deserialize, Serialize, Serializer};
3
4use koi_common::api::{error_body, ErrorBody};
5use koi_common::error::ErrorCode;
6use koi_common::types::{EventKind, ServiceRecord};
7
8use crate::error::MdnsError;
9use crate::events::MdnsEvent;
10
11// ── mDNS-specific wire types ─────────────────────────────────────────
12//
13// The registration/admin wire-contract types now live in the kernel
14// (`koi_common::mdns_protocol`) so clients can speak the contract without the mDNS
15// engine. Re-exported here so `koi_mdns::protocol::*` paths are unchanged.
16pub use koi_common::mdns_protocol::{
17    AdminRegistration, DaemonStatus, LeaseMode, LeaseState, RegisterPayload, RegistrationCounts,
18    RegistrationResult, RenewalResult,
19};
20
21// ── Request ──────────────────────────────────────────────────────────
22
23/// All possible inbound operations for mDNS.
24/// The top-level JSON key determines the variant.
25#[derive(Debug, Deserialize)]
26#[serde(rename_all = "lowercase")]
27pub enum Request {
28    Browse(String),
29    Register(RegisterPayload),
30    Unregister(String),
31    Resolve(String),
32    Subscribe(String),
33    Heartbeat(String),
34}
35
36// ── Response ─────────────────────────────────────────────────────────
37
38/// All possible outbound messages for the mDNS domain.
39/// Custom Serialize ensures the correct JSON shape for each variant:
40/// - Found/Registered/Unregistered/Resolved: `{"found": {...}}`
41/// - Error: `{"error": "code", "message": "..."}`  (flat)
42/// - Event: `{"event": "kind", "service": {...}}`   (flat)
43#[derive(Debug, Clone)]
44pub enum Response {
45    Found(ServiceRecord),
46    Registered(RegistrationResult),
47    Unregistered(String),
48    Resolved(ServiceRecord),
49    Event {
50        event: EventKind,
51        service: ServiceRecord,
52    },
53    Renewed(RenewalResult),
54    Error(ErrorBody),
55}
56
57impl Serialize for Response {
58    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
59        match self {
60            Response::Found(record) => {
61                let mut map = serializer.serialize_map(Some(1))?;
62                map.serialize_entry("found", record)?;
63                map.end()
64            }
65            Response::Registered(result) => {
66                let mut map = serializer.serialize_map(Some(1))?;
67                map.serialize_entry("registered", result)?;
68                map.end()
69            }
70            Response::Unregistered(id) => {
71                let mut map = serializer.serialize_map(Some(1))?;
72                map.serialize_entry("unregistered", id)?;
73                map.end()
74            }
75            Response::Resolved(record) => {
76                let mut map = serializer.serialize_map(Some(1))?;
77                map.serialize_entry("resolved", record)?;
78                map.end()
79            }
80            Response::Event { event, service } => {
81                let mut map = serializer.serialize_map(Some(2))?;
82                map.serialize_entry("event", event)?;
83                map.serialize_entry("service", service)?;
84                map.end()
85            }
86            Response::Renewed(result) => {
87                let mut map = serializer.serialize_map(Some(1))?;
88                map.serialize_entry("renewed", result)?;
89                map.end()
90            }
91            Response::Error(body) => {
92                let mut map = serializer.serialize_map(Some(2))?;
93                map.serialize_entry("error", &body.error)?;
94                map.serialize_entry("message", &body.message)?;
95                map.end()
96            }
97        }
98    }
99}
100
101// ── Pipeline helpers ─────────────────────────────────────────────────
102
103use koi_common::pipeline::PipelineResponse;
104
105/// Type alias for mDNS pipeline responses.
106pub type MdnsPipelineResponse = PipelineResponse<Response>;
107
108/// Convert a browse event into a pipeline response.
109pub fn browse_event_to_pipeline(event: MdnsEvent) -> MdnsPipelineResponse {
110    match event {
111        MdnsEvent::Resolved(record) | MdnsEvent::Found(record) => {
112            PipelineResponse::clean(Response::Found(record))
113        }
114        MdnsEvent::Removed { name, service_type } => PipelineResponse::clean(Response::Event {
115            event: EventKind::Removed,
116            service: ServiceRecord {
117                name,
118                service_type,
119                host: None,
120                ip: None,
121                port: None,
122                txt: Default::default(),
123            },
124        }),
125    }
126}
127
128/// Convert a subscribe event into a pipeline response.
129pub fn subscribe_event_to_pipeline(event: MdnsEvent) -> MdnsPipelineResponse {
130    let (kind, record) = match event {
131        MdnsEvent::Found(record) => (EventKind::Found, record),
132        MdnsEvent::Resolved(record) => (EventKind::Resolved, record),
133        MdnsEvent::Removed { name, service_type } => (
134            EventKind::Removed,
135            ServiceRecord {
136                name,
137                service_type,
138                host: None,
139                ip: None,
140                port: None,
141                txt: Default::default(),
142            },
143        ),
144    };
145    PipelineResponse::clean(Response::Event {
146        event: kind,
147        service: record,
148    })
149}
150
151/// Convert an MdnsError into a pipeline error response.
152pub fn error_to_pipeline(e: &MdnsError) -> MdnsPipelineResponse {
153    PipelineResponse::clean(Response::Error(error_body(
154        ErrorCode::from(e),
155        e.to_string(),
156    )))
157}
158
159// ── Tests ────────────────────────────────────────────────────────────
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use std::collections::HashMap;
165
166    fn test_record() -> ServiceRecord {
167        ServiceRecord {
168            name: "Server A".into(),
169            service_type: "_http._tcp".into(),
170            host: Some("server.local".into()),
171            ip: Some("192.168.1.42".into()),
172            port: Some(8080),
173            txt: HashMap::from([("version".into(), "2.1".into())]),
174        }
175    }
176
177    // ── RegisterPayload tests ────────────────────────────────────────
178
179    #[test]
180    fn register_payload_deserializes_from_json() {
181        let json =
182            r#"{"name": "My App", "type": "_http._tcp", "port": 8080, "txt": {"version": "1.0"}}"#;
183        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
184        assert_eq!(payload.name, "My App");
185        assert_eq!(payload.service_type, "_http._tcp");
186        assert_eq!(payload.port, 8080);
187        assert_eq!(payload.txt.get("version").unwrap(), "1.0");
188    }
189
190    #[test]
191    fn register_payload_defaults_txt_to_empty() {
192        let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80}"#;
193        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
194        assert!(payload.txt.is_empty());
195    }
196
197    #[test]
198    fn register_payload_defaults_lease_to_none() {
199        let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80}"#;
200        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
201        assert!(payload.lease_secs.is_none());
202    }
203
204    #[test]
205    fn register_payload_accepts_lease_secs() {
206        let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80, "lease_secs": 300}"#;
207        let payload: RegisterPayload = serde_json::from_str(json).unwrap();
208        assert_eq!(payload.lease_secs, Some(300));
209    }
210
211    #[test]
212    fn lease_mode_serializes_to_lowercase() {
213        assert_eq!(serde_json::to_value(LeaseMode::Session).unwrap(), "session");
214        assert_eq!(
215            serde_json::to_value(LeaseMode::Heartbeat).unwrap(),
216            "heartbeat"
217        );
218        assert_eq!(
219            serde_json::to_value(LeaseMode::Permanent).unwrap(),
220            "permanent"
221        );
222    }
223
224    #[test]
225    fn lease_state_serializes_to_lowercase() {
226        assert_eq!(serde_json::to_value(LeaseState::Alive).unwrap(), "alive");
227        assert_eq!(
228            serde_json::to_value(LeaseState::Draining).unwrap(),
229            "draining"
230        );
231    }
232
233    #[test]
234    fn renewal_result_roundtrips() {
235        let r = RenewalResult {
236            id: "abc".into(),
237            lease_secs: 300,
238        };
239        let json = serde_json::to_string(&r).unwrap();
240        let r2: RenewalResult = serde_json::from_str(&json).unwrap();
241        assert_eq!(r, r2);
242    }
243
244    // ── Request tests ────────────────────────────────────────────────
245
246    #[test]
247    fn browse_request_parses() {
248        let json = r#"{"browse": "_http._tcp"}"#;
249        let req: Request = serde_json::from_str(json).unwrap();
250        assert!(matches!(req, Request::Browse(ref s) if s == "_http._tcp"));
251    }
252
253    #[test]
254    fn register_request_parses() {
255        let json = r#"{"register": {"name": "My App", "type": "_http._tcp", "port": 8080}}"#;
256        let req: Request = serde_json::from_str(json).unwrap();
257        assert!(matches!(req, Request::Register(ref p) if p.name == "My App"));
258    }
259
260    #[test]
261    fn unregister_request_parses() {
262        let json = r#"{"unregister": "abc123"}"#;
263        let req: Request = serde_json::from_str(json).unwrap();
264        assert!(matches!(req, Request::Unregister(ref id) if id == "abc123"));
265    }
266
267    #[test]
268    fn resolve_request_parses() {
269        let json = r#"{"resolve": "My App._http._tcp.local."}"#;
270        let req: Request = serde_json::from_str(json).unwrap();
271        assert!(matches!(req, Request::Resolve(ref s) if s == "My App._http._tcp.local."));
272    }
273
274    #[test]
275    fn subscribe_request_parses() {
276        let json = r#"{"subscribe": "_http._tcp"}"#;
277        let req: Request = serde_json::from_str(json).unwrap();
278        assert!(matches!(req, Request::Subscribe(ref s) if s == "_http._tcp"));
279    }
280
281    #[test]
282    fn heartbeat_request_parses() {
283        let json = r#"{"heartbeat": "a1b2c3d4"}"#;
284        let req: Request = serde_json::from_str(json).unwrap();
285        assert!(matches!(req, Request::Heartbeat(ref id) if id == "a1b2c3d4"));
286    }
287
288    #[test]
289    fn unknown_verb_fails() {
290        let json = r#"{"explode": "boom"}"#;
291        let result = serde_json::from_str::<Request>(json);
292        assert!(result.is_err());
293    }
294
295    // ── Response tests ───────────────────────────────────────────────
296
297    #[test]
298    fn clean_response_has_no_pipeline_properties() {
299        let resp = MdnsPipelineResponse::clean(Response::Found(test_record()));
300        let json = serde_json::to_value(&resp).unwrap();
301        let obj = json.as_object().unwrap();
302        assert!(!obj.contains_key("status"));
303        assert!(!obj.contains_key("warning"));
304        assert!(obj.contains_key("found"));
305    }
306
307    #[test]
308    fn flatten_produces_flat_json_not_nested() {
309        let resp = MdnsPipelineResponse::clean(Response::Found(test_record()));
310        let json = serde_json::to_value(&resp).unwrap();
311        assert!(json.get("found").is_some());
312        assert!(json.get("body").is_none());
313    }
314
315    #[test]
316    fn renewed_response_serializes_correctly() {
317        let resp = MdnsPipelineResponse::clean(Response::Renewed(RenewalResult {
318            id: "a1b2c3".into(),
319            lease_secs: 300,
320        }));
321        let json = serde_json::to_value(&resp).unwrap();
322        let renewed = json.get("renewed").unwrap();
323        assert_eq!(renewed.get("id").unwrap(), "a1b2c3");
324        assert_eq!(renewed.get("lease_secs").unwrap(), 300);
325    }
326
327    #[test]
328    fn error_response_serializes_correctly() {
329        let resp = MdnsPipelineResponse::clean(Response::Error(error_body(
330            ErrorCode::NotFound,
331            "No registration with id 'xyz'",
332        )));
333        let json = serde_json::to_value(&resp).unwrap();
334        assert_eq!(json.get("error").unwrap(), "not_found");
335        assert_eq!(
336            json.get("message").unwrap(),
337            "No registration with id 'xyz'"
338        );
339    }
340
341    #[test]
342    fn registered_response_serializes_correctly() {
343        let resp = MdnsPipelineResponse::clean(Response::Registered(RegistrationResult {
344            id: "a1b2c3".into(),
345            name: "My App".into(),
346            service_type: "_http._tcp".into(),
347            port: 8080,
348            mode: LeaseMode::Permanent,
349            lease_secs: None,
350        }));
351        let json = serde_json::to_value(&resp).unwrap();
352        let reg = json.get("registered").unwrap();
353        assert_eq!(reg.get("id").unwrap(), "a1b2c3");
354        assert_eq!(reg.get("name").unwrap(), "My App");
355    }
356
357    #[test]
358    fn unregistered_response_serializes_correctly() {
359        let resp = MdnsPipelineResponse::clean(Response::Unregistered("a1b2c3".into()));
360        let json = serde_json::to_value(&resp).unwrap();
361        assert_eq!(json.get("unregistered").unwrap(), "a1b2c3");
362    }
363
364    #[test]
365    fn event_response_serializes_correctly() {
366        let resp = MdnsPipelineResponse::clean(Response::Event {
367            event: EventKind::Found,
368            service: test_record(),
369        });
370        let json = serde_json::to_value(&resp).unwrap();
371        assert_eq!(json.get("event").unwrap(), "found");
372        assert!(json.get("service").is_some());
373    }
374
375    // ── Pipeline helper free function tests ─────────────────────────
376
377    #[test]
378    fn browse_event_resolved_produces_found() {
379        let event = MdnsEvent::Resolved(test_record());
380        let resp = browse_event_to_pipeline(event);
381        let json = serde_json::to_value(&resp).unwrap();
382        assert!(json.get("found").is_some(), "should have 'found' key");
383        assert_eq!(json.get("found").unwrap().get("name").unwrap(), "Server A");
384    }
385
386    #[test]
387    fn browse_event_removed_produces_event_removed() {
388        let event = MdnsEvent::Removed {
389            name: "Gone._http._tcp.local.".into(),
390            service_type: "_http._tcp".into(),
391        };
392        let resp = browse_event_to_pipeline(event);
393        let json = serde_json::to_value(&resp).unwrap();
394        assert_eq!(json.get("event").unwrap(), "removed");
395        assert_eq!(
396            json.get("service").unwrap().get("name").unwrap(),
397            "Gone._http._tcp.local."
398        );
399    }
400
401    #[test]
402    fn subscribe_event_found_produces_event_found() {
403        let event = MdnsEvent::Found(test_record());
404        let resp = subscribe_event_to_pipeline(event);
405        let json = serde_json::to_value(&resp).unwrap();
406        assert_eq!(json.get("event").unwrap(), "found");
407        assert!(json.get("service").is_some());
408    }
409
410    #[test]
411    fn subscribe_event_resolved_produces_event_resolved() {
412        let event = MdnsEvent::Resolved(test_record());
413        let resp = subscribe_event_to_pipeline(event);
414        let json = serde_json::to_value(&resp).unwrap();
415        assert_eq!(json.get("event").unwrap(), "resolved");
416        assert_eq!(
417            json.get("service").unwrap().get("name").unwrap(),
418            "Server A"
419        );
420    }
421
422    #[test]
423    fn error_to_pipeline_not_found() {
424        let err = MdnsError::RegistrationNotFound("xyz".into());
425        let resp = error_to_pipeline(&err);
426        let json = serde_json::to_value(&resp).unwrap();
427        assert_eq!(json.get("error").unwrap(), "not_found");
428        let msg = json.get("message").unwrap().as_str().unwrap();
429        assert!(msg.contains("xyz"), "message should contain id: {msg}");
430    }
431}