1use serde::ser::SerializeMap;
2use serde::{Deserialize, Serialize, Serializer};
3
4use koi_common::api::{error_body, ErrorBody};
5use koi_common::error::ErrorCode;
6use koi_common::types::{EventKind, ServiceRecord};
7
8use crate::error::MdnsError;
9use crate::events::MdnsEvent;
10
11pub use koi_common::mdns_protocol::{
17 AdminRegistration, DaemonStatus, LeaseMode, LeaseState, RegisterPayload, RegistrationCounts,
18 RegistrationResult, RenewalResult,
19};
20
21#[derive(Debug, Deserialize)]
26#[serde(rename_all = "lowercase")]
27pub enum Request {
28 Browse(String),
29 Register(RegisterPayload),
30 Unregister(String),
31 Resolve(String),
32 Subscribe(String),
33 Heartbeat(String),
34}
35
36#[derive(Debug, Clone)]
44pub enum Response {
45 Found(ServiceRecord),
46 Registered(RegistrationResult),
47 Unregistered(String),
48 Resolved(ServiceRecord),
49 Event {
50 event: EventKind,
51 service: ServiceRecord,
52 },
53 Renewed(RenewalResult),
54 Error(ErrorBody),
55}
56
57impl Serialize for Response {
58 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
59 match self {
60 Response::Found(record) => {
61 let mut map = serializer.serialize_map(Some(1))?;
62 map.serialize_entry("found", record)?;
63 map.end()
64 }
65 Response::Registered(result) => {
66 let mut map = serializer.serialize_map(Some(1))?;
67 map.serialize_entry("registered", result)?;
68 map.end()
69 }
70 Response::Unregistered(id) => {
71 let mut map = serializer.serialize_map(Some(1))?;
72 map.serialize_entry("unregistered", id)?;
73 map.end()
74 }
75 Response::Resolved(record) => {
76 let mut map = serializer.serialize_map(Some(1))?;
77 map.serialize_entry("resolved", record)?;
78 map.end()
79 }
80 Response::Event { event, service } => {
81 let mut map = serializer.serialize_map(Some(2))?;
82 map.serialize_entry("event", event)?;
83 map.serialize_entry("service", service)?;
84 map.end()
85 }
86 Response::Renewed(result) => {
87 let mut map = serializer.serialize_map(Some(1))?;
88 map.serialize_entry("renewed", result)?;
89 map.end()
90 }
91 Response::Error(body) => {
92 let mut map = serializer.serialize_map(Some(2))?;
93 map.serialize_entry("error", &body.error)?;
94 map.serialize_entry("message", &body.message)?;
95 map.end()
96 }
97 }
98 }
99}
100
101use koi_common::pipeline::PipelineResponse;
104
105pub type MdnsPipelineResponse = PipelineResponse<Response>;
107
108pub fn browse_event_to_pipeline(event: MdnsEvent) -> MdnsPipelineResponse {
110 match event {
111 MdnsEvent::Resolved(record) | MdnsEvent::Found(record) => {
112 PipelineResponse::clean(Response::Found(record))
113 }
114 MdnsEvent::Removed { name, service_type } => PipelineResponse::clean(Response::Event {
115 event: EventKind::Removed,
116 service: ServiceRecord {
117 name,
118 service_type,
119 host: None,
120 ip: None,
121 port: None,
122 txt: Default::default(),
123 },
124 }),
125 }
126}
127
128pub fn subscribe_event_to_pipeline(event: MdnsEvent) -> MdnsPipelineResponse {
130 let (kind, record) = match event {
131 MdnsEvent::Found(record) => (EventKind::Found, record),
132 MdnsEvent::Resolved(record) => (EventKind::Resolved, record),
133 MdnsEvent::Removed { name, service_type } => (
134 EventKind::Removed,
135 ServiceRecord {
136 name,
137 service_type,
138 host: None,
139 ip: None,
140 port: None,
141 txt: Default::default(),
142 },
143 ),
144 };
145 PipelineResponse::clean(Response::Event {
146 event: kind,
147 service: record,
148 })
149}
150
151pub fn error_to_pipeline(e: &MdnsError) -> MdnsPipelineResponse {
153 PipelineResponse::clean(Response::Error(error_body(
154 ErrorCode::from(e),
155 e.to_string(),
156 )))
157}
158
159#[cfg(test)]
162mod tests {
163 use super::*;
164 use std::collections::HashMap;
165
166 fn test_record() -> ServiceRecord {
167 ServiceRecord {
168 name: "Server A".into(),
169 service_type: "_http._tcp".into(),
170 host: Some("server.local".into()),
171 ip: Some("192.168.1.42".into()),
172 port: Some(8080),
173 txt: HashMap::from([("version".into(), "2.1".into())]),
174 }
175 }
176
177 #[test]
180 fn register_payload_deserializes_from_json() {
181 let json =
182 r#"{"name": "My App", "type": "_http._tcp", "port": 8080, "txt": {"version": "1.0"}}"#;
183 let payload: RegisterPayload = serde_json::from_str(json).unwrap();
184 assert_eq!(payload.name, "My App");
185 assert_eq!(payload.service_type, "_http._tcp");
186 assert_eq!(payload.port, 8080);
187 assert_eq!(payload.txt.get("version").unwrap(), "1.0");
188 }
189
190 #[test]
191 fn register_payload_defaults_txt_to_empty() {
192 let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80}"#;
193 let payload: RegisterPayload = serde_json::from_str(json).unwrap();
194 assert!(payload.txt.is_empty());
195 }
196
197 #[test]
198 fn register_payload_defaults_lease_to_none() {
199 let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80}"#;
200 let payload: RegisterPayload = serde_json::from_str(json).unwrap();
201 assert!(payload.lease_secs.is_none());
202 }
203
204 #[test]
205 fn register_payload_accepts_lease_secs() {
206 let json = r#"{"name": "Bare", "type": "_http._tcp", "port": 80, "lease_secs": 300}"#;
207 let payload: RegisterPayload = serde_json::from_str(json).unwrap();
208 assert_eq!(payload.lease_secs, Some(300));
209 }
210
211 #[test]
212 fn lease_mode_serializes_to_lowercase() {
213 assert_eq!(serde_json::to_value(LeaseMode::Session).unwrap(), "session");
214 assert_eq!(
215 serde_json::to_value(LeaseMode::Heartbeat).unwrap(),
216 "heartbeat"
217 );
218 assert_eq!(
219 serde_json::to_value(LeaseMode::Permanent).unwrap(),
220 "permanent"
221 );
222 }
223
224 #[test]
225 fn lease_state_serializes_to_lowercase() {
226 assert_eq!(serde_json::to_value(LeaseState::Alive).unwrap(), "alive");
227 assert_eq!(
228 serde_json::to_value(LeaseState::Draining).unwrap(),
229 "draining"
230 );
231 }
232
233 #[test]
234 fn renewal_result_roundtrips() {
235 let r = RenewalResult {
236 id: "abc".into(),
237 lease_secs: 300,
238 };
239 let json = serde_json::to_string(&r).unwrap();
240 let r2: RenewalResult = serde_json::from_str(&json).unwrap();
241 assert_eq!(r, r2);
242 }
243
244 #[test]
247 fn browse_request_parses() {
248 let json = r#"{"browse": "_http._tcp"}"#;
249 let req: Request = serde_json::from_str(json).unwrap();
250 assert!(matches!(req, Request::Browse(ref s) if s == "_http._tcp"));
251 }
252
253 #[test]
254 fn register_request_parses() {
255 let json = r#"{"register": {"name": "My App", "type": "_http._tcp", "port": 8080}}"#;
256 let req: Request = serde_json::from_str(json).unwrap();
257 assert!(matches!(req, Request::Register(ref p) if p.name == "My App"));
258 }
259
260 #[test]
261 fn unregister_request_parses() {
262 let json = r#"{"unregister": "abc123"}"#;
263 let req: Request = serde_json::from_str(json).unwrap();
264 assert!(matches!(req, Request::Unregister(ref id) if id == "abc123"));
265 }
266
267 #[test]
268 fn resolve_request_parses() {
269 let json = r#"{"resolve": "My App._http._tcp.local."}"#;
270 let req: Request = serde_json::from_str(json).unwrap();
271 assert!(matches!(req, Request::Resolve(ref s) if s == "My App._http._tcp.local."));
272 }
273
274 #[test]
275 fn subscribe_request_parses() {
276 let json = r#"{"subscribe": "_http._tcp"}"#;
277 let req: Request = serde_json::from_str(json).unwrap();
278 assert!(matches!(req, Request::Subscribe(ref s) if s == "_http._tcp"));
279 }
280
281 #[test]
282 fn heartbeat_request_parses() {
283 let json = r#"{"heartbeat": "a1b2c3d4"}"#;
284 let req: Request = serde_json::from_str(json).unwrap();
285 assert!(matches!(req, Request::Heartbeat(ref id) if id == "a1b2c3d4"));
286 }
287
288 #[test]
289 fn unknown_verb_fails() {
290 let json = r#"{"explode": "boom"}"#;
291 let result = serde_json::from_str::<Request>(json);
292 assert!(result.is_err());
293 }
294
295 #[test]
298 fn clean_response_has_no_pipeline_properties() {
299 let resp = MdnsPipelineResponse::clean(Response::Found(test_record()));
300 let json = serde_json::to_value(&resp).unwrap();
301 let obj = json.as_object().unwrap();
302 assert!(!obj.contains_key("status"));
303 assert!(!obj.contains_key("warning"));
304 assert!(obj.contains_key("found"));
305 }
306
307 #[test]
308 fn flatten_produces_flat_json_not_nested() {
309 let resp = MdnsPipelineResponse::clean(Response::Found(test_record()));
310 let json = serde_json::to_value(&resp).unwrap();
311 assert!(json.get("found").is_some());
312 assert!(json.get("body").is_none());
313 }
314
315 #[test]
316 fn renewed_response_serializes_correctly() {
317 let resp = MdnsPipelineResponse::clean(Response::Renewed(RenewalResult {
318 id: "a1b2c3".into(),
319 lease_secs: 300,
320 }));
321 let json = serde_json::to_value(&resp).unwrap();
322 let renewed = json.get("renewed").unwrap();
323 assert_eq!(renewed.get("id").unwrap(), "a1b2c3");
324 assert_eq!(renewed.get("lease_secs").unwrap(), 300);
325 }
326
327 #[test]
328 fn error_response_serializes_correctly() {
329 let resp = MdnsPipelineResponse::clean(Response::Error(error_body(
330 ErrorCode::NotFound,
331 "No registration with id 'xyz'",
332 )));
333 let json = serde_json::to_value(&resp).unwrap();
334 assert_eq!(json.get("error").unwrap(), "not_found");
335 assert_eq!(
336 json.get("message").unwrap(),
337 "No registration with id 'xyz'"
338 );
339 }
340
341 #[test]
342 fn registered_response_serializes_correctly() {
343 let resp = MdnsPipelineResponse::clean(Response::Registered(RegistrationResult {
344 id: "a1b2c3".into(),
345 name: "My App".into(),
346 service_type: "_http._tcp".into(),
347 port: 8080,
348 mode: LeaseMode::Permanent,
349 lease_secs: None,
350 }));
351 let json = serde_json::to_value(&resp).unwrap();
352 let reg = json.get("registered").unwrap();
353 assert_eq!(reg.get("id").unwrap(), "a1b2c3");
354 assert_eq!(reg.get("name").unwrap(), "My App");
355 }
356
357 #[test]
358 fn unregistered_response_serializes_correctly() {
359 let resp = MdnsPipelineResponse::clean(Response::Unregistered("a1b2c3".into()));
360 let json = serde_json::to_value(&resp).unwrap();
361 assert_eq!(json.get("unregistered").unwrap(), "a1b2c3");
362 }
363
364 #[test]
365 fn event_response_serializes_correctly() {
366 let resp = MdnsPipelineResponse::clean(Response::Event {
367 event: EventKind::Found,
368 service: test_record(),
369 });
370 let json = serde_json::to_value(&resp).unwrap();
371 assert_eq!(json.get("event").unwrap(), "found");
372 assert!(json.get("service").is_some());
373 }
374
375 #[test]
378 fn browse_event_resolved_produces_found() {
379 let event = MdnsEvent::Resolved(test_record());
380 let resp = browse_event_to_pipeline(event);
381 let json = serde_json::to_value(&resp).unwrap();
382 assert!(json.get("found").is_some(), "should have 'found' key");
383 assert_eq!(json.get("found").unwrap().get("name").unwrap(), "Server A");
384 }
385
386 #[test]
387 fn browse_event_removed_produces_event_removed() {
388 let event = MdnsEvent::Removed {
389 name: "Gone._http._tcp.local.".into(),
390 service_type: "_http._tcp".into(),
391 };
392 let resp = browse_event_to_pipeline(event);
393 let json = serde_json::to_value(&resp).unwrap();
394 assert_eq!(json.get("event").unwrap(), "removed");
395 assert_eq!(
396 json.get("service").unwrap().get("name").unwrap(),
397 "Gone._http._tcp.local."
398 );
399 }
400
401 #[test]
402 fn subscribe_event_found_produces_event_found() {
403 let event = MdnsEvent::Found(test_record());
404 let resp = subscribe_event_to_pipeline(event);
405 let json = serde_json::to_value(&resp).unwrap();
406 assert_eq!(json.get("event").unwrap(), "found");
407 assert!(json.get("service").is_some());
408 }
409
410 #[test]
411 fn subscribe_event_resolved_produces_event_resolved() {
412 let event = MdnsEvent::Resolved(test_record());
413 let resp = subscribe_event_to_pipeline(event);
414 let json = serde_json::to_value(&resp).unwrap();
415 assert_eq!(json.get("event").unwrap(), "resolved");
416 assert_eq!(
417 json.get("service").unwrap().get("name").unwrap(),
418 "Server A"
419 );
420 }
421
422 #[test]
423 fn error_to_pipeline_not_found() {
424 let err = MdnsError::RegistrationNotFound("xyz".into());
425 let resp = error_to_pipeline(&err);
426 let json = serde_json::to_value(&resp).unwrap();
427 assert_eq!(json.get("error").unwrap(), "not_found");
428 let msg = json.get("message").unwrap().as_str().unwrap();
429 assert!(msg.contains("xyz"), "message should contain id: {msg}");
430 }
431}