proc_ctl/
port_query.rs

1use crate::error::{ProcCtlError, ProcCtlResult};
2use crate::types::{Pid, ProtocolPort};
3use std::process::Child;
4
5/// Find the ports used by a process
6#[derive(Debug)]
7pub struct PortQuery {
8    ipv4_addresses: bool,
9    ipv6_addresses: bool,
10    tcp_addresses: bool,
11    udp_addresses: bool,
12    process_id: Option<Pid>,
13    min_num_ports: Option<usize>,
14}
15
16impl PortQuery {
17    /// Create a new query
18    pub fn new() -> Self {
19        PortQuery {
20            ipv4_addresses: true,
21            ipv6_addresses: true,
22            tcp_addresses: true,
23            udp_addresses: true,
24            process_id: None,
25            min_num_ports: None,
26        }
27    }
28
29    /// Only consider IPv4 addresses
30    pub fn ip_v4_only(mut self) -> Self {
31        self.ipv4_addresses = true;
32        self.ipv6_addresses = false;
33        self
34    }
35
36    /// Only consider IPv6 addresses
37    pub fn ip_v6_only(mut self) -> Self {
38        self.ipv4_addresses = false;
39        self.ipv6_addresses = true;
40        self
41    }
42
43    /// Only consider TCP ports
44    pub fn tcp_only(mut self) -> Self {
45        self.tcp_addresses = true;
46        self.udp_addresses = false;
47        self
48    }
49
50    /// Only consider UDP ports
51    pub fn udp_only(mut self) -> Self {
52        self.tcp_addresses = false;
53        self.udp_addresses = true;
54        self
55    }
56
57    /// Require at least `num_ports` ports to be bound by the matched process for the query to succeed.
58    pub fn expect_min_num_ports(mut self, num_ports: usize) -> Self {
59        self.min_num_ports = Some(num_ports);
60        self
61    }
62
63    /// Set the process ID to match
64    ///
65    /// Either this function or `process_id_from_child` are required to be called before the query is usable.
66    pub fn process_id(mut self, pid: Pid) -> Self {
67        self.process_id = Some(pid);
68        self
69    }
70
71    /// Get the process ID of a child process
72    ///
73    /// Either this function or `process_id` are required to be called before the query is usable.
74    pub fn process_id_from_child(self, child: &Child) -> Self {
75        self.process_id(child.id())
76    }
77
78    /// Execute the query
79    pub fn execute(&self) -> ProcCtlResult<Vec<ProtocolPort>> {
80        #[cfg(any(target_os = "linux", target_os = "windows", target_os = "macos"))]
81        let ports = list_ports_for_pid(self, crate::common::resolve_pid(self)?)?;
82        #[cfg(not(any(target_os = "linux", target_os = "windows", target_os = "macos")))]
83        let ports = Vec::with_capacity(0);
84
85        if let Some(num) = &self.min_num_ports {
86            if ports.len() < *num {
87                return Err(ProcCtlError::TooFewPorts(ports, *num));
88            }
89        }
90
91        Ok(ports)
92    }
93
94    /// Execute the query and retry until it succeeds or exhausts the configured retries
95    #[cfg(feature = "resilience")]
96    pub fn execute_with_retry_sync(
97        &self,
98        delay: std::time::Duration,
99        count: usize,
100    ) -> ProcCtlResult<Vec<ProtocolPort>> {
101        retry::retry(retry::delay::Fixed::from(delay).take(count), || {
102            self.execute()
103        })
104        .map_err(|e| e.error)
105    }
106
107    /// Async equivalent of `execute_with_retry_sync`
108    #[cfg(feature = "async")]
109    #[async_recursion::async_recursion]
110    pub async fn execute_with_retry(
111        &self,
112        delay: std::time::Duration,
113        count: usize,
114    ) -> ProcCtlResult<Vec<ProtocolPort>> {
115        match self.execute() {
116            Ok(ports) => Ok(ports),
117            Err(e) => {
118                if count == 0 {
119                    Err(e)
120                } else {
121                    tokio::time::sleep(delay).await;
122                    self.execute_with_retry(delay, count - 1).await
123                }
124            }
125        }
126    }
127}
128
129#[cfg(target_os = "linux")]
130fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<ProtocolPort>> {
131    let proc = procfs::process::Process::new(pid as i32)?;
132    let fds = proc.fd()?;
133    let socket_nodes = fds
134        .filter_map(|fd| {
135            if let Ok(fd) = fd {
136                match fd.target {
137                    procfs::process::FDTarget::Socket(inode) => Some(inode),
138                    _ => None,
139                }
140            } else {
141                None
142            }
143        })
144        .collect::<std::collections::HashSet<_>>();
145
146    let mut out = Vec::new();
147
148    if query.tcp_addresses {
149        let mut tcp_entries = proc.tcp()?;
150
151        if query.ipv6_addresses {
152            let tcp6_entries = proc.tcp6()?;
153
154            tcp_entries.extend(tcp6_entries);
155        }
156
157        for entry in tcp_entries {
158            if entry.state == procfs::net::TcpState::Listen && socket_nodes.contains(&entry.inode) {
159                out.push(ProtocolPort::Tcp(entry.local_address.port()));
160            }
161        }
162    }
163
164    if query.udp_addresses {
165        let mut udp_entries = proc.udp()?;
166
167        if query.ipv6_addresses {
168            let udp6_entries = proc.udp6()?;
169            udp_entries.extend(udp6_entries);
170        }
171
172        for entry in udp_entries {
173            if socket_nodes.contains(&entry.inode) {
174                out.push(ProtocolPort::Udp(entry.local_address.port()));
175            }
176        }
177    }
178
179    Ok(out)
180}
181
182#[cfg(target_os = "windows")]
183fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<ProtocolPort>> {
184    let mut out = Vec::new();
185
186    if query.tcp_addresses {
187        if query.ipv4_addresses {
188            let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET)?;
189            let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID = unsafe {
190                &mut *(table.as_mut_ptr()
191                    as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCPTABLE_OWNER_PID)
192            };
193
194            for i in 0..table.dwNumEntries as usize {
195                let row = unsafe { &*table.table.as_mut_ptr().add(i) };
196                if row.dwOwningPid == pid {
197                    out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
198                }
199            }
200        }
201        if query.ipv6_addresses {
202            let mut table = load_tcp_table(windows::Win32::Networking::WinSock::AF_INET6)?;
203            let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID = unsafe {
204                &mut *(table.as_mut_ptr()
205                    as *mut windows::Win32::NetworkManagement::IpHelper::MIB_TCP6TABLE_OWNER_PID)
206            };
207
208            for i in 0..table.dwNumEntries as usize {
209                let row = unsafe { &*table.table.as_mut_ptr().add(i) };
210                if row.dwOwningPid == pid {
211                    out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
212                }
213            }
214        }
215    }
216    if query.udp_addresses {
217        if query.ipv4_addresses {
218            let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET)?;
219            let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID = unsafe {
220                &mut *(table.as_mut_ptr()
221                    as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDPTABLE_OWNER_PID)
222            };
223
224            for i in 0..table.dwNumEntries as usize {
225                let row = unsafe { &*table.table.as_mut_ptr().add(i) };
226                if row.dwOwningPid == pid {
227                    out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
228                }
229            }
230        }
231        if query.ipv6_addresses {
232            let mut table = load_udp_table(windows::Win32::Networking::WinSock::AF_INET6)?;
233            let table: &mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID = unsafe {
234                &mut *(table.as_mut_ptr()
235                    as *mut windows::Win32::NetworkManagement::IpHelper::MIB_UDP6TABLE_OWNER_PID)
236            };
237
238            for i in 0..table.dwNumEntries as usize {
239                let row = unsafe { &*table.table.as_mut_ptr().add(i) };
240                if row.dwOwningPid == pid {
241                    out.push(ProtocolPort::Tcp(row.dwLocalPort as u16));
242                }
243            }
244        }
245    }
246
247    Ok(out)
248}
249
250#[cfg(target_os = "windows")]
251fn load_tcp_table(
252    family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY,
253) -> ProcCtlResult<Vec<u8>> {
254    let mut table = Vec::<u8>::with_capacity(0);
255    let mut table_size: u32 = 0;
256    for _ in 0..3 {
257        let err_code = unsafe {
258            windows::Win32::Foundation::WIN32_ERROR(
259                windows::Win32::NetworkManagement::IpHelper::GetExtendedTcpTable(
260                    Some(table.as_mut_ptr() as *mut _),
261                    &mut table_size,
262                    false,
263                    family.0 as u32,
264                    windows::Win32::NetworkManagement::IpHelper::TCP_TABLE_OWNER_PID_ALL,
265                    0,
266                ),
267            )
268        };
269
270        if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER {
271            table.resize(table_size as usize, 0);
272            continue;
273        } else if err_code != windows::Win32::Foundation::NO_ERROR {
274            return Err(ProcCtlError::ProcessError(format!(
275                "Failed to get TCP table: {:?}",
276                err_code
277            )));
278        }
279
280        return Ok(table);
281    }
282
283    Err(ProcCtlError::ProcessError(
284        "Failed to get TCP table".to_string(),
285    ))
286}
287
288#[cfg(target_os = "windows")]
289fn load_udp_table(
290    family: windows::Win32::Networking::WinSock::ADDRESS_FAMILY,
291) -> ProcCtlResult<Vec<u8>> {
292    let mut table = Vec::<u8>::with_capacity(0);
293    let mut table_size: u32 = 0;
294    for _ in 0..3 {
295        let err_code = unsafe {
296            windows::Win32::Foundation::WIN32_ERROR(
297                windows::Win32::NetworkManagement::IpHelper::GetExtendedUdpTable(
298                    Some(table.as_mut_ptr() as *mut _),
299                    &mut table_size,
300                    false,
301                    family.0 as u32,
302                    windows::Win32::NetworkManagement::IpHelper::UDP_TABLE_OWNER_PID,
303                    0,
304                ),
305            )
306        };
307
308        if err_code == windows::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER {
309            table.resize(table_size as usize, 0);
310            continue;
311        } else if err_code != windows::Win32::Foundation::NO_ERROR {
312            return Err(ProcCtlError::ProcessError(format!(
313                "Failed to get UDP table: {:?}",
314                err_code
315            )));
316        }
317
318        return Ok(table);
319    }
320
321    Err(ProcCtlError::ProcessError(
322        "Failed to get UDP table".to_string(),
323    ))
324}
325
326#[cfg(target_os = "macos")]
327fn list_ports_for_pid(query: &PortQuery, pid: Pid) -> ProcCtlResult<Vec<ProtocolPort>> {
328    let mut out = Vec::new();
329
330    if query.ipv4_addresses {
331        if query.tcp_addresses {
332            match std::process::Command::new("lsof")
333                .arg("-a")
334                .arg("-iTCP")
335                .arg("-i4")
336                .arg("-sTCP:LISTEN")
337                .arg("-nP")
338                .arg("-F0pn")
339                .output()
340            {
341                Ok(output) => out.extend(
342                    find_ports_v4(output.stdout.clone(), pid)
343                        .into_iter()
344                        .map(ProtocolPort::Tcp),
345                ),
346                Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
347            }
348        }
349        if query.udp_addresses {
350            match std::process::Command::new("lsof")
351                .arg("-a")
352                .arg("-iUDP")
353                .arg("-i4")
354                .arg("-nP")
355                .arg("-F0pn")
356                .output()
357            {
358                Ok(output) => out.extend(
359                    find_ports_v4(output.stdout.clone(), pid)
360                        .into_iter()
361                        .map(ProtocolPort::Udp),
362                ),
363                Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
364            }
365        }
366    }
367    if query.ipv6_addresses {
368        if query.tcp_addresses {
369            match std::process::Command::new("lsof")
370                .arg("-a")
371                .arg("-iTCP")
372                .arg("-i6")
373                .arg("-sTCP:LISTEN")
374                .arg("-nP")
375                .arg("-F0pn")
376                .output()
377            {
378                Ok(output) => out.extend(
379                    find_ports_v6(output.stdout.clone(), pid)
380                        .into_iter()
381                        .map(ProtocolPort::Tcp),
382                ),
383                Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
384            }
385        }
386        if query.udp_addresses {
387            match std::process::Command::new("lsof")
388                .arg("-a")
389                .arg("-iUDP")
390                .arg("-i6")
391                .arg("-nP")
392                .arg("-F0pn")
393                .output()
394            {
395                Ok(output) => out.extend(
396                    find_ports_v6(output.stdout.clone(), pid)
397                        .into_iter()
398                        .map(ProtocolPort::Udp),
399                ),
400                Err(e) => return Err(ProcCtlError::ProcessError(e.to_string())),
401            }
402        }
403    }
404
405    Ok(out)
406}
407
408#[cfg(target_os = "macos")]
409fn find_ports_v4(output: Vec<u8>, find_pid: Pid) -> Vec<u16> {
410    let mut out = Vec::new();
411
412    let mut index = 0;
413    let len = output.len();
414    while index < len {
415        if output[index] != b'p' {
416            break;
417        }
418        index += 1;
419
420        let start_pid = index;
421        while index < len && output[index] != 0 {
422            index += 1;
423        }
424
425        let Some(pid) = String::from_utf8_lossy(&output[start_pid..index])
426            .parse::<u32>()
427            .ok()
428        else {
429            break;
430        };
431        index += 1; // 0
432        index += 1; // NL
433
434        loop {
435            if pid == find_pid && index < len && output[index] == b'n' {
436                while index < len && output[index] != b':' {
437                    index += 1;
438                }
439                index += 1; // :
440
441                let start_port = index;
442                while index < len && output[index] != 0 {
443                    index += 1;
444                }
445
446                if index >= len {
447                    break;
448                }
449
450                if let Ok(port) = String::from_utf8_lossy(&output[start_port..index]).parse::<u16>()
451                {
452                    out.push(port);
453                };
454                index += 1; // 0
455            } else {
456                while index < len && output[index] != 0 {
457                    index += 1;
458                }
459                index += 1; // 0
460            }
461
462            if index < len && output[index] == 10 {
463                // NL
464                index += 1;
465            }
466
467            if index >= len || output[index] == b'p' {
468                break;
469            }
470        }
471    }
472
473    out
474}
475
476#[cfg(target_os = "macos")]
477fn find_ports_v6(output: Vec<u8>, find_pid: Pid) -> Vec<u16> {
478    let mut out = Vec::new();
479
480    let mut index = 0;
481    let len = output.len();
482    while index < len {
483        if output[index] != b'p' {
484            break;
485        }
486        index += 1;
487
488        let start_pid = index;
489        while index < len && output[index] != 0 {
490            index += 1;
491        }
492
493        let Ok(pid) = String::from_utf8_lossy(&output[start_pid..index]).parse::<u32>() else {
494            break;
495        };
496        index += 1; // 0
497        index += 1; // NL
498
499        loop {
500            if pid == find_pid && index < len && output[index] == b'n' {
501                while index < len && output[index] != b']' {
502                    index += 1;
503                }
504                index += 1; // ]
505
506                if index < len && output[index] != b':' {
507                    break;
508                }
509                index += 1;
510
511                let start_port = index;
512                while index < len && output[index] != 0 {
513                    index += 1;
514                }
515
516                if index >= len {
517                    break;
518                }
519
520                if let Ok(port) = String::from_utf8_lossy(&output[start_port..index]).parse::<u16>()
521                {
522                    out.push(port);
523                };
524                index += 1; // 0
525            } else {
526                while index < len && output[index] != 0 {
527                    index += 1;
528                }
529                index += 1; // 0
530            }
531
532            if index < len && output[index] == 10 {
533                // NL
534                index += 1;
535            }
536
537            if index >= len || output[index] == b'p' {
538                break;
539            }
540        }
541    }
542
543    out
544}
545
546#[cfg(any(
547    target_os = "linux",
548    target_os = "windows",
549    target_os = "macos",
550    feature = "proc"
551))]
552impl crate::common::MaybeHasPid for PortQuery {
553    fn get_pid(&self) -> Option<Pid> {
554        self.process_id
555    }
556}
557
558impl Default for PortQuery {
559    fn default() -> Self {
560        PortQuery::new()
561    }
562}