Skip to main content

erbium/dhcp/
mod.rs

1/*   Copyright 2024 Perry Lorier
2 *
3 *  Licensed under the Apache License, Version 2.0 (the "License");
4 *  you may not use this file except in compliance with the License.
5 *  You may obtain a copy of the License at
6 *
7 *      http://www.apache.org/licenses/LICENSE-2.0
8 *
9 *  Unless required by applicable law or agreed to in writing, software
10 *  distributed under the License is distributed on an "AS IS" BASIS,
11 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 *  See the License for the specific language governing permissions and
13 *  limitations under the License.
14 *
15 *  SPDX-License-Identifier: Apache-2.0
16 *
17 *  Main DHCP Code.
18 */
19use std::collections;
20use std::convert::TryInto as _;
21use std::net;
22use std::ops::Sub as _;
23use std::sync::Arc;
24use tokio::sync;
25
26use crate::dhcp::dhcppkt::Serialise;
27use erbium_net::addr::{NetAddr, ToNetAddr, UNSPECIFIED4, WithPort as _};
28use erbium_net::packet;
29use erbium_net::raw;
30use erbium_net::udp;
31
32pub mod config;
33pub mod dhcppkt;
34pub mod pool;
35#[cfg(test)]
36mod test;
37
38type UdpSocket = udp::UdpSocket;
39type ServerIds = std::collections::HashSet<net::Ipv4Addr>;
40pub type SharedServerIds = Arc<sync::Mutex<ServerIds>>;
41
42lazy_static::lazy_static! {
43    static ref DHCP_RX_PACKETS: prometheus::IntCounter =
44        prometheus::register_int_counter!("dhcp_received_packets", "Number of DHCP packets received")
45            .unwrap();
46    static ref DHCP_TX_PACKETS: prometheus::IntCounter =
47        prometheus::register_int_counter!("dhcp_sent_packets", "Number of DHCP packets sent")
48            .unwrap();
49    static ref DHCP_ERRORS: prometheus::IntCounterVec = prometheus::register_int_counter_vec!(
50        "dhcp_errors",
51        "Counts of reasons that replies cannot be sent",
52        &["reason"]
53    )
54    .unwrap();
55    static ref DHCP_ALLOCATIONS: prometheus::IntCounterVec = prometheus::register_int_counter_vec!(
56        "dhcp_allocations",
57        "Counts of address allocation types",
58        &["reason"]
59    )
60    .unwrap();
61    static ref DHCP_ACTIVE_LEASES: prometheus::IntGauge = prometheus::register_int_gauge!(
62        "dhcp_active_leases",
63        "Counts of leases that are currently in use"
64    )
65    .unwrap();
66    static ref DHCP_EXPIRED_LEASES: prometheus::IntGauge = prometheus::register_int_gauge!(
67        "dhcp_expired_leases",
68        "Counts of leases that are currently expired"
69    )
70    .unwrap();
71}
72
73#[derive(Debug, PartialEq, Eq)]
74pub enum DhcpError {
75    UnknownMessageType(dhcppkt::MessageType),
76    NoLeasesConfigured,
77    ParseError(dhcppkt::ParseError),
78    PoolError(pool::Error),
79    InternalError(String),
80    OtherServer(std::net::Ipv4Addr),
81    NoPolicyConfigured,
82}
83
84impl std::error::Error for DhcpError {
85    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
86        None
87    }
88}
89
90impl std::fmt::Display for DhcpError {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        match self {
93            DhcpError::UnknownMessageType(m) => write!(f, "Unknown Message Type: {:?}", m),
94            DhcpError::NoLeasesConfigured => write!(f, "No Leases Configured"),
95            DhcpError::ParseError(e) => write!(f, "Parse Error: {:?}", e),
96            DhcpError::InternalError(e) => write!(f, "Internal Error: {:?}", e),
97            DhcpError::OtherServer(s) => write!(f, "Packet for a different DHCP server: {}", s),
98            DhcpError::NoPolicyConfigured => write!(f, "No policy configured for client"),
99            DhcpError::PoolError(p) => write!(f, "Pool Error: {:?}", p),
100        }
101    }
102}
103
104impl DhcpError {
105    const fn get_variant_name(&self) -> &'static str {
106        use DhcpError::*;
107        match self {
108            UnknownMessageType(_) => "UNKNOWN_MESSAGE_TYPE",
109            NoLeasesConfigured => "NO_LEASES_CONFIGURED",
110            ParseError(_) => "PARSE_ERROR",
111            InternalError(_) => "INTERNAL_ERROR",
112            OtherServer(_) => "OTHER_SERVER",
113            NoPolicyConfigured => "NO_POLICY",
114            PoolError(pool::Error::NoAssignableAddress) => "NO_ADDRESS",
115            PoolError(pool::Error::RequestedAddressInUse) => "ADDRESS_IN_USE",
116            PoolError(_) => "INTERNAL_POOL_ERROR",
117        }
118    }
119}
120
121#[derive(Debug)]
122pub struct DHCPRequest {
123    /// The DHCP request packet.
124    pub pkt: dhcppkt::Dhcp,
125    /// The IP address that the request was received on.
126    pub serverip: std::net::Ipv4Addr,
127    /// The interface index that the request was received on.
128    pub ifindex: u32,
129    pub if_mtu: Option<u32>,
130    pub if_router: Option<std::net::Ipv4Addr>,
131}
132
133#[cfg(test)]
134impl std::default::Default for DHCPRequest {
135    fn default() -> Self {
136        DHCPRequest {
137            pkt: dhcppkt::Dhcp {
138                op: dhcppkt::OP_BOOTREQUEST,
139                htype: dhcppkt::HWTYPE_ETHERNET,
140                hlen: 6,
141                hops: 0,
142                xid: 0,
143                secs: 0,
144                flags: 0,
145                ciaddr: net::Ipv4Addr::UNSPECIFIED,
146                yiaddr: net::Ipv4Addr::UNSPECIFIED,
147                siaddr: net::Ipv4Addr::UNSPECIFIED,
148                giaddr: net::Ipv4Addr::UNSPECIFIED,
149                chaddr: vec![
150                    0x00, 0x00, 0x5E, 0x00, 0x53,
151                    0x00, /* Reserved for documentation, per RFC7042 */
152                ],
153                sname: vec![],
154                file: vec![],
155                options: Default::default(),
156            },
157            serverip: "0.0.0.0".parse().unwrap(),
158            ifindex: 0,
159            if_mtu: None,
160            if_router: None,
161        }
162    }
163}
164
165#[derive(Eq, PartialEq, Debug)]
166enum PolicyMatch {
167    NoMatch,
168    MatchFailed,
169    MatchSucceeded,
170}
171
172fn check_policy(req: &DHCPRequest, policy: &config::Policy) -> PolicyMatch {
173    let mut outcome = PolicyMatch::NoMatch;
174    //if let Some(policy.match_interface ...
175    if policy.match_all {
176        outcome = PolicyMatch::MatchSucceeded;
177    }
178    if let Some(match_chaddr) = &policy.match_chaddr {
179        outcome = PolicyMatch::MatchSucceeded;
180        if req.pkt.chaddr != *match_chaddr {
181            return PolicyMatch::MatchFailed;
182        }
183    }
184    if let Some(match_subnet) = &policy.match_subnet {
185        outcome = PolicyMatch::MatchSucceeded;
186        if !match_subnet.contains(req.serverip) {
187            return PolicyMatch::MatchFailed;
188        }
189    }
190
191    for (k, m) in policy.match_other.iter() {
192        if match (m, req.pkt.options.other.get(k)) {
193            (None, None) => true, /* Required that option doesn't exist, option doesn't exist */
194            (None, Some(_)) => false, /* Required that option doesn't exist, option exists */
195            (Some(mat), Some(opt)) if &mat.as_bytes() == opt => true, /* Required it has value, and matches */
196            (Some(_), Some(_)) => false, /* Required it has a value, option has some other value */
197            (Some(_), None) => false, /* Required that option has a value, option doesn't exist */
198        } {
199            /* If at least one thing matches, then this is a MatchSucceded */
200            outcome = PolicyMatch::MatchSucceeded;
201        } else {
202            /* If any fail, then fail everything */
203            return PolicyMatch::MatchFailed;
204        }
205    }
206    outcome
207}
208
209fn apply_policy(req: &DHCPRequest, policy: &config::Policy, response: &mut Response) -> bool {
210    /* Check if our policy should match.
211     */
212    match check_policy(req, policy) {
213        /* If the match failed, do not apply. */
214        PolicyMatch::MatchFailed => return false,
215        /* If there are no matches applied for this policy, check if any subpolicies match, and if
216         * so, apply this policy too, otherwise fail.
217         */
218        PolicyMatch::NoMatch => {
219            if !check_policies(req, &policy.policies) {
220                return false;
221            }
222        }
223        /* If there were matchers, and we matched them all, then continue with applying the policy.
224         */
225        PolicyMatch::MatchSucceeded => (),
226    }
227
228    /* If there are addresses provided here, override any from the parent */
229    if let Some(address) = &policy.apply_address {
230        response.address = Some(address.clone()); /* HELP: I tried to make the lifetimes worked, and failed */
231    }
232
233    /* Now get the list of parameters we will apply from the parameter list from the client.
234     */
235    // TODO: This should probably just be a u128 bitvector
236    let pl: std::collections::HashSet<
237        dhcppkt::DhcpOption,
238        std::collections::hash_map::RandomState,
239    > = req
240        .pkt
241        .options
242        .get_option::<Vec<u8>>(&dhcppkt::OPTION_PARAMLIST)
243        .unwrap_or_default()
244        .iter()
245        .copied()
246        .map(dhcppkt::DhcpOption::from)
247        .collect();
248
249    for (k, v) in &policy.apply_other {
250        if pl.contains(k) {
251            response.options.mutate_option(k, v.as_ref());
252        }
253    }
254
255    /* And check to see if a subpolicy also matches */
256    apply_policies(req, &policy.policies, response);
257
258    /* Some of the defaults depend on what other options end up being set, so apply them here. */
259    if let Some(subnet) = &policy.match_subnet {
260        if pl.contains(&dhcppkt::OPTION_NETMASK) {
261            response
262                .options
263                .mutate_option_default(&dhcppkt::OPTION_NETMASK, &subnet.netmask());
264        }
265        if pl.contains(&dhcppkt::OPTION_BROADCAST) {
266            response
267                .options
268                .mutate_option_default(&dhcppkt::OPTION_BROADCAST, &subnet.broadcast());
269        }
270    }
271
272    true
273}
274
275fn check_policies(req: &DHCPRequest, policies: &[config::Policy]) -> bool {
276    for policy in policies {
277        match check_policy(req, policy) {
278            PolicyMatch::MatchSucceeded => return true,
279            PolicyMatch::MatchFailed => continue,
280            PolicyMatch::NoMatch => {
281                if check_policies(req, &policy.policies) {
282                    return true;
283                } else {
284                    continue;
285                }
286            }
287        }
288    }
289    false
290}
291
292fn apply_policies(req: &DHCPRequest, policies: &[config::Policy], response: &mut Response) -> bool {
293    for policy in policies {
294        if apply_policy(req, policy, response) {
295            return true;
296        }
297    }
298    false
299}
300
301#[derive(Default, Clone)]
302struct ResponseOptions {
303    /* Options can be unset (not specified), set to "None" (do not send), or set to a specific
304     * value.
305     */
306    option: collections::HashMap<dhcppkt::DhcpOption, Option<Vec<u8>>>,
307}
308
309impl ResponseOptions {
310    fn set_raw_option(mut self, option: &dhcppkt::DhcpOption, value: &[u8]) -> Self {
311        self.option.insert(*option, Some(value.to_vec()));
312        self
313    }
314
315    fn set_option<T: dhcppkt::Serialise>(self, option: &dhcppkt::DhcpOption, value: &T) -> Self {
316        let mut v = Vec::new();
317        value.serialise(&mut v);
318        self.set_raw_option(option, &v)
319    }
320
321    pub fn mutate_option<T: dhcppkt::Serialise>(
322        &mut self,
323        option: &dhcppkt::DhcpOption,
324        maybe_value: Option<&T>,
325    ) {
326        match maybe_value {
327            Some(value) => {
328                let mut v = Vec::new();
329                value.serialise(&mut v);
330                self.option.insert(*option, Some(v));
331            }
332            None => {
333                self.option.insert(*option, None);
334            }
335        }
336    }
337
338    pub fn mutate_option_default<T: dhcppkt::Serialise>(
339        &mut self,
340        option: &dhcppkt::DhcpOption,
341        value: &T,
342    ) {
343        if !self.option.contains_key(option) {
344            self.mutate_option(option, Some(value));
345        }
346    }
347
348    pub fn to_options(&self) -> dhcppkt::DhcpOptions {
349        let mut opt = dhcppkt::DhcpOptions::default();
350        for (k, v) in &self.option {
351            if let Some(d) = v {
352                opt.other.insert(*k, d.to_vec());
353            }
354        }
355        opt
356    }
357}
358
359#[derive(Default)]
360struct Response {
361    options: ResponseOptions,
362    address: Option<pool::PoolAddresses>,
363    minlease: Option<std::time::Duration>,
364    maxlease: Option<std::time::Duration>,
365}
366
367fn handle_discover(
368    pools: &mut pool::Pool,
369    req: &DHCPRequest,
370    _serverids: &ServerIds,
371    base: &[config::Policy],
372    conf: &super::config::Config,
373) -> Result<dhcppkt::Dhcp, DhcpError> {
374    /* Build the default response we are about to reply with, it will be filled in later */
375    let mut response: Response = Response {
376        options: ResponseOptions::default()
377            .set_option(&dhcppkt::OPTION_MSGTYPE, &dhcppkt::DHCPOFFER)
378            .set_option(&dhcppkt::OPTION_SERVERID, &req.serverip),
379        ..Default::default()
380    };
381
382    /* Now attempt to apply all the policies.*/
383    let base_policy = apply_policies(req, base, &mut response);
384    let conf_policy = apply_policies(req, &conf.dhcp.policies, &mut response);
385    if !base_policy && !conf_policy {
386        /* If none of the policies applied at all, then provide a warning back to the caller */
387        Err(DhcpError::NoPolicyConfigured)
388    } else if let Some(addresses) = response.address {
389        /* At least one policy matched, and provided addresses.  So now go allocate an address */
390        let mut raw_options = Vec::new();
391        req.pkt.options.serialise(&mut raw_options);
392        match pools.allocate_address(
393            &req.pkt.get_client_id(),
394            req.pkt.options.get_address_request(),
395            &addresses,
396            response.minlease.unwrap_or(pool::DEFAULT_MIN_LEASE),
397            response.maxlease.unwrap_or(pool::DEFAULT_MAX_LEASE),
398            &raw_options,
399        ) {
400            /* Now we have an address, build the reply */
401            Ok(lease) => {
402                DHCP_ALLOCATIONS
403                    .with_label_values(&[&format!("{:?}", lease.lease_type)])
404                    .inc();
405                log::info!(
406                    "Allocated Lease: {} for {:?} ({:?})",
407                    lease.ip,
408                    lease.expire,
409                    lease.lease_type
410                );
411
412                Ok(dhcppkt::Dhcp {
413                    op: dhcppkt::OP_BOOTREPLY,
414                    htype: dhcppkt::HWTYPE_ETHERNET,
415                    hlen: 6,
416                    hops: 0,
417                    xid: req.pkt.xid,
418                    secs: 0,
419                    flags: req.pkt.flags,
420                    ciaddr: net::Ipv4Addr::UNSPECIFIED,
421                    yiaddr: lease.ip,
422                    siaddr: net::Ipv4Addr::UNSPECIFIED,
423                    giaddr: req.pkt.giaddr,
424                    chaddr: req.pkt.chaddr.clone(),
425                    sname: vec![],
426                    file: vec![],
427                    options: response
428                        .options
429                        .clone()
430                        .set_option(&dhcppkt::OPTION_SERVERID, &req.serverip)
431                        .to_options(),
432                })
433            }
434            /* Some error occurred, document it. */
435            Err(e) => Err(DhcpError::PoolError(e)),
436        }
437    } else {
438        /* There were no addresses assigned to this match */
439        Err(DhcpError::NoLeasesConfigured)
440    }
441}
442
443fn handle_request(
444    pools: &mut pool::Pool,
445    req: &DHCPRequest,
446    serverids: &ServerIds,
447    base: &[config::Policy],
448    conf: &super::config::Config,
449) -> Result<dhcppkt::Dhcp, DhcpError> {
450    if let Some(si) = req.pkt.options.get_serverid()
451        && !serverids.contains(&si)
452    {
453        return Err(DhcpError::OtherServer(si));
454    }
455    let mut response: Response = Response {
456        options: ResponseOptions::default()
457            .set_option(&dhcppkt::OPTION_MSGTYPE, &dhcppkt::DHCPOFFER)
458            .set_option(&dhcppkt::OPTION_SERVERID, &req.serverip),
459        ..Default::default()
460    };
461    let base_policy = apply_policies(req, base, &mut response);
462    let conf_policy = apply_policies(req, &conf.dhcp.policies, &mut response);
463    if !base_policy && !conf_policy {
464        Err(DhcpError::NoPolicyConfigured)
465    } else if let Some(addresses) = response.address {
466        let mut raw_options = Vec::new();
467        req.pkt.options.serialise(&mut raw_options);
468        match pools.allocate_address(
469            &req.pkt.get_client_id(),
470            if !req.pkt.ciaddr.is_unspecified() {
471                Some(req.pkt.ciaddr)
472            } else {
473                req.pkt.options.get_address_request()
474            },
475            &addresses,
476            response.minlease.unwrap_or(pool::DEFAULT_MIN_LEASE),
477            response.maxlease.unwrap_or(pool::DEFAULT_MAX_LEASE),
478            &raw_options,
479        ) {
480            Ok(lease) => {
481                DHCP_ALLOCATIONS
482                    .with_label_values(&[&format!("{:?}", lease.lease_type)])
483                    .inc();
484                log::info!(
485                    "Allocated Lease: {} for {:?} ({:?})",
486                    lease.ip,
487                    lease.expire,
488                    lease.lease_type
489                );
490                Ok(dhcppkt::Dhcp {
491                    op: dhcppkt::OP_BOOTREPLY,
492                    htype: dhcppkt::HWTYPE_ETHERNET,
493                    hlen: 6,
494                    hops: 0,
495                    xid: req.pkt.xid,
496                    secs: 0,
497                    flags: req.pkt.flags,
498                    ciaddr: req.pkt.ciaddr,
499                    yiaddr: lease.ip,
500                    siaddr: net::Ipv4Addr::UNSPECIFIED,
501                    giaddr: req.pkt.giaddr,
502                    chaddr: req.pkt.chaddr.clone(),
503                    sname: vec![],
504                    file: vec![],
505                    options: response
506                        .options
507                        .set_option(&dhcppkt::OPTION_MSGTYPE, &dhcppkt::DHCPACK)
508                        .set_option(
509                            &dhcppkt::OPTION_SERVERID,
510                            &req.pkt.options.get_serverid().unwrap_or(req.serverip),
511                        )
512                        .set_option(&dhcppkt::OPTION_LEASETIME, &(lease.expire.as_secs() as u32))
513                        .to_options(),
514                })
515            }
516            Err(e) => Err(DhcpError::PoolError(e)),
517        }
518    } else {
519        Err(DhcpError::NoLeasesConfigured)
520    }
521}
522
523fn format_mac(v: &[u8]) -> String {
524    v.iter()
525        .map(|b| format!("{:0>2x}", b))
526        .collect::<Vec<String>>()
527        .join(":")
528}
529
530fn format_client(req: &dhcppkt::Dhcp) -> String {
531    format!(
532        "{} ({})",
533        format_mac(&req.chaddr),
534        String::from_utf8_lossy(
535            &req.options
536                .get_option::<Vec<u8>>(&dhcppkt::OPTION_HOSTNAME)
537                .unwrap_or_default()
538        ),
539    )
540}
541
542fn log_options(req: &dhcppkt::Dhcp) {
543    log::info!(
544        "{}: Options: {}",
545        format_client(req),
546        req.options
547            .other
548            .iter()
549            // We already decode MSGTYPE and PARAMLIST elsewhere, so don't try and decode
550            // them here.  It just leads to confusing looking messages.
551            .filter(|(k, _)| **k != dhcppkt::OPTION_MSGTYPE && **k != dhcppkt::OPTION_PARAMLIST)
552            .map(|(k, v)| format!(
553                "{k}({})",
554                k.get_type()
555                    .and_then(|x| x.decode(v))
556                    .map(|x| format!("{}", x))
557                    .unwrap_or_else(|| "<decode-failed>".into())
558            ))
559            .collect::<Vec<String>>()
560            .join(" "),
561    );
562}
563
564async fn log_pkt(request: &DHCPRequest, netinfo: &erbium_net::netinfo::SharedNetInfo) {
565    use std::fmt::Write as _;
566    let mut s = "".to_string();
567    write!(
568        s,
569        "{}: {} on {}",
570        format_client(&request.pkt),
571        request
572            .pkt
573            .options
574            .get_messagetype()
575            .map(|x| x.to_string())
576            .unwrap_or_else(|| "[unknown]".into()),
577        netinfo.get_safe_name_by_ifidx(request.ifindex).await
578    )
579    .unwrap();
580    if !request.serverip.is_unspecified() {
581        write!(s, " ({})", request.serverip).unwrap();
582    }
583    if !request.pkt.ciaddr.is_unspecified() {
584        write!(s, ", using {}", request.pkt.ciaddr).unwrap();
585    }
586    if !request.pkt.giaddr.is_unspecified() {
587        write!(
588            s,
589            ", relayed via {} hops from {}",
590            request.pkt.hops, request.pkt.giaddr
591        )
592        .unwrap();
593    }
594    log::info!("{}", s);
595    log_options(&request.pkt);
596    log::info!(
597        "{}: Requested: {}",
598        format_client(&request.pkt),
599        request
600            .pkt
601            .options
602            .get_option::<Vec<u8>>(&dhcppkt::OPTION_PARAMLIST)
603            .map(|v| v
604                .iter()
605                .map(|&x| dhcppkt::DhcpOption::new(x))
606                .map(|o| o.to_string())
607                .collect::<Vec<String>>()
608                .join(" "))
609            .unwrap_or_else(|| "<none>".into())
610    );
611}
612
613/// Produce a default configuration.
614/// This builds a configuration that would look like:
615/// ```yaml
616/// # Always match the base configuration
617/// match-all: true
618/// # For top level settings, apply them.
619/// apply-dns-servers: [ip4s]
620/// apply-dns-search: [domains]
621/// apply-captive-portal: captiveportal
622/// policies:
623///  # For each IPv4 prefix provided in the addresses top level config:
624///  - match-subnet: prefix4
625///    apply-subnet: prefix4 # with the current requestip removed.
626/// ```
627pub fn build_default_config(conf: &crate::config::Config, request: &DHCPRequest) -> config::Policy {
628    let mut default_policy = config::Policy {
629        match_all: true, /* We always want this policy to match. */
630        ..Default::default()
631    };
632    /* Add the default options from the top level configuration */
633    default_policy.apply_other.insert(
634        dhcppkt::OPTION_DOMAINSERVER,
635        Some(dhcppkt::DhcpOptionTypeValue::IpList(
636            conf.dns_servers
637                .iter()
638                .filter_map(|ip| match ip {
639                    ip if *ip == config::INTERFACE4 => Some(request.serverip),
640                    std::net::IpAddr::V4(ip4) => Some(*ip4),
641                    _ => None,
642                })
643                .collect(),
644        )),
645    );
646    default_policy.apply_other.insert(
647        dhcppkt::OPTION_DOMAINSEARCH,
648        Some(dhcppkt::DhcpOptionTypeValue::DomainList(
649            conf.dns_search.clone(),
650        )),
651    );
652    default_policy.apply_other.insert(
653        dhcppkt::OPTION_CAPTIVEPORTAL,
654        conf.captive_portal
655            .clone()
656            .map(dhcppkt::DhcpOptionTypeValue::String),
657    );
658    /* Now build some sub policies, for each address range */
659    let all_addrs = conf.dhcp.get_all_used_addresses();
660    default_policy.policies = conf
661        .addresses
662        .iter()
663        .filter_map(|prefix| {
664            if let super::config::Prefix::V4(p4) = prefix {
665                use crate::config::Match as _;
666                use crate::config::PrefixOps as _;
667                let subnet = erbium_net::Ipv4Subnet::new(p4.network(), p4.prefixlen).ok()?;
668                let mut ret = config::Policy {
669                    match_subnet: Some(subnet),
670                    apply_address: Some(
671                        (1..((1 << (32 - p4.prefixlen)) - 2))
672                            .map(|offset| (u32::from(subnet.network()) + offset).into())
673                            // TODO: This removes one IP from the list, it should also remove any
674                            // others found on the local machine.  Probably fine for now, but
675                            // likely to cause confusion in the future.
676                            .filter(|ip4| *ip4 != request.serverip)
677                            .collect::<pool::PoolAddresses>()
678                            .sub(&all_addrs),
679                    ),
680                    ..Default::default()
681                };
682                /* If this is the interface the request is coming in, then we can do extra stuff */
683                if p4.contains(request.serverip) {
684                    // Add the MTU
685                    // TODO: Perhaps don't send it if it's default?
686                    if let Some(mtu) = request.if_mtu {
687                        ret.apply_other.insert(
688                            dhcppkt::OPTION_MTUIF,
689                            Some(dhcppkt::DhcpOptionTypeValue::U16(mtu as u16)),
690                        );
691                    }
692                    // Add the default route.
693                    if let Some(route) = request.if_router {
694                        ret.apply_other.insert(
695                            dhcppkt::OPTION_ROUTERADDR,
696                            Some(dhcppkt::DhcpOptionTypeValue::Ip(route)),
697                        );
698                    }
699                }
700                Some(ret)
701            } else {
702                None
703            }
704        })
705        .collect();
706    default_policy
707}
708
709pub fn handle_pkt(
710    pools: &mut pool::Pool,
711    request: &DHCPRequest,
712    serverids: ServerIds,
713    conf: &super::config::Config,
714) -> Result<dhcppkt::Dhcp, DhcpError> {
715    match request.pkt.options.get_messagetype() {
716        Some(dhcppkt::DHCPDISCOVER) => {
717            let base = [build_default_config(conf, request)];
718            handle_discover(pools, request, &serverids, &base, conf)
719        }
720        Some(dhcppkt::DHCPREQUEST) => {
721            let base = [build_default_config(conf, request)];
722            handle_request(pools, request, &serverids, &base, conf)
723        }
724        Some(x) => Err(DhcpError::UnknownMessageType(x)),
725        None => Err(DhcpError::ParseError(dhcppkt::ParseError::InvalidPacket)),
726    }
727}
728
729async fn send_raw(raw: Arc<raw::RawSocket>, buf: &[u8], intf: i32) -> Result<(), std::io::Error> {
730    DHCP_TX_PACKETS.inc();
731    raw.send_msg(
732        buf,
733        &raw::ControlMessage::new(),
734        raw::MsgFlags::empty(),
735        Some(
736            &erbium_net::addr::linkaddr_for_ifindex(
737                intf.try_into().unwrap(), /* TODO: Push IfIndex type back through callstack */
738            )
739            .to_net_addr(),
740        ),
741    )
742    .await
743    .map(|_| ())
744}
745
746async fn get_serverids(s: &SharedServerIds) -> ServerIds {
747    s.lock().await.clone()
748}
749
750fn to_array(mac: &[u8]) -> Option<[u8; 6]> {
751    mac[0..6].try_into().ok()
752}
753
754enum RunError {
755    ListenError(std::io::Error),
756    RecvError(std::io::Error),
757    Io(std::io::Error),
758    PoolError(pool::Error),
759}
760
761impl std::fmt::Display for RunError {
762    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
763        match self {
764            RunError::Io(e) => write!(f, "I/O Error in DHCP: {e}"),
765            RunError::PoolError(e) => write!(f, "DHCP Pool Error: {e}"),
766            RunError::ListenError(e) => write!(f, "Failed to listen on DHCP: {e}"),
767            RunError::RecvError(e) => write!(f, "Failed to receive a packet for DHCP: {e}"),
768        }
769    }
770}
771
772pub struct DhcpService {
773    netinfo: erbium_net::netinfo::SharedNetInfo,
774    conf: crate::config::SharedConfig,
775    rawsock: std::sync::Arc<erbium_net::raw::RawSocket>,
776    pool: std::sync::Arc<sync::Mutex<pool::Pool>>,
777    serverids: SharedServerIds,
778    listener: UdpSocket,
779}
780
781impl DhcpService {
782    async fn recvdhcp(&self, pkt: &[u8], src: NetAddr, intf: u32) {
783        let raw = self.rawsock.clone();
784        /* First, lets find the various metadata IP addresses */
785        let ip4 = *src.as_sockaddr_in().unwrap();
786        let optional_dst = self.netinfo.get_ipv4_by_ifidx(intf).await;
787        if optional_dst.is_none() {
788            log::warn!(
789                "No IPv4 found on interface {}",
790                self.netinfo.get_safe_name_by_ifidx(intf).await
791            );
792            DHCP_ERRORS
793                .with_label_values(&["NO_IPV4_ON_INTERFACE"])
794                .inc();
795            return;
796        }
797
798        /* Now lets decode the packet, and if it fails decode, fail the function early */
799        let req = match dhcppkt::parse(pkt) {
800            Err(e) => {
801                log::warn!("Failed to parse packet: {}", e);
802                DHCP_ERRORS.with_label_values(&[e.get_variant_name()]).inc();
803                return;
804            }
805            Ok(req) => req,
806        };
807
808        /* Log what we've got */
809        let if_mtu = self.netinfo.get_mtu_by_ifidx(intf).await;
810        let if_router = match self.netinfo.get_ipv4_default_route().await {
811            /* If the default route points out a different interface, then this is the default route */
812            Some((_, Some(rtridx))) if rtridx != intf => Some(optional_dst.unwrap()),
813            /* If it's the same interface, then the default router should be the nexthop */
814            Some((Some(nexthop), Some(rtridx))) if rtridx == intf => Some(nexthop),
815            _ => None,
816        };
817
818        let request = DHCPRequest {
819            pkt: req,
820            serverip: optional_dst.unwrap(),
821            ifindex: intf,
822            if_mtu,
823            if_router,
824        };
825        log_pkt(&request, &self.netinfo).await;
826
827        /* Now, lets process the packet we've found */
828        let reply;
829        {
830            /* Limit the amount of time we have these locked to just handling the packet */
831            let mut pool = self.pool.lock().await;
832            let lockedconf = self.conf.read().await;
833
834            reply = match handle_pkt(
835                &mut pool,
836                &request,
837                get_serverids(&self.serverids).await,
838                &lockedconf,
839            ) {
840                Err(e) => {
841                    log::warn!(
842                        "{}: Failed to handle {}: {}",
843                        format_client(&request.pkt),
844                        request
845                            .pkt
846                            .options
847                            .get_messagetype()
848                            .map(|x| x.to_string())
849                            .unwrap_or_else(|| "packet".into()),
850                        e
851                    );
852                    DHCP_ERRORS.with_label_values(&[e.get_variant_name()]).inc();
853                    return;
854                }
855                Ok(r) => r,
856            };
857        }
858
859        /* Now, we should have a packet ready to send */
860        /* First, if we're claiming to be particular IP, we should remember that as an IP that is one
861         * of ours
862         */
863        if let Some(si) = reply.options.get_serverid() {
864            self.serverids.lock().await.insert(si);
865        }
866
867        /* Log what we're sending */
868        log::info!(
869            "{}: Sending {} on {} with {} for {}",
870            format_client(&reply),
871            reply
872                .options
873                .get_messagetype()
874                .map(|x| x.to_string())
875                .unwrap_or_else(|| "[unknown]".into()),
876            self.netinfo
877                .get_name_by_ifidx(intf)
878                .await
879                .unwrap_or_else(|| "<unknown if>".into()),
880            reply.yiaddr,
881            reply
882                .options
883                .get_option::<u32>(&dhcppkt::OPTION_LEASETIME)
884                .unwrap_or(0)
885        );
886        log_options(&reply);
887
888        /* Collect metadata ready to send */
889        let srcll = if let Some(erbium_net::netinfo::LinkLayer::Ethernet(srcll)) =
890            self.netinfo.get_linkaddr_by_ifidx(intf).await
891        {
892            srcll
893        } else {
894            log::warn!("{}: Not a usable LinkLayer?!", format_client(&reply));
895            DHCP_ERRORS.with_label_values(&["UNUSABLE_LINKLAYER"]).inc();
896            return;
897        };
898
899        let chaddr = if let Some(chaddr) = to_array(&reply.chaddr) {
900            chaddr
901        } else {
902            log::warn!(
903                "{}: Cannot send reply to invalid client hardware addr {:?}",
904                format_client(&reply),
905                reply.chaddr
906            );
907            DHCP_ERRORS.with_label_values(&["INVALID_CHADDR"]).inc();
908            return;
909        };
910
911        let dst = if request.pkt.get_broadcast_flag() {
912            *std::net::Ipv4Addr::BROADCAST
913                .with_port(ip4.port())
914                .as_sockaddr_in()
915                .unwrap_or(&ip4)
916        } else {
917            *reply
918                .yiaddr
919                .with_port(ip4.port())
920                .as_sockaddr_in()
921                .unwrap_or(&ip4)
922        };
923
924        /* Construct the raw packet from the reply to send */
925        let replybuf = reply.serialise();
926        let etherbuf = packet::Fragment::new_udp4(
927            *request.serverip.with_port(67).as_sockaddr_in().unwrap(),
928            &srcll,
929            dst,
930            &chaddr,
931            packet::Tail::Payload(&replybuf),
932        )
933        .flatten();
934
935        if let Err(e) = send_raw(raw, &etherbuf, intf.try_into().unwrap()).await {
936            log::warn!("{}: Failed to send reply: {:?}", format_client(&reply), e);
937            DHCP_ERRORS.with_label_values(&["SEND_ERROR"]).inc();
938        }
939    }
940
941    async fn new_internal(
942        netinfo: erbium_net::netinfo::SharedNetInfo,
943        conf: super::config::SharedConfig,
944    ) -> Result<Self, RunError> {
945        let rawsock =
946            Arc::new(raw::RawSocket::new(raw::EthProto::ALL).map_err(RunError::ListenError)?);
947        let pool = Arc::new(sync::Mutex::new(
948            pool::Pool::new().map_err(RunError::PoolError)?,
949        ));
950        let serverids: SharedServerIds =
951            Arc::new(sync::Mutex::new(std::collections::HashSet::new()));
952        let listener = UdpSocket::bind(&[UNSPECIFIED4.with_port(67)])
953            .await
954            .map_err(RunError::ListenError)?;
955        listener
956            .set_opt_ipv4_packet_info(true)
957            .map_err(RunError::ListenError)?;
958        listener
959            .set_opt_reuse_port(true)
960            .map_err(RunError::ListenError)?;
961        log::info!(
962            "Listening for DHCP on {}",
963            listener.local_addr().map_err(RunError::Io)?
964        );
965        Ok(Self {
966            netinfo,
967            conf,
968            rawsock,
969            pool,
970            serverids,
971            listener,
972        })
973    }
974
975    pub async fn new(
976        netinfo: erbium_net::netinfo::SharedNetInfo,
977        conf: super::config::SharedConfig,
978    ) -> Result<Self, String> {
979        match Self::new_internal(netinfo, conf).await {
980            Ok(x) => Ok(x),
981            Err(e) => Err(e.to_string()),
982        }
983    }
984
985    async fn run_internal(
986        self: &std::sync::Arc<Self>,
987        listener: &UdpSocket,
988    ) -> Result<(), RunError> {
989        loop {
990            let rm = match listener.recv_msg(65536, udp::MsgFlags::empty()).await {
991                Ok(m) => m,
992                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
993                Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
994                Err(e) => return Err(RunError::RecvError(e)),
995            };
996            DHCP_RX_PACKETS.inc();
997            let self2 = self.clone();
998            tokio::spawn(async move {
999                self2
1000                    .recvdhcp(
1001                        &rm.buffer,
1002                        rm.address.unwrap(),
1003                        rm.local_intf().unwrap().try_into().unwrap(),
1004                    )
1005                    .await
1006            });
1007        }
1008    }
1009
1010    pub async fn run(self: std::sync::Arc<Self>) -> Result<(), String> {
1011        match self.run_internal(&self.listener).await {
1012            Ok(_) => Ok(()),
1013            Err(e) => Err(e.to_string()),
1014        }
1015    }
1016
1017    pub async fn update_metrics(self: &std::sync::Arc<Self>) {
1018        match self.pool.lock().await.get_pool_metrics() {
1019            Ok((in_use, expired)) => {
1020                DHCP_ACTIVE_LEASES.set(in_use.into());
1021                DHCP_EXPIRED_LEASES.set(expired.into());
1022            }
1023            Err(e) => log::warn!("Failed to update metrics: {}", e),
1024        }
1025    }
1026
1027    pub async fn get_leases(self: &std::sync::Arc<Self>) -> Vec<pool::LeaseInfo> {
1028        let ret = self.pool.lock().await.get_leases();
1029        match ret {
1030            Ok(l) => l,
1031            Err(e) => {
1032                log::warn!("Failed to get leases: {}", e);
1033                Vec::new()
1034            }
1035        }
1036    }
1037}
1038
1039#[test]
1040fn test_policy() {
1041    let cfg = config::Policy {
1042        match_subnet: Some(erbium_net::Ipv4Subnet::new("192.0.2.0".parse().unwrap(), 24).unwrap()),
1043        ..Default::default()
1044    };
1045    let req = DHCPRequest {
1046        serverip: "192.0.2.67".parse().unwrap(),
1047        ..Default::default()
1048    };
1049    let mut resp = Default::default();
1050    let policies = vec![cfg];
1051
1052    assert!(apply_policies(&req, policies.as_slice(), &mut resp));
1053}
1054
1055#[tokio::test]
1056async fn test_config_parse() -> Result<(), Box<dyn std::error::Error>> {
1057    let cfg = crate::config::load_config_from_string_for_test(
1058        "---
1059dhcp-policies:
1060  - match-subnet: 192.168.0.0/24
1061    apply-dns-servers: ['8.8.8.8', '8.8.4.4']
1062    apply-subnet: 192.168.0.0/24
1063    apply-time-offset: 3600
1064    apply-domain-name: erbium.dev
1065    apply-forward: false
1066    apply-mtu: 1500
1067    apply-broadcast: 192.168.255.255
1068    apply-rebind-time: 120
1069    apply-renewal-time: 90s
1070    apply-arp-timeout: 1w
1071
1072
1073    policies:
1074       - { match-host-name: myhost, apply-address: 192.168.0.1 }
1075       - { match-hardware-address: 00:01:02:03:04:05, apply-address: 192.168.0.2 }
1076
1077
1078  - match-interface: dmz
1079    apply-dns-servers: ['8.8.8.8']
1080    apply-subnet: 192.0.2.0/24
1081
1082    # Reserve some space from the pool for servers
1083    policies:
1084      - apply-range: {start: 192.0.2.10, end: 192.0.2.20}
1085
1086        # From the reserved pool, assign a static address.
1087        policies:
1088          - { match-hardware-address: 00:01:02:03:04:05, apply-address: 192.168.0.2 }
1089
1090      # Reserve space for VPN endpoints
1091      - match-user-class: VPN
1092        apply-subnet: 192.0.2.128/25
1093        ",
1094    )?;
1095
1096    let mut resp = Response {
1097        ..Default::default()
1098    };
1099    if !apply_policies(
1100        &DHCPRequest {
1101            pkt: dhcppkt::Dhcp {
1102                op: dhcppkt::OP_BOOTREQUEST,
1103                htype: dhcppkt::HWTYPE_ETHERNET,
1104                hlen: 6,
1105                hops: 0,
1106                xid: 0,
1107                secs: 0,
1108                flags: 0,
1109                ciaddr: net::Ipv4Addr::UNSPECIFIED,
1110                yiaddr: net::Ipv4Addr::UNSPECIFIED,
1111                siaddr: net::Ipv4Addr::UNSPECIFIED,
1112                giaddr: net::Ipv4Addr::UNSPECIFIED,
1113                chaddr: vec![0, 1, 2, 3, 4, 5],
1114                sname: vec![],
1115                file: vec![],
1116                options: dhcppkt::DhcpOptions {
1117                    ..Default::default()
1118                },
1119            },
1120            serverip: "192.168.0.67".parse().unwrap(),
1121            ifindex: 1,
1122            if_mtu: None,
1123            if_router: None,
1124        },
1125        &cfg.read().await.dhcp.policies,
1126        &mut resp,
1127    ) {
1128        panic!("No policies applied");
1129    }
1130
1131    log::info!("{:?}", cfg.read().await);
1132
1133    assert_eq!(
1134        resp.address,
1135        Some(
1136            [std::net::Ipv4Addr::new(192, 168, 0, 2)]
1137                .iter()
1138                .cloned()
1139                .collect()
1140        )
1141    );
1142
1143    Ok(())
1144}
1145
1146#[test]
1147fn test_format_client() {
1148    let req = dhcppkt::Dhcp {
1149        op: dhcppkt::OP_BOOTREQUEST,
1150        htype: dhcppkt::HWTYPE_ETHERNET,
1151        hlen: 6,
1152        hops: 0,
1153        xid: 0,
1154        secs: 0,
1155        flags: 0,
1156        ciaddr: net::Ipv4Addr::UNSPECIFIED,
1157        yiaddr: net::Ipv4Addr::UNSPECIFIED,
1158        siaddr: net::Ipv4Addr::UNSPECIFIED,
1159        giaddr: net::Ipv4Addr::UNSPECIFIED,
1160        chaddr: vec![0, 1, 2, 3, 4, 5],
1161        sname: vec![],
1162        file: vec![],
1163        options: dhcppkt::DhcpOptions {
1164            ..Default::default()
1165        },
1166    };
1167    assert_eq!(format_client(&req), "00:01:02:03:04:05 ()");
1168}
1169
1170#[tokio::test]
1171async fn test_defaults() {
1172    let mut p = pool::Pool::new_in_memory().expect("Failed to create pool");
1173    let mut pkt = test::mk_dhcp_request();
1174    pkt.pkt.options.mutate_option(
1175        &dhcppkt::OPTION_PARAMLIST,
1176        &vec![
1177            dhcppkt::OPTION_DOMAINSERVER,
1178            dhcppkt::OPTION_DOMAINSEARCH,
1179            dhcppkt::OPTION_CAPTIVEPORTAL,
1180        ],
1181    );
1182
1183    let serverids: ServerIds = ServerIds::new();
1184    let conf = crate::config::Config {
1185        dns_servers: vec![
1186            "192.0.2.53".parse().unwrap(),
1187            "2001:db8::53".parse().unwrap(),
1188        ],
1189        dns_search: vec!["example.org".into()],
1190        captive_portal: Some("example.com".into()),
1191        ..test::mk_default_config()
1192    };
1193    let base = [build_default_config(&conf, &pkt)];
1194    println!("base={:?}", base);
1195    let resp =
1196        handle_discover(&mut p, &pkt, &serverids, &base, &conf).expect("Failed to handle request");
1197    assert_eq!(
1198        resp.options
1199            .get_option::<Vec<std::net::Ipv4Addr>>(&dhcppkt::OPTION_DOMAINSERVER),
1200        Some(vec!["192.0.2.53".parse::<std::net::Ipv4Addr>().unwrap()])
1201    );
1202    println!(
1203        "{:?}",
1204        resp.options.get_raw_option(&dhcppkt::OPTION_CAPTIVEPORTAL)
1205    );
1206    assert_eq!(
1207        resp.options
1208            .get_option::<Vec<u8>>(&dhcppkt::OPTION_CAPTIVEPORTAL),
1209        Some("example.com".as_bytes().to_vec())
1210    );
1211    assert_eq!(
1212        resp.options
1213            .get_option::<Vec<String>>(&dhcppkt::OPTION_DOMAINSEARCH),
1214        Some(vec![String::from("example.org")])
1215    );
1216}
1217
1218#[tokio::test]
1219async fn test_base() {
1220    let mut pool = pool::Pool::new_in_memory().expect("Failed to create pool");
1221    let mut pkt = test::mk_dhcp_request();
1222    pkt.pkt.options.mutate_option(
1223        &dhcppkt::OPTION_PARAMLIST,
1224        &vec![
1225            dhcppkt::OPTION_DOMAINSERVER,
1226            dhcppkt::OPTION_DOMAINSEARCH,
1227            dhcppkt::OPTION_CAPTIVEPORTAL,
1228        ],
1229    );
1230
1231    let serverids: ServerIds = ServerIds::new();
1232    let mut apply_address: pool::PoolAddresses = Default::default();
1233    apply_address.insert("192.0.2.3".parse().unwrap());
1234    let conf = crate::config::Config {
1235        dns_servers: vec![
1236            "192.0.2.53".parse().unwrap(),
1237            "2001:db8::53".parse().unwrap(),
1238        ],
1239        dns_search: vec!["example.org".into()],
1240        captive_portal: Some("example.com".into()),
1241        addresses: vec![config::Prefix::V4(config::Prefix4 {
1242            addr: "192.0.2.0".parse().unwrap(),
1243            prefixlen: 24,
1244        })],
1245        dhcp: config::Config {
1246            policies: vec![config::Policy {
1247                match_chaddr: Some(vec![0x0, 0x1, 0x2, 0x3, 0x4, 0x5]),
1248                apply_address: Some(apply_address),
1249                ..Default::default()
1250            }],
1251        },
1252        ..Default::default()
1253    };
1254    let base = build_default_config(&conf, &pkt);
1255    /* The generated policy should not allocate 192.0.2.3, because that is allocated in the
1256     * custom dhcp policy provided.
1257     */
1258    assert!(
1259        !base.policies[0]
1260            .apply_address
1261            .as_ref()
1262            .unwrap()
1263            .contains(&"192.0.2.3".parse().unwrap())
1264    );
1265    println!("base={:#?}", base);
1266    println!("pkt={:?}", pkt);
1267    let resp = handle_discover(&mut pool, &pkt, &serverids, &[base], &conf)
1268        .expect("Failed to handle request");
1269    assert_eq!(
1270        resp.options
1271            .get_option::<Vec<std::net::Ipv4Addr>>(&dhcppkt::OPTION_DOMAINSERVER),
1272        Some(vec!["192.0.2.53".parse::<std::net::Ipv4Addr>().unwrap()])
1273    );
1274    println!(
1275        "{:?}",
1276        resp.options.get_raw_option(&dhcppkt::OPTION_CAPTIVEPORTAL)
1277    );
1278    assert_eq!(
1279        resp.options
1280            .get_option::<Vec<u8>>(&dhcppkt::OPTION_CAPTIVEPORTAL),
1281        Some("example.com".as_bytes().to_vec())
1282    );
1283    assert_eq!(
1284        resp.options
1285            .get_option::<Vec<String>>(&dhcppkt::OPTION_DOMAINSEARCH),
1286        Some(vec![String::from("example.org")])
1287    );
1288}
1289
1290#[tokio::test]
1291/* There was a bug that if the suffix bits in the prefix were not 0, then it would silently ignore
1292 * the address.  Now we set them to 0 ourselves explicitly.
1293 */
1294async fn test_non_network_prefix() {
1295    let conf = crate::config::load_config_from_string_for_test(
1296        "---
1297addresses: [192.0.2.53/24]
1298",
1299    )
1300    .unwrap();
1301
1302    let pkt = test::mk_dhcp_request();
1303    let base = build_default_config(&*conf.read().await, &pkt);
1304
1305    let network = erbium_net::Ipv4Subnet::new("192.0.2.0".parse().unwrap(), 24).unwrap();
1306
1307    assert_eq!(base.policies[0].match_subnet.unwrap().addr, network.addr);
1308    assert_eq!(
1309        base.policies[0].match_subnet.unwrap().prefixlen,
1310        network.prefixlen
1311    );
1312}