clap_port_flag/
lib.rs

1#![deny(missing_docs)]
2#![doc = include_str!("../README.md")]
3#![cfg_attr(test, deny(warnings))]
4
5use std::net::TcpListener;
6
7#[cfg(feature = "fd")]
8use std::os::raw::c_int;
9#[cfg(feature = "fd")]
10use std::os::unix::io::FromRawFd;
11
12/// Easily add a `--port` flag to clap.
13///
14/// ## Usage
15/// ```rust
16/// #[derive(Debug, clap::Parser)]
17/// struct Cli {
18///   #[clap(flatten)]
19///   port: clap_port_flag::Port,
20/// }
21/// #
22/// # fn main() {}
23/// ```
24#[derive(clap::Args, Debug)]
25pub struct Port {
26    /// The network address and port to listen to.
27    #[cfg(feature = "addr_with_port")]
28    #[clap(short = 'a', long = "address", default_value = "127.0.0.1:80")]
29    address: std::net::SocketAddr,
30
31    /// The network address to listen to.
32    #[cfg(not(feature = "addr_with_port"))]
33    #[clap(short = 'a', long = "address", default_value = "127.0.0.1")]
34    address: String,
35
36    /// The network port to listen to.
37    #[cfg(not(features = "addr_with_port"))]
38    #[clap(short = 'p', long = "port", env = "PORT", group = "bind")]
39    port: Option<u16>,
40
41    /// A previously opened network socket.
42    #[cfg(feature = "fd")]
43    #[clap(long = "listen-fd", env = "LISTEN_FD", group = "bind")]
44    fd: Option<c_int>,
45}
46
47/// Create a TCP socket.
48///
49/// ## Panics
50/// If a file descriptor was passed directly, we call the unsafe
51/// `TcpListener::from_raw_fd()` method, which may panic if a non-existent file
52/// descriptor was passed.
53/// (This only applies if the "fd" feature is enabled)
54impl Port {
55    /// Create a TCP socket from the passed in port or file descriptor.
56    pub fn bind(&self) -> std::io::Result<TcpListener> {
57        #[cfg(feature = "fd")]
58        if let Some(fd) = self.fd {
59            return unsafe { Ok(TcpListener::from_raw_fd(fd)) };
60        }
61
62        #[cfg(feature = "addr_with_port")]
63        {
64            let addr: std::net::SocketAddr = self.address;
65            TcpListener::bind(addr)
66        }
67        #[cfg(not(feature = "addr_with_port"))]
68        {
69            let addr: &str = self.address.as_str();
70            if let Some(port) = self.port {
71                TcpListener::bind((addr, port))
72            } else {
73                Err(std::io::Error::new(
74                    std::io::ErrorKind::Other,
75                    "No port supplied.",
76                ))
77            }
78        }
79    }
80
81    /// Create a TCP socket by calling to `.bind()`. If it fails, create a socket
82    /// on `port`.
83    ///
84    /// Useful to create a default socket to listen to if none was passed.
85    #[cfg(feature = "addr_with_port")]
86    pub fn bind_or(&self, port: u16) -> std::io::Result<TcpListener> {
87        let mut addr = self.address;
88        addr.set_port(port);
89
90        self.bind().or_else(|_| TcpListener::bind(addr))
91    }
92
93    /// Create a TCP socket by calling to `.bind()`. If it fails, create a socket
94    /// on `port`.
95    ///
96    /// Useful to create a default socket to listen to if none was passed.
97    #[cfg(not(feature = "addr_with_port"))]
98    pub fn bind_or(&self, port: u16) -> std::io::Result<TcpListener> {
99        self.bind()
100            .or_else(|_| TcpListener::bind((self.address.as_str(), port)))
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    use clap::Parser;
109
110    #[derive(Debug, Parser)]
111    struct Cli {
112        #[clap(flatten)]
113        port: Port,
114    }
115
116    #[cfg(not(feature = "addr_with_port"))]
117    #[test]
118    fn test_cli() {
119        let args = Cli::try_parse_from(&["test", "--address", "1.2.3.4", "--port", "1234"]);
120        assert!(args.is_ok(), "Not ok: {:?}", args.unwrap_err());
121        let args = args.unwrap();
122        assert_eq!(args.port.address, "1.2.3.4");
123        assert_eq!(args.port.port, Some(1234));
124    }
125
126    #[cfg(feature = "addr_with_port")]
127    #[test]
128    fn test_cli() {
129        let args = Cli::try_parse_from(&["test", "--address", "1.2.3.4:8080"]);
130        assert!(args.is_ok(), "Not ok: {:?}", args.unwrap_err());
131        let args = args.unwrap();
132        let exp = std::net::SocketAddr::V4(std::net::SocketAddrV4::new(
133            std::net::Ipv4Addr::new(1, 2, 3, 4),
134            8080,
135        ));
136        assert_eq!(args.port.address, exp);
137    }
138}