1use core::net::{IpAddr, SocketAddr};
2
3use crate::dns::{Flags, Message, QClass, QType, Query, Request, Response};
4use crate::time::Time;
5use crate::vec::Vec;
6use crate::writer::Writer;
7use crate::ServiceInfo;
8
9pub struct Server<
43 'a,
44 const QLEN: usize,
45 const ALEN: usize,
46 const LLEN: usize,
47 const SLEN: usize,
48 const LK: usize,
49> {
50 last_now: Time,
51 services: Vec<ServiceInfo<'a, LLEN>, SLEN>,
52 local_ips: Vec<LocalIp, SLEN>,
53 next_advertise: Time,
54 next_advertise_idx: usize,
55 next_query: Time,
56 next_query_idx: usize,
57 txid_query: u16,
58 next_txid: u16,
59}
60
61#[derive(Clone, Copy, PartialEq, Eq)]
62struct LocalIp {
63 addr: IpAddr,
64 mask: IpAddr,
65}
66
67const ADVERTISE_INTERVAL: u64 = 15_000;
68const QUERY_INTERVAL: u64 = 19_000;
69
70#[derive(Debug)]
72pub enum Cast {
73 Multi {
75 from: IpAddr,
77 },
78 Uni {
80 from: IpAddr,
82 target: SocketAddr,
84 },
85}
86
87#[derive(Debug)]
89pub enum Input<'x> {
90 Timeout(Time),
95
96 Packet(&'x [u8], SocketAddr),
98}
99
100pub enum Output<'x, const LLEN: usize, const SLEN: usize> {
102 Packet(usize, Cast),
107
108 Timeout(Time),
112
113 Remote(ServiceInfo<'x, LLEN>),
115}
116
117impl<
118 'a,
119 const QLEN: usize,
120 const ALEN: usize,
121 const LLEN: usize,
122 const SLEN: usize,
123 const LK: usize,
124 > Server<'a, QLEN, ALEN, LLEN, SLEN, LK>
125{
126 pub fn new(
128 iter: impl Iterator<Item = ServiceInfo<'a, LLEN>>,
129 ) -> Server<'a, QLEN, ALEN, LLEN, SLEN, LK> {
130 let mut services = Vec::new();
131 services.extend(iter);
132
133 let mut local_ips = Vec::new();
134 for s in services.iter() {
135 let loc = LocalIp {
136 addr: s.ip_address(),
137 mask: s.netmask(),
138 };
139 let has_ip = local_ips.iter().any(|l| *l == loc);
140 if !has_ip {
141 local_ips.push(loc).unwrap();
143 }
144 }
145
146 Server {
147 last_now: Time::from_millis(0),
148 services,
149 local_ips,
150 next_advertise: Time::from_millis(3000),
151 next_advertise_idx: 0,
152 next_query: Time::from_millis(5000),
153 next_query_idx: 0,
154 txid_query: 0,
155 next_txid: 1,
156 }
157 }
158
159 fn poll_timeout(&self) -> Time {
160 self.next_advertise.min(self.next_query)
161 }
162
163 pub fn handle<'x>(&mut self, input: Input<'x>, buffer: &mut [u8]) -> Output<'x, LLEN, SLEN> {
168 match input {
169 Input::Timeout(now) => self.handle_timeout(now, buffer),
170 Input::Packet(data, from) => self.handle_packet(data, from, buffer),
171 }
172 }
173
174 fn handle_timeout(&mut self, now: Time, buffer: &mut [u8]) -> Output<'static, LLEN, SLEN> {
175 self.last_now = now;
176
177 if now >= self.next_advertise {
178 let send_from = self.local_ips[self.next_advertise_idx];
179
180 let ret = self.do_advertise(buffer, send_from);
181
182 self.next_advertise_idx += 1;
183
184 if self.next_advertise_idx == self.local_ips.len() {
185 self.next_advertise_idx = 0;
186 self.next_advertise = now + ADVERTISE_INTERVAL;
187 }
188
189 ret
190 } else if now >= self.next_query {
191 let send_from = self.local_ips[self.next_query_idx];
192
193 let ret = self.do_query(buffer, send_from);
194
195 self.next_query_idx += 1;
196
197 if self.next_query_idx == self.local_ips.len() {
198 self.next_query_idx = 0;
199 self.next_query = now + QUERY_INTERVAL;
200 }
201
202 ret
203 } else {
204 Output::Timeout(self.poll_timeout())
205 }
206 }
207
208 fn next_txid(&mut self) -> u16 {
209 let x = self.next_txid;
210 self.next_txid = self.next_txid.wrapping_add(1);
211 x
212 }
213
214 fn do_advertise(&mut self, buffer: &mut [u8], local: LocalIp) -> Output<'static, LLEN, SLEN> {
215 let mut response: Response<QLEN, ALEN, LLEN> = Response {
216 id: 0,
217 flags: Flags::standard_response(),
218 queries: Vec::new(),
219 answers: Vec::new(),
220 };
221
222 let to_consider = self
223 .services
224 .iter()
225 .filter(|s| s.ip_address() == local.addr && s.netmask() == local.mask);
226
227 for service in to_consider {
228 response
229 .answers
230 .extend(service.as_answers(QClass::Multicast));
231 }
232
233 debug!("Advertise response (from {}): {:?}", local.addr, response);
234
235 let mut buf = Writer::<LK>::new(buffer);
236
237 response.serialize(&mut buf);
238
239 Output::Packet(buf.len(), Cast::Multi { from: local.addr })
240 }
241
242 fn do_query(&mut self, buffer: &mut [u8], local: LocalIp) -> Output<'static, LLEN, SLEN> {
243 let mut request: Request<QLEN, LLEN> = Request {
244 id: self.next_txid(),
245 flags: Flags::standard_request(),
246 queries: Vec::new(),
247 };
248
249 self.txid_query = request.id;
250
251 let to_consider = self
252 .services
253 .iter()
254 .filter(|s| s.ip_address() == local.addr && s.netmask() == local.mask);
255
256 for service in to_consider {
257 let query = Query {
258 name: service.service_type().clone(),
259 qtype: QType::PTR,
260 qclass: QClass::IN,
261 };
262 request.queries.push(query).unwrap();
263 }
264
265 debug!("Send request (from {}): {:?}", local.addr, request);
266
267 let mut buf = Writer::<LK>::new(buffer);
268 request.serialize(&mut buf);
269
270 Output::Packet(buf.len(), Cast::Multi { from: local.addr })
271 }
272
273 fn handle_packet<'x>(
274 &mut self,
275 data: &'x [u8],
276 from: SocketAddr,
277 buffer: &mut [u8],
278 ) -> Output<'x, LLEN, SLEN> {
279 match Message::parse(data) {
280 Ok((_, Message::Request(request))) => self.handle_request(request, from, buffer),
281 Ok((_, Message::Response(response))) => self.handle_response(response, from, buffer),
282 Err(_) => Output::Timeout(self.poll_timeout()),
283 }
284 }
285
286 fn handle_request<'x>(
287 &mut self,
288 request: Request<'x, QLEN, LLEN>,
289 from: SocketAddr,
290 buffer: &mut [u8],
291 ) -> Output<'x, LLEN, SLEN> {
292 if request.queries.is_empty() {
293 return Output::Timeout(self.poll_timeout());
294 }
295
296 if request.id == self.txid_query {
298 return Output::Timeout(self.poll_timeout());
299 }
300
301 let qclass = request.queries[0].qclass;
303
304 let queries = request.queries.iter();
305
306 let mut answers = Vec::new();
307
308 for query in queries {
309 for service in self.services.iter() {
310 if query.qtype == QType::PTR
311 && &query.name == service.service_type()
312 && is_same_network(service.ip_address(), service.netmask(), from.ip())
313 {
314 answers.extend(service.as_answers(qclass));
315 }
316 }
317 }
318
319 if answers.is_empty() {
320 return Output::Timeout(self.poll_timeout());
321 }
322
323 debug!("Incoming request: {:?} {:?}", from, request);
324
325 let response: Response<QLEN, ALEN, LLEN> = Response {
326 id: request.id,
327 flags: Flags::standard_response(),
328 queries: request.queries,
329 answers,
330 };
331
332 debug!("Send response: {:?}", response);
333 let mut buf = Writer::<LK>::new(buffer);
334 response.serialize(&mut buf);
335
336 let send_from = self
337 .local_ips
338 .iter()
339 .find(|l| is_same_network(l.addr, l.mask, from.ip()))
340 .unwrap()
343 .addr;
344
345 let cast = match qclass {
346 QClass::IN => Cast::Uni {
347 from: send_from,
348 target: from,
349 },
350 _ => Cast::Multi { from: send_from },
351 };
352
353 Output::Packet(buf.len(), cast)
354 }
355
356 fn handle_response<'x>(
357 &mut self,
358 response: Response<'x, QLEN, ALEN, LLEN>,
359 _from: SocketAddr,
360 _buffer: &mut [u8],
361 ) -> Output<'x, LLEN, SLEN> {
362 let mut services = Vec::new();
363
364 trace!("Handle response: {:?} {:?}", _from, response);
365
366 ServiceInfo::from_answers::<SLEN>(&response.answers, &mut services);
367
368 services.retain(|s| is_matching_service(s, &self.services));
369
370 if services.len() > 1 {
371 warn!("More than one service in answers. This is not currently handled");
372 }
373
374 if services.is_empty() {
375 Output::Timeout(self.poll_timeout())
376 } else {
377 Output::Remote(services.remove(0))
378 }
379 }
380}
381
382fn is_same_network(ip: IpAddr, netmask: IpAddr, other: IpAddr) -> bool {
383 match (ip, netmask, other) {
384 (IpAddr::V4(ip), IpAddr::V4(mask), IpAddr::V4(other)) => {
385 (u32::from(ip) & u32::from(mask)) == (u32::from(other) & u32::from(mask))
386 }
387 (IpAddr::V6(ip), IpAddr::V6(mask), IpAddr::V6(other)) => ip
388 .segments()
389 .iter()
390 .zip(mask.segments().iter())
391 .zip(other.segments().iter())
392 .all(|((&ip_seg, &mask_seg), &other_seg)| {
393 (ip_seg & mask_seg) == (other_seg & mask_seg)
394 }),
395 _ => false,
396 }
397}
398
399fn is_matching_service<const LLEN: usize, const SLEN: usize>(
400 s1: &ServiceInfo<'_, LLEN>,
401 services: &Vec<ServiceInfo<'_, LLEN>, SLEN>,
402) -> bool {
403 let mut handled_service = false;
404 let mut is_self = false;
405
406 for s2 in services.iter() {
407 handled_service |= s1.service_type() == s2.service_type();
408
409 is_self |= s1.instance_name() == s2.instance_name()
410 && s1.ip_address() == s2.ip_address()
411 && s1.port() == s2.port();
412 }
413
414 handled_service && !is_self
415}
416
417#[cfg(feature = "defmt")]
418impl defmt::Format for Input<'_> {
419 fn format(&self, fmt: defmt::Formatter) {
420 use crate::format::FormatSocketAddr;
421 match self {
422 Input::Timeout(instant) => {
423 defmt::write!(fmt, "Timeout({:?})", instant);
424 }
425 Input::Packet(data, addr) => {
426 defmt::write!(
427 fmt,
428 "Packet([..{} bytes], {:?})",
429 data.len(),
430 FormatSocketAddr(*addr)
431 );
432 }
433 }
434 }
435}
436
437#[cfg(feature = "defmt")]
438impl defmt::Format for Cast {
439 fn format(&self, fmt: defmt::Formatter) {
440 use crate::format::{FormatIpAddr, FormatSocketAddr};
441 match self {
442 Cast::Multi { from } => {
443 defmt::write!(fmt, "Multi {{ from:{:?} }}", FormatIpAddr(*from));
444 }
445 Cast::Uni { from, target } => {
446 defmt::write!(
447 fmt,
448 "Uni {{ from:{:?}, target:{:?} }}",
449 FormatIpAddr(*from),
450 FormatSocketAddr(*target)
451 );
452 }
453 }
454 }
455}