1use nu_engine::command_prelude::*;
2use nu_protocol::shell_error::io::IoError;
3
4use std::{net::TcpListener, ops::RangeInclusive};
5
6#[derive(Clone)]
7pub struct Port;
8
9impl Command for Port {
10 fn name(&self) -> &str {
11 "port"
12 }
13
14 fn signature(&self) -> Signature {
15 Signature::build("port")
16 .input_output_types(vec![(Type::Nothing, Type::Int)])
17 .optional(
18 "start",
19 SyntaxShape::Int,
20 "The start port to scan (inclusive).",
21 )
22 .optional("end", SyntaxShape::Int, "The end port to scan (inclusive).")
23 .category(Category::Network)
24 }
25
26 fn description(&self) -> &str {
27 "Get a free TCP port from system."
28 }
29
30 fn search_terms(&self) -> Vec<&str> {
31 vec!["network", "http"]
32 }
33
34 fn run(
35 &self,
36 engine_state: &EngineState,
37 stack: &mut Stack,
38 call: &Call,
39 _input: PipelineData,
40 ) -> Result<PipelineData, ShellError> {
41 get_free_port(engine_state, stack, call)
42 }
43
44 fn examples(&self) -> Vec<Example<'_>> {
45 vec![
46 Example {
47 description: "get a free port between 3121 and 4000",
48 example: "port 3121 4000",
49 result: Some(Value::test_int(3121)),
50 },
51 Example {
52 description: "get a free port from system",
53 example: "port",
54 result: None,
55 },
56 ]
57 }
58}
59
60fn get_free_port(
61 engine_state: &EngineState,
62 stack: &mut Stack,
63 call: &Call,
64) -> Result<PipelineData, ShellError> {
65 let from_io_error = IoError::factory(call.head, None);
66
67 let start_port: Option<Spanned<usize>> = call.opt(engine_state, stack, 0)?;
68 let end_port: Option<Spanned<usize>> = call.opt(engine_state, stack, 1)?;
69
70 let free_port = if start_port.is_none() && end_port.is_none() {
71 system_provided_port().map_err(&from_io_error)?
72 } else {
73 let (start_port, start_span) = match start_port {
74 Some(p) => (p.item, Some(p.span)),
75 None => (1024, None),
76 };
77
78 let start_port = match u16::try_from(start_port) {
79 Ok(p) => p,
80 Err(e) => {
81 return Err(ShellError::CantConvert {
82 to_type: "u16".into(),
83 from_type: "usize".into(),
84 span: start_span.unwrap_or(call.head),
85 help: Some(format!("{e} (min: {}, max: {})", u16::MIN, u16::MAX)),
86 });
87 }
88 };
89
90 let (end_port, end_span) = match end_port {
91 Some(p) => (p.item, Some(p.span)),
92 None => (65535, None),
93 };
94
95 let end_port = match u16::try_from(end_port) {
96 Ok(p) => p,
97 Err(e) => {
98 return Err(ShellError::CantConvert {
99 to_type: "u16".into(),
100 from_type: "usize".into(),
101 span: end_span.unwrap_or(call.head),
102 help: Some(format!("{e} (min: {}, max: {})", u16::MIN, u16::MAX)),
103 });
104 }
105 };
106
107 let range_span = match (start_span, end_span) {
108 (Some(start), Some(end)) => Span::new(start.start, end.end),
109 (Some(start), None) => start,
110 (None, Some(end)) => end,
111 (None, None) => call.head,
112 };
113
114 if start_port > end_port {
116 return Err(ShellError::InvalidRange {
117 left_flank: start_port.to_string(),
118 right_flank: end_port.to_string(),
119 span: range_span,
120 });
121 }
122
123 search_port_in_range((start_port..=end_port).into_spanned(range_span), call.head)?
124 };
125
126 Ok(Value::int(free_port as i64, call.head).into_pipeline_data())
127}
128
129fn system_provided_port() -> Result<u16, std::io::Error> {
130 TcpListener::bind("127.0.0.1:0")?
131 .local_addr()
132 .map(|addr| addr.port())
133}
134
135#[cfg(not(windows))]
137fn search_port_in_range(
138 range: Spanned<RangeInclusive<u16>>,
139 call_span: Span,
140) -> Result<u16, ShellError> {
141 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
142
143 let listener = 'search: {
144 let mut last_err = None;
145 for port in range.item {
146 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
147 match TcpListener::bind(addr) {
148 Ok(listener) => break 'search Ok(listener),
149 Err(err) => last_err = Some(err),
150 }
151 }
152
153 Err(IoError::new_with_additional_context(
154 last_err.expect("range not empty, validated before"),
155 range.span,
156 None,
157 "Every port has been tried, but no valid one was found",
158 ))
159 }?;
160
161 Ok(listener
162 .local_addr()
163 .map_err(|err| IoError::new(err, call_span, None))?
164 .port())
165}
166
167#[cfg(windows)]
168mod windows {
169 use super::*;
170
171 use std::{
172 alloc::{Layout, alloc, dealloc},
173 ptr,
174 };
175
176 use ::windows::Win32::{
177 Foundation::{
178 ERROR_INSUFFICIENT_BUFFER, ERROR_INVALID_PARAMETER, ERROR_NOT_SUPPORTED, NO_ERROR,
179 WIN32_ERROR,
180 },
181 NetworkManagement::IpHelper::{GetTcpTable2, MIB_TCPROW2, MIB_TCPTABLE2},
182 Networking::WinSock::ntohs,
183 };
184
185 #[repr(C)]
186 struct TcpTable {
187 pub num_entries: u32,
188 pub table: [MIB_TCPROW2],
189 }
190
191 const _: () = assert!(align_of::<MIB_TCPTABLE2>() == 4);
192
193 impl TcpTable {
194 fn new() -> Result<Box<Self>, WIN32_ERROR> {
195 let mut size = 0;
196 let size_pointer: *mut u32 = &mut size;
197
198 let ret_code = unsafe { GetTcpTable2(None, size_pointer, false) };
203 assert_eq!(WIN32_ERROR(ret_code), ERROR_INSUFFICIENT_BUFFER);
204
205 let layout = unsafe {
209 Layout::from_size_align_unchecked(size as usize, align_of::<MIB_TCPTABLE2>())
210 };
211
212 let ptr = unsafe { alloc(layout) as *mut MIB_TCPTABLE2 };
215 assert!(!ptr.is_null());
216
217 let ret_code = unsafe { GetTcpTable2(Some(ptr), size_pointer, false) };
221 let ret_code = WIN32_ERROR(ret_code);
222 if ret_code != NO_ERROR {
223 unsafe { dealloc(ptr as *mut u8, layout) };
227 return Err(ret_code);
228 }
229
230 let header = unsafe { &*ptr };
232
233 let table = unsafe {
243 let ptr = ptr::slice_from_raw_parts_mut(ptr, header.dwNumEntries as usize);
244 Box::from_raw(ptr as *mut TcpTable)
245 };
246
247 Ok(table)
248 }
249 }
250
251 #[cfg(windows)]
263 pub fn search_port_in_range(
264 range: Spanned<RangeInclusive<u16>>,
265 call_span: Span,
266 ) -> Result<u16, ShellError> {
267 use std::collections::HashSet;
268
269 let table = TcpTable::new()
270 .map_err(|err| {
271 (
272 err,
273 match err {
274 NO_ERROR => unreachable!("handled as Ok variant"),
275 ERROR_INSUFFICIENT_BUFFER => "The buffer for TcpTable is not large enough",
276 ERROR_INVALID_PARAMETER => "SizePointer was null or not writable",
277 ERROR_NOT_SUPPORTED => "GetTcpTable2 is not supported on this OS",
278 _ => "Unexpected error code from GetTcpTable2",
279 },
280 )
281 })
282 .map_err(|(err, msg)| {
283 ShellError::Io(IoError::new_with_additional_context(
284 std::io::Error::from_raw_os_error(err.0 as i32),
285 call_span,
286 None,
287 msg,
288 ))
289 })?;
290
291 let used_ports: HashSet<u16> = table
292 .table
293 .iter()
294 .map(|row| row.dwLocalPort as u16)
295 .map(|raw| {
296 unsafe { ntohs(raw) }
299 })
300 .collect();
301
302 for port in range.item {
303 if !used_ports.contains(&port) {
304 return Ok(port);
305 }
306 }
307
308 Err(IoError::new_with_additional_context(
309 std::io::Error::from(std::io::ErrorKind::AddrInUse),
310 call_span,
311 None,
312 "All ports in the range were taken",
313 )
314 .into())
315 }
316}
317
318#[cfg(windows)]
319use windows::search_port_in_range;