opslag/
server.rs

1use core::net::{IpAddr, SocketAddr};
2
3use crate::dns::{Flags, Message, QClass, QType, Query, Request, Response};
4use crate::time::Time;
5use crate::vec::Vec;
6use crate::writer::Writer;
7use crate::ServiceInfo;
8
9/// A server for broadcasting/discovering peers.
10///
11/// * `QLEN` - Max number of queries in a single mDNS packet. Only used if not **std**.
12///            Typically 4 for SRV, PTR, TXT and A (or AAAA).
13/// * `ALEN` - Max number of answers in a single mDNS packet. Only used if not **std**.
14///            Typically 4 for SRV, PTR, TXT and A (or AAAA).
15/// * `LLEN` - Max number of segments for a parsed Label.
16///            All services have max 4 segments: martin_test._myservice._udp.local.
17/// * `SLEN` - Number of service infos to handle in the [`Server`].
18/// * `LK`   – List size for DNS label compression. 10 is a good value.
19///
20/// Specifying too small QLEN, ALEN, LLEN or SLEN does not make the server fail, but rather
21/// reject messages that can't be parsed.
22///
23/// ```
24/// use opslag::{Server, ServiceInfo};
25///
26/// let info = ServiceInfo::<4>::new(
27///     "_midiriff._udp.local", // name of service
28///     "martin_test",          // instance name, in case multiple services on same host
29///     "mini.local",           // host
30///     [192, 168, 0, 1],       // IP address of host
31///     [255, 255, 255, 0],     // Netmask for the IP
32///     1234,                   // port of service
33///  );
34///
35/// // Max 4 queries
36/// // Max 4 answers
37/// // Max 4 segments in a label.
38/// // 1 handled service
39/// // 10 entries for dns label compression
40/// let server = Server::<4, 4, 4, 1, 10>::new([info].into_iter());
41/// ```
42pub struct Server<
43    'a,
44    const QLEN: usize,
45    const ALEN: usize,
46    const LLEN: usize,
47    const SLEN: usize,
48    const LK: usize,
49> {
50    last_now: Time,
51    services: Vec<ServiceInfo<'a, LLEN>, SLEN>,
52    local_ips: Vec<LocalIp, SLEN>,
53    next_advertise: Time,
54    next_advertise_idx: usize,
55    next_query: Time,
56    next_query_idx: usize,
57    txid_query: u16,
58    next_txid: u16,
59}
60
61#[derive(Clone, Copy, PartialEq, Eq)]
62struct LocalIp {
63    addr: IpAddr,
64    mask: IpAddr,
65}
66
67const ADVERTISE_INTERVAL: u64 = 15_000;
68const QUERY_INTERVAL: u64 = 19_000;
69
70/// How to cast outgoing packets.
71#[derive(Debug)]
72pub enum Cast {
73    /// Send as multicast.
74    Multi {
75        /// Send from this ip address.
76        from: IpAddr,
77    },
78    /// Unicast to specific socket address.
79    Uni {
80        /// Send from this ip address.
81        from: IpAddr,
82        /// Send to this ip address.
83        target: SocketAddr,
84    },
85}
86
87/// Input to [`Server`].
88#[derive(Debug)]
89pub enum Input<'x> {
90    /// A timeout.
91    ///
92    /// It's fine to send timeouts when there is nothing else to do.
93    /// The service expects a timeout for the [`Output::Timeout`] indicated.
94    Timeout(Time),
95
96    /// Some data coming from the network.
97    Packet(&'x [u8], SocketAddr),
98}
99
100/// Output from the [`Server`].
101pub enum Output<'x, const LLEN: usize, const SLEN: usize> {
102    /// A packet to send somewhere.
103    ///
104    /// The data is in the buffer given to [`Server::handle`] and the amount of the
105    /// buffer use is the first argument of the tuple.å
106    Packet(usize, Cast),
107
108    /// Next time the service expects a timeout.
109    ///
110    /// It is fine to send more timeouts before this.
111    Timeout(Time),
112
113    /// The [`Server`] discovered a remote instance of a declared [`ServiceInfo`].
114    Remote(ServiceInfo<'x, LLEN>),
115}
116
117impl<
118        'a,
119        const QLEN: usize,
120        const ALEN: usize,
121        const LLEN: usize,
122        const SLEN: usize,
123        const LK: usize,
124    > Server<'a, QLEN, ALEN, LLEN, SLEN, LK>
125{
126    /// Creates a new server instance.
127    pub fn new(
128        iter: impl Iterator<Item = ServiceInfo<'a, LLEN>>,
129    ) -> Server<'a, QLEN, ALEN, LLEN, SLEN, LK> {
130        let mut services = Vec::new();
131        services.extend(iter);
132
133        let mut local_ips = Vec::new();
134        for s in services.iter() {
135            let loc = LocalIp {
136                addr: s.ip_address(),
137                mask: s.netmask(),
138            };
139            let has_ip = local_ips.iter().any(|l| *l == loc);
140            if !has_ip {
141                // unwrap: this should be fine since local_ips is as long as services.
142                local_ips.push(loc).unwrap();
143            }
144        }
145
146        Server {
147            last_now: Time::from_millis(0),
148            services,
149            local_ips,
150            next_advertise: Time::from_millis(3000),
151            next_advertise_idx: 0,
152            next_query: Time::from_millis(5000),
153            next_query_idx: 0,
154            txid_query: 0,
155            next_txid: 1,
156        }
157    }
158
159    fn poll_timeout(&self) -> Time {
160        self.next_advertise.min(self.next_query)
161    }
162
163    /// Handle some input and produce output.
164    ///
165    /// You can send [`Input::Timeout`] whenenver. The `buffer` is for outgoing packets.
166    /// Upon [`Output::Packet`] the buffer will be filled to some point with data to transmit.
167    pub fn handle<'x>(&mut self, input: Input<'x>, buffer: &mut [u8]) -> Output<'x, LLEN, SLEN> {
168        match input {
169            Input::Timeout(now) => self.handle_timeout(now, buffer),
170            Input::Packet(data, from) => self.handle_packet(data, from, buffer),
171        }
172    }
173
174    fn handle_timeout(&mut self, now: Time, buffer: &mut [u8]) -> Output<'static, LLEN, SLEN> {
175        self.last_now = now;
176
177        if now >= self.next_advertise {
178            let send_from = self.local_ips[self.next_advertise_idx];
179
180            let ret = self.do_advertise(buffer, send_from);
181
182            self.next_advertise_idx += 1;
183
184            if self.next_advertise_idx == self.local_ips.len() {
185                self.next_advertise_idx = 0;
186                self.next_advertise = now + ADVERTISE_INTERVAL;
187            }
188
189            ret
190        } else if now >= self.next_query {
191            let send_from = self.local_ips[self.next_query_idx];
192
193            let ret = self.do_query(buffer, send_from);
194
195            self.next_query_idx += 1;
196
197            if self.next_query_idx == self.local_ips.len() {
198                self.next_query_idx = 0;
199                self.next_query = now + QUERY_INTERVAL;
200            }
201
202            ret
203        } else {
204            Output::Timeout(self.poll_timeout())
205        }
206    }
207
208    fn next_txid(&mut self) -> u16 {
209        let x = self.next_txid;
210        self.next_txid = self.next_txid.wrapping_add(1);
211        x
212    }
213
214    fn do_advertise(&mut self, buffer: &mut [u8], local: LocalIp) -> Output<'static, LLEN, SLEN> {
215        let mut response: Response<QLEN, ALEN, LLEN> = Response {
216            id: 0,
217            flags: Flags::standard_response(),
218            queries: Vec::new(),
219            answers: Vec::new(),
220        };
221
222        let to_consider = self
223            .services
224            .iter()
225            .filter(|s| s.ip_address() == local.addr && s.netmask() == local.mask);
226
227        for service in to_consider {
228            response
229                .answers
230                .extend(service.as_answers(QClass::Multicast));
231        }
232
233        debug!("Advertise response (from {}): {:?}", local.addr, response);
234
235        let mut buf = Writer::<LK>::new(buffer);
236
237        response.serialize(&mut buf);
238
239        Output::Packet(buf.len(), Cast::Multi { from: local.addr })
240    }
241
242    fn do_query(&mut self, buffer: &mut [u8], local: LocalIp) -> Output<'static, LLEN, SLEN> {
243        let mut request: Request<QLEN, LLEN> = Request {
244            id: self.next_txid(),
245            flags: Flags::standard_request(),
246            queries: Vec::new(),
247        };
248
249        self.txid_query = request.id;
250
251        let to_consider = self
252            .services
253            .iter()
254            .filter(|s| s.ip_address() == local.addr && s.netmask() == local.mask);
255
256        for service in to_consider {
257            let query = Query {
258                name: service.service_type().clone(),
259                qtype: QType::PTR,
260                qclass: QClass::IN,
261            };
262            request.queries.push(query).unwrap();
263        }
264
265        debug!("Send request (from {}): {:?}", local.addr, request);
266
267        let mut buf = Writer::<LK>::new(buffer);
268        request.serialize(&mut buf);
269
270        Output::Packet(buf.len(), Cast::Multi { from: local.addr })
271    }
272
273    fn handle_packet<'x>(
274        &mut self,
275        data: &'x [u8],
276        from: SocketAddr,
277        buffer: &mut [u8],
278    ) -> Output<'x, LLEN, SLEN> {
279        match Message::parse(data) {
280            Ok((_, Message::Request(request))) => self.handle_request(request, from, buffer),
281            Ok((_, Message::Response(response))) => self.handle_response(response, from, buffer),
282            Err(_) => Output::Timeout(self.poll_timeout()),
283        }
284    }
285
286    fn handle_request<'x>(
287        &mut self,
288        request: Request<'x, QLEN, LLEN>,
289        from: SocketAddr,
290        buffer: &mut [u8],
291    ) -> Output<'x, LLEN, SLEN> {
292        if request.queries.is_empty() {
293            return Output::Timeout(self.poll_timeout());
294        }
295
296        // Ignore requests from self
297        if request.id == self.txid_query {
298            return Output::Timeout(self.poll_timeout());
299        }
300
301        // We check for empty above
302        let qclass = request.queries[0].qclass;
303
304        let queries = request.queries.iter();
305
306        let mut answers = Vec::new();
307
308        for query in queries {
309            for service in self.services.iter() {
310                if query.qtype == QType::PTR
311                    && &query.name == service.service_type()
312                    && is_same_network(service.ip_address(), service.netmask(), from.ip())
313                {
314                    answers.extend(service.as_answers(qclass));
315                }
316            }
317        }
318
319        if answers.is_empty() {
320            return Output::Timeout(self.poll_timeout());
321        }
322
323        debug!("Incoming request: {:?} {:?}", from, request);
324
325        let response: Response<QLEN, ALEN, LLEN> = Response {
326            id: request.id,
327            flags: Flags::standard_response(),
328            queries: request.queries,
329            answers,
330        };
331
332        debug!("Send response: {:?}", response);
333        let mut buf = Writer::<LK>::new(buffer);
334        response.serialize(&mut buf);
335
336        let send_from = self
337            .local_ips
338            .iter()
339            .find(|l| is_same_network(l.addr, l.mask, from.ip()))
340            // unwrap: is ok because above answers.is_empty() check means we must have had
341            // a match between incoming query and service records.
342            .unwrap()
343            .addr;
344
345        let cast = match qclass {
346            QClass::IN => Cast::Uni {
347                from: send_from,
348                target: from,
349            },
350            _ => Cast::Multi { from: send_from },
351        };
352
353        Output::Packet(buf.len(), cast)
354    }
355
356    fn handle_response<'x>(
357        &mut self,
358        response: Response<'x, QLEN, ALEN, LLEN>,
359        _from: SocketAddr,
360        _buffer: &mut [u8],
361    ) -> Output<'x, LLEN, SLEN> {
362        let mut services = Vec::new();
363
364        trace!("Handle response: {:?} {:?}", _from, response);
365
366        ServiceInfo::from_answers::<SLEN>(&response.answers, &mut services);
367
368        services.retain(|s| is_matching_service(s, &self.services));
369
370        if services.len() > 1 {
371            warn!("More than one service in answers. This is not currently handled");
372        }
373
374        if services.is_empty() {
375            Output::Timeout(self.poll_timeout())
376        } else {
377            Output::Remote(services.remove(0))
378        }
379    }
380}
381
382fn is_same_network(ip: IpAddr, netmask: IpAddr, other: IpAddr) -> bool {
383    match (ip, netmask, other) {
384        (IpAddr::V4(ip), IpAddr::V4(mask), IpAddr::V4(other)) => {
385            (u32::from(ip) & u32::from(mask)) == (u32::from(other) & u32::from(mask))
386        }
387        (IpAddr::V6(ip), IpAddr::V6(mask), IpAddr::V6(other)) => ip
388            .segments()
389            .iter()
390            .zip(mask.segments().iter())
391            .zip(other.segments().iter())
392            .all(|((&ip_seg, &mask_seg), &other_seg)| {
393                (ip_seg & mask_seg) == (other_seg & mask_seg)
394            }),
395        _ => false,
396    }
397}
398
399fn is_matching_service<const LLEN: usize, const SLEN: usize>(
400    s1: &ServiceInfo<'_, LLEN>,
401    services: &Vec<ServiceInfo<'_, LLEN>, SLEN>,
402) -> bool {
403    let mut handled_service = false;
404    let mut is_self = false;
405
406    for s2 in services.iter() {
407        handled_service |= s1.service_type() == s2.service_type();
408
409        is_self |= s1.instance_name() == s2.instance_name()
410            && s1.ip_address() == s2.ip_address()
411            && s1.port() == s2.port();
412    }
413
414    handled_service && !is_self
415}
416
417#[cfg(feature = "defmt")]
418impl defmt::Format for Input<'_> {
419    fn format(&self, fmt: defmt::Formatter) {
420        use crate::format::FormatSocketAddr;
421        match self {
422            Input::Timeout(instant) => {
423                defmt::write!(fmt, "Timeout({:?})", instant);
424            }
425            Input::Packet(data, addr) => {
426                defmt::write!(
427                    fmt,
428                    "Packet([..{} bytes], {:?})",
429                    data.len(),
430                    FormatSocketAddr(*addr)
431                );
432            }
433        }
434    }
435}
436
437#[cfg(feature = "defmt")]
438impl defmt::Format for Cast {
439    fn format(&self, fmt: defmt::Formatter) {
440        use crate::format::{FormatIpAddr, FormatSocketAddr};
441        match self {
442            Cast::Multi { from } => {
443                defmt::write!(fmt, "Multi {{ from:{:?} }}", FormatIpAddr(*from));
444            }
445            Cast::Uni { from, target } => {
446                defmt::write!(
447                    fmt,
448                    "Uni {{ from:{:?}, target:{:?} }}",
449                    FormatIpAddr(*from),
450                    FormatSocketAddr(*target)
451                );
452            }
453        }
454    }
455}