nu_command/network/
port.rs

1use nu_engine::command_prelude::*;
2use nu_protocol::shell_error::io::IoError;
3
4use std::{net::TcpListener, ops::RangeInclusive};
5
6#[derive(Clone)]
7pub struct Port;
8
9impl Command for Port {
10    fn name(&self) -> &str {
11        "port"
12    }
13
14    fn signature(&self) -> Signature {
15        Signature::build("port")
16            .input_output_types(vec![(Type::Nothing, Type::Int)])
17            .optional(
18                "start",
19                SyntaxShape::Int,
20                "The start port to scan (inclusive).",
21            )
22            .optional("end", SyntaxShape::Int, "The end port to scan (inclusive).")
23            .category(Category::Network)
24    }
25
26    fn description(&self) -> &str {
27        "Get a free TCP port from system."
28    }
29
30    fn search_terms(&self) -> Vec<&str> {
31        vec!["network", "http"]
32    }
33
34    fn run(
35        &self,
36        engine_state: &EngineState,
37        stack: &mut Stack,
38        call: &Call,
39        _input: PipelineData,
40    ) -> Result<PipelineData, ShellError> {
41        get_free_port(engine_state, stack, call)
42    }
43
44    fn examples(&self) -> Vec<Example<'_>> {
45        vec![
46            Example {
47                description: "get a free port between 3121 and 4000",
48                example: "port 3121 4000",
49                result: Some(Value::test_int(3121)),
50            },
51            Example {
52                description: "get a free port from system",
53                example: "port",
54                result: None,
55            },
56        ]
57    }
58}
59
60fn get_free_port(
61    engine_state: &EngineState,
62    stack: &mut Stack,
63    call: &Call,
64) -> Result<PipelineData, ShellError> {
65    let from_io_error = IoError::factory(call.head, None);
66
67    let start_port: Option<Spanned<usize>> = call.opt(engine_state, stack, 0)?;
68    let end_port: Option<Spanned<usize>> = call.opt(engine_state, stack, 1)?;
69
70    let free_port = if start_port.is_none() && end_port.is_none() {
71        system_provided_port().map_err(&from_io_error)?
72    } else {
73        let (start_port, start_span) = match start_port {
74            Some(p) => (p.item, Some(p.span)),
75            None => (1024, None),
76        };
77
78        let start_port = match u16::try_from(start_port) {
79            Ok(p) => p,
80            Err(e) => {
81                return Err(ShellError::CantConvert {
82                    to_type: "u16".into(),
83                    from_type: "usize".into(),
84                    span: start_span.unwrap_or(call.head),
85                    help: Some(format!("{e} (min: {}, max: {})", u16::MIN, u16::MAX)),
86                });
87            }
88        };
89
90        let (end_port, end_span) = match end_port {
91            Some(p) => (p.item, Some(p.span)),
92            None => (65535, None),
93        };
94
95        let end_port = match u16::try_from(end_port) {
96            Ok(p) => p,
97            Err(e) => {
98                return Err(ShellError::CantConvert {
99                    to_type: "u16".into(),
100                    from_type: "usize".into(),
101                    span: end_span.unwrap_or(call.head),
102                    help: Some(format!("{e} (min: {}, max: {})", u16::MIN, u16::MAX)),
103                });
104            }
105        };
106
107        let range_span = match (start_span, end_span) {
108            (Some(start), Some(end)) => Span::new(start.start, end.end),
109            (Some(start), None) => start,
110            (None, Some(end)) => end,
111            (None, None) => call.head,
112        };
113
114        // check input range valid.
115        if start_port > end_port {
116            return Err(ShellError::InvalidRange {
117                left_flank: start_port.to_string(),
118                right_flank: end_port.to_string(),
119                span: range_span,
120            });
121        }
122
123        search_port_in_range((start_port..=end_port).into_spanned(range_span), call.head)?
124    };
125
126    Ok(Value::int(free_port as i64, call.head).into_pipeline_data())
127}
128
129fn system_provided_port() -> Result<u16, std::io::Error> {
130    TcpListener::bind("127.0.0.1:0")?
131        .local_addr()
132        .map(|addr| addr.port())
133}
134
135/// Find an open port by binding to every possible port in range.
136#[cfg(not(windows))]
137fn search_port_in_range(
138    range: Spanned<RangeInclusive<u16>>,
139    call_span: Span,
140) -> Result<u16, ShellError> {
141    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
142
143    let listener = 'search: {
144        let mut last_err = None;
145        for port in range.item {
146            let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
147            match TcpListener::bind(addr) {
148                Ok(listener) => break 'search Ok(listener),
149                Err(err) => last_err = Some(err),
150            }
151        }
152
153        Err(IoError::new_with_additional_context(
154            last_err.expect("range not empty, validated before"),
155            range.span,
156            None,
157            "Every port has been tried, but no valid one was found",
158        ))
159    }?;
160
161    Ok(listener
162        .local_addr()
163        .map_err(|err| IoError::new(err, call_span, None))?
164        .port())
165}
166
167#[cfg(windows)]
168mod windows {
169    use super::*;
170
171    use std::{
172        alloc::{Layout, alloc, dealloc},
173        ptr,
174    };
175
176    use ::windows::Win32::{
177        Foundation::{
178            ERROR_INSUFFICIENT_BUFFER, ERROR_INVALID_PARAMETER, ERROR_NOT_SUPPORTED, NO_ERROR,
179            WIN32_ERROR,
180        },
181        NetworkManagement::IpHelper::{GetTcpTable2, MIB_TCPROW2, MIB_TCPTABLE2},
182        Networking::WinSock::ntohs,
183    };
184
185    #[repr(C)]
186    struct TcpTable {
187        pub num_entries: u32,
188        pub table: [MIB_TCPROW2],
189    }
190
191    const _: () = assert!(align_of::<MIB_TCPTABLE2>() == 4);
192
193    impl TcpTable {
194        fn new() -> Result<Box<Self>, WIN32_ERROR> {
195            let mut size = 0;
196            let size_pointer: *mut u32 = &mut size;
197
198            // SAFETY:
199            // - Passing a null table pointer queries the required size (documented behavior).
200            // - `size_pointer` is a valid, non-null out pointer.
201            // - We expect `ERROR_INSUFFICIENT_BUFFER` so that `size` is written.
202            let ret_code = unsafe { GetTcpTable2(None, size_pointer, false) };
203            assert_eq!(WIN32_ERROR(ret_code), ERROR_INSUFFICIENT_BUFFER);
204
205            // SAFETY:
206            // - Alignment is 4: non-zero and a power of two.
207            // - `size` comes from the API and is expected to be reasonable for allocation.
208            let layout = unsafe {
209                Layout::from_size_align_unchecked(size as usize, align_of::<MIB_TCPTABLE2>())
210            };
211
212            // IMPORTANT: This allocation must be freed or transferred to ownership before leaving this scope.
213            // SAFETY: `layout` has non-zero size (at least 4 for one u32).
214            let ptr = unsafe { alloc(layout) as *mut MIB_TCPTABLE2 };
215            assert!(!ptr.is_null());
216
217            // SAFETY:
218            // - `ptr` is non-null, properly aligned, and points to `size` bytes.
219            // - `size_pointer` still points to `size` from the first call.
220            let ret_code = unsafe { GetTcpTable2(Some(ptr), size_pointer, false) };
221            let ret_code = WIN32_ERROR(ret_code);
222            if ret_code != NO_ERROR {
223                // SAFETY:
224                // - `ptr` was allocated with `alloc(layout)` in this function.
225                // - Using the same `layout` to deallocate is correct.
226                unsafe { dealloc(ptr as *mut u8, layout) };
227                return Err(ret_code);
228            }
229
230            // SAFETY: `GetTcpTable2` returned `NO_ERROR`, so the header at `ptr` is initialized.
231            let header = unsafe { &*ptr };
232
233            // SAFETY:
234            // - Memory at `ptr` came from the global allocator and is initialized.
235            // - `TcpTable` is #[repr(C)] and layout-compatible with `MIB_TCPTABLE2` plus trailing rows.
236            // - We build a slice fat pointer only to carry the length; we do not dereference the slice itself here.
237            // - Casts between slice DSTs preserve the length metadata:
238            //     https://github.com/rust-lang/unsafe-code-guidelines/issues/288
239            //     https://github.com/rust-lang/reference/pull/1417
240            // - Casting to `*mut TcpTable` preserves that metadata for our DST.
241            // - `Box::from_raw` takes ownership and will free via the same allocator.
242            let table = unsafe {
243                let ptr = ptr::slice_from_raw_parts_mut(ptr, header.dwNumEntries as usize);
244                Box::from_raw(ptr as *mut TcpTable)
245            };
246
247            Ok(table)
248        }
249    }
250
251    /// Find an open port by checking the TCP table.
252    ///
253    /// On Windows, it is possible to bind to the same port multiple times if it was not
254    /// originally bound as an exclusive port[^so].
255    /// The Rust implementation of [`TcpListener::bind`] currently does not enforce exclusive
256    /// binding, which means the same port can be bound more than once.  
257    /// Because of this, we cannot simply try binding to a port to check if it is free.  
258    /// Instead, we query the [TCP table](https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable2)
259    /// to see which ports are already in use and then pick one that is not listed.
260    ///
261    /// [^so]: <https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse>
262    #[cfg(windows)]
263    pub fn search_port_in_range(
264        range: Spanned<RangeInclusive<u16>>,
265        call_span: Span,
266    ) -> Result<u16, ShellError> {
267        use std::collections::HashSet;
268
269        let table = TcpTable::new()
270            .map_err(|err| {
271                (
272                    err,
273                    match err {
274                        NO_ERROR => unreachable!("handled as Ok variant"),
275                        ERROR_INSUFFICIENT_BUFFER => "The buffer for TcpTable is not large enough",
276                        ERROR_INVALID_PARAMETER => "SizePointer was null or not writable",
277                        ERROR_NOT_SUPPORTED => "GetTcpTable2 is not supported on this OS",
278                        _ => "Unexpected error code from GetTcpTable2",
279                    },
280                )
281            })
282            .map_err(|(err, msg)| {
283                ShellError::Io(IoError::new_with_additional_context(
284                    std::io::Error::from_raw_os_error(err.0 as i32),
285                    call_span,
286                    None,
287                    msg,
288                ))
289            })?;
290
291        let used_ports: HashSet<u16> = table
292            .table
293            .iter()
294            .map(|row| row.dwLocalPort as u16)
295            .map(|raw| {
296                // Convert from network byte order to host byte order.
297                // SAFETY: `raw` is the exact value returned by the API for a port.
298                unsafe { ntohs(raw) }
299            })
300            .collect();
301
302        for port in range.item {
303            if !used_ports.contains(&port) {
304                return Ok(port);
305            }
306        }
307
308        Err(IoError::new_with_additional_context(
309            std::io::Error::from(std::io::ErrorKind::AddrInUse),
310            call_span,
311            None,
312            "All ports in the range were taken",
313        )
314        .into())
315    }
316}
317
318#[cfg(windows)]
319use windows::search_port_in_range;