1use crate::error::{ProcCtlError, ProcCtlResult};
2use crate::types::{Pid, ProtocolPort};
3use std::process::Child;
4
5#[derive(Debug)]
7pub struct PortQuery {
8 ipv4_addresses: bool,
9 ipv6_addresses: bool,
10 tcp_addresses: bool,
11 udp_addresses: bool,
12 process_id: Option<Pid>,
13 min_num_ports: Option<usize>,
14}
15
16impl PortQuery {
17 pub fn new() -> Self {
19 PortQuery {
20 ipv4_addresses: true,
21 ipv6_addresses: true,
22 tcp_addresses: true,
23 udp_addresses: true,
24 process_id: None,
25 min_num_ports: None,
26 }
27 }
28
29 pub fn ip_v4_only(mut self) -> Self {
31 self.ipv4_addresses = true;
32 self.ipv6_addresses = false;
33 self
34 }
35
36 pub fn ip_v6_only(mut self) -> Self {
38 self.ipv4_addresses = false;
39 self.ipv6_addresses = true;
40 self
41 }
42
43 pub fn tcp_only(mut self) -> Self {
45 self.tcp_addresses = true;
46 self.udp_addresses = false;
47 self
48 }
49
50 pub fn udp_only(mut self) -> Self {
52 self.tcp_addresses = false;
53 self.udp_addresses = true;
54 self
55 }
56
57 pub fn expect_min_num_ports(mut self, num_ports: usize) -> Self {
59 self.min_num_ports = Some(num_ports);
60 self
61 }
62
63 pub fn process_id(mut self, pid: Pid) -> Self {
67 self.process_id = Some(pid);
68 self
69 }
70
71 pub fn process_id_from_child(self, child: &Child) -> Self {
75 self.process_id(child.id())
76 }
77
78 pub fn execute(&self) -> ProcCtlResult<Vec<ProtocolPort>> {
80 #[cfg(any(target_os = "linux", target_os = "windows", target_os = "macos"))]
81 let ports = list_ports_for_pid(self, crate::common::resolve_pid(self)?)?;
82 #[cfg(not(any(target_os = "linux", target_os = "windows", target_os = "macos")))]
83 let ports = Vec::with_capacity(0);
84
85 if let Some(num) = &self.min_num_ports {
86 if ports.len() < *num {
87 return Err(ProcCtlError::TooFewPorts(ports, *num));
88 }
89 }
90
91 Ok(ports)
92 }
93
94 #[cfg(feature = "resilience")]
96 pub fn execute_with_retry_sync(
97 &self,
98 delay: std::time::Duration,
99 count: usize,
100 ) -> ProcCtlResult<Vec<ProtocolPort>> {
101 retry::retry(retry::delay::Fixed::from(delay).take(count), || {
102 self.execute()
103 })
104 .map_err(|e| e.error)
105 }
106
107 #[cfg(feature = "async")]
109 #[async_recursion::async_recursion]
110 pub async fn execute_with_retry(
111 &self,
112 delay: std::time::Duration,
113 count: usize,
114 ) -> ProcCtlResult<Vec<ProtocolPort>> {
115 match self.execute() {
116 Ok(ports) => Ok(ports),
117 Err(e) => {
118 if count == 0 {
119 Err(e)
120 } else {
121 tokio::time::sleep(delay).await;
122 self.execute_with_retry(delay, count - 1).await
123 }
124 }
125 }
126 }
127}
128
129#[cfg(target_os = "linux")]
130fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<ProtocolPort>> {
131 let proc = procfs::process::Process::new(pid as i32)?;
132 let fds = proc.fd()?;
133 let socket_nodes = fds
134 .filter_map(|fd| {
135 if let Ok(fd) = fd {
136 match fd.target {
137 procfs::process::FDTarget::Socket(inode) => Some(inode),
138 _ => None,
139 }
140 } else {
141 None
142 }
143 })
144 .collect::<std::collections::HashSet<_>>();
145
146 let mut out = Vec::new();
147
148 if query.tcp_addresses {
149 let mut tcp_entries = proc.tcp()?;
150
151 if query.ipv6_addresses {
152 let tcp6_entries = proc.tcp6()?;
153
154 tcp_entries.extend(tcp6_entries);
155 }
156
157 for entry in tcp_entries {
158 if entry.state == procfs::net::TcpState::Listen && socket_nodes.contains(&entry.inode) {
159 out.push(ProtocolPort::Tcp(entry.local_address.port()));
160 }
161 }
162 }
163
164 if query.udp_addresses {
165 let mut udp_entries = proc.udp()?;
166
167 if query.ipv6_addresses {
168 let udp6_entries = proc.udp6()?;
169 udp_entries.extend(udp6_entries);
170 }
171
172 for entry in udp_entries {
173 if socket_nodes.contains(&entry.inode) {
174 out.push(ProtocolPort::Udp(entry.local_address.port()));
175 }
176 }
177 }
178
179 Ok(out)
180}
181
182#[cfg(target_os = "windows")]
183fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<ProtocolPort>> {
184 let mut out = Vec::new();
185
186 if query.tcp_addresses {
187 if query.ipv4_addresses {
188 let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET)?;
189 let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID = unsafe {
190 &mut *(table.as_mut_ptr()
191 as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID)
192 };
193
194 for i in 0..table.dwNumEntries as usize {
195 let row = unsafe { &*table.table.as_mut_ptr().add(i) };
196 if row.dwOwningPid == pid {
197 out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
198 }
199 }
200 }
201 if query.ipv6_addresses {
202 let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET6)?;
203 let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID = unsafe {
204 &mut *(table.as_mut_ptr()
205 as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID)
206 };
207
208 for i in 0..table.dwNumEntries as usize {
209 let row = unsafe { &*table.table.as_mut_ptr().add(i) };
210 if row.dwOwningPid == pid {
211 out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
212 }
213 }
214 }
215 }
216 if query.udp_addresses {
217 if query.ipv4_addresses {
218 let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET)?;
219 let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID = unsafe {
220 &mut *(table.as_mut_ptr()
221 as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID)
222 };
223
224 for i in 0..table.dwNumEntries as usize {
225 let row = unsafe { &*table.table.as_mut_ptr().add(i) };
226 if row.dwOwningPid == pid {
227 out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
228 }
229 }
230 }
231 if query.ipv6_addresses {
232 let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET6)?;
233 let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID = unsafe {
234 &mut *(table.as_mut_ptr()
235 as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID)
236 };
237
238 for i in 0..table.dwNumEntries as usize {
239 let row = unsafe { &*table.table.as_mut_ptr().add(i) };
240 if row.dwOwningPid == pid {
241 out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
242 }
243 }
244 }
245 }
246
247 Ok(out)
248}
249
250#[cfg(target_os = "windows")]
251fn load_tcp_table(
252 family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY,
253) -> ProcCtlResult<Vec<u8>> {
254 let mut table = Vec::<u8>::with_capacity(0);
255 let mut table_size: u32 = 0;
256 for _ in 0..3 {
257 let err_code = unsafe {
258 windows::Win32::Foundation::WIN32_ERROR(
259 windows::Win32::NetworkManagement::IpHelper::GetExtendedTcpTable(
260 Some(table.as_mut_ptr() as *mut _),
261 &mut table_size,
262 false,
263 family.0 as u32,
264 windows::Win32::NetworkManagement::IpHelper::TCP_TABLE_OWNER_PID_ALL,
265 0,
266 ),
267 )
268 };
269
270 if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER {
271 table.resize(table_size as usize, 0);
272 continue;
273 } else if err_code != windows::Win32::Foundation::NO_ERROR {
274 return Err(ProcCtlError::ProcessError(format!(
275 "Failed to get TCP table: {:?}",
276 err_code
277 )));
278 }
279
280 return Ok(table);
281 }
282
283 Err(ProcCtlError::ProcessError(
284 "Failed to get TCP table".to_string(),
285 ))
286}
287
288#[cfg(target_os = "windows")]
289fn load_udp_table(
290 family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY,
291) -> ProcCtlResult<Vec<u8>> {
292 let mut table = Vec::<u8>::with_capacity(0);
293 let mut table_size: u32 = 0;
294 for _ in 0..3 {
295 let err_code = unsafe {
296 windows::Win32::Foundation::WIN32_ERROR(
297 windows::Win32::NetworkManagement::IpHelper::GetExtendedUdpTable(
298 Some(table.as_mut_ptr() as *mut _),
299 &mut table_size,
300 false,
301 family.0 as u32,
302 windows::Win32::NetworkManagement::IpHelper::UDP_TABLE_OWNER_PID,
303 0,
304 ),
305 )
306 };
307
308 if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER {
309 table.resize(table_size as usize, 0);
310 continue;
311 } else if err_code != windows::Win32::Foundation::NO_ERROR {
312 return Err(ProcCtlError::ProcessError(format!(
313 "Failed to get UDP table: {:?}",
314 err_code
315 )));
316 }
317
318 return Ok(table);
319 }
320
321 Err(ProcCtlError::ProcessError(
322 "Failed to get UDP table".to_string(),
323 ))
324}
325
326#[cfg(target_os = "macos")]
327fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<ProtocolPort>> {
328 let mut out = Vec::new();
329
330 if query.ipv4_addresses {
331 if query.tcp_addresses {
332 match std::process::Command::new("lsof")
333 .arg("-a")
334 .arg("-iTCP")
335 .arg("-i4")
336 .arg("-sTCP:LISTEN")
337 .arg("-nP")
338 .arg("-F0pn")
339 .output()
340 {
341 Ok(output) => out.extend(
342 find_ports_v4(output.stdout.clone(), pid)
343 .into_iter()
344 .map(ProtocolPort::Tcp),
345 ),
346 Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
347 }
348 }
349 if query.udp_addresses {
350 match std::process::Command::new("lsof")
351 .arg("-a")
352 .arg("-iUDP")
353 .arg("-i4")
354 .arg("-nP")
355 .arg("-F0pn")
356 .output()
357 {
358 Ok(output) => out.extend(
359 find_ports_v4(output.stdout.clone(), pid)
360 .into_iter()
361 .map(ProtocolPort::Udp),
362 ),
363 Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
364 }
365 }
366 }
367 if query.ipv6_addresses {
368 if query.tcp_addresses {
369 match std::process::Command::new("lsof")
370 .arg("-a")
371 .arg("-iTCP")
372 .arg("-i6")
373 .arg("-sTCP:LISTEN")
374 .arg("-nP")
375 .arg("-F0pn")
376 .output()
377 {
378 Ok(output) => out.extend(
379 find_ports_v6(output.stdout.clone(), pid)
380 .into_iter()
381 .map(ProtocolPort::Tcp),
382 ),
383 Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
384 }
385 }
386 if query.udp_addresses {
387 match std::process::Command::new("lsof")
388 .arg("-a")
389 .arg("-iUDP")
390 .arg("-i6")
391 .arg("-nP")
392 .arg("-F0pn")
393 .output()
394 {
395 Ok(output) => out.extend(
396 find_ports_v6(output.stdout.clone(), pid)
397 .into_iter()
398 .map(ProtocolPort::Udp),
399 ),
400 Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
401 }
402 }
403 }
404
405 Ok(out)
406}
407
408#[cfg(target_os = "macos")]
409fn find_ports_v4(output: Vec<u8>, find_pid: Pid) -> Vec<u16> {
410 let mut out = Vec::new();
411
412 let mut index = 0;
413 let len = output.len();
414 while index < len {
415 if output[index] != b'p' {
416 break;
417 }
418 index += 1;
419
420 let start_pid = index;
421 while index < len && output[index] != 0 {
422 index += 1;
423 }
424
425 let Some(pid) = String::from_utf8_lossy(&output[start_pid..index])
426 .parse::<u32>()
427 .ok()
428 else {
429 break;
430 };
431 index += 1; index += 1; loop {
435 if pid == find_pid && index < len && output[index] == b'n' {
436 while index < len && output[index] != b':' {
437 index += 1;
438 }
439 index += 1; let start_port = index;
442 while index < len && output[index] != 0 {
443 index += 1;
444 }
445
446 if index >= len {
447 break;
448 }
449
450 if let Ok(port) = String::from_utf8_lossy(&output[start_port..index]).parse::<u16>()
451 {
452 out.push(port);
453 };
454 index += 1; } else {
456 while index < len && output[index] != 0 {
457 index += 1;
458 }
459 index += 1; }
461
462 if index < len && output[index] == 10 {
463 index += 1;
465 }
466
467 if index >= len || output[index] == b'p' {
468 break;
469 }
470 }
471 }
472
473 out
474}
475
476#[cfg(target_os = "macos")]
477fn find_ports_v6(output: Vec<u8>, find_pid: Pid) -> Vec<u16> {
478 let mut out = Vec::new();
479
480 let mut index = 0;
481 let len = output.len();
482 while index < len {
483 if output[index] != b'p' {
484 break;
485 }
486 index += 1;
487
488 let start_pid = index;
489 while index < len && output[index] != 0 {
490 index += 1;
491 }
492
493 let Ok(pid) = String::from_utf8_lossy(&output[start_pid..index]).parse::<u32>() else {
494 break;
495 };
496 index += 1; index += 1; loop {
500 if pid == find_pid && index < len && output[index] == b'n' {
501 while index < len && output[index] != b']' {
502 index += 1;
503 }
504 index += 1; if index < len && output[index] != b':' {
507 break;
508 }
509 index += 1;
510
511 let start_port = index;
512 while index < len && output[index] != 0 {
513 index += 1;
514 }
515
516 if index >= len {
517 break;
518 }
519
520 if let Ok(port) = String::from_utf8_lossy(&output[start_port..index]).parse::<u16>()
521 {
522 out.push(port);
523 };
524 index += 1; } else {
526 while index < len && output[index] != 0 {
527 index += 1;
528 }
529 index += 1; }
531
532 if index < len && output[index] == 10 {
533 index += 1;
535 }
536
537 if index >= len || output[index] == b'p' {
538 break;
539 }
540 }
541 }
542
543 out
544}
545
546#[cfg(any(
547 target_os = "linux",
548 target_os = "windows",
549 target_os = "macos",
550 feature = "proc"
551))]
552impl crate::common::MaybeHasPid for PortQuery {
553 fn get_pid(&self) -> Option<Pid> {
554 self.process_id
555 }
556}
557
558impl Default for PortQuery {
559 fn default() -> Self {
560 PortQuery::new()
561 }
562}