Skip to main content

digit/
protocol.rs

1//! Finger protocol wire format and TCP communication.
2//!
3//! Provides functions to build RFC 1288 query strings and execute finger
4//! queries over TCP with configurable timeouts and response size limits.
5
6use std::io::{self, Read, Write};
7use std::net::{TcpStream, ToSocketAddrs};
8use std::sync::mpsc;
9use std::thread;
10use std::time::Duration;
11
12use crate::query::Query;
13
14/// Errors that can occur during a finger protocol exchange.
15#[derive(Debug, thiserror::Error)]
16pub enum FingerError {
17    /// Failed to resolve the hostname.
18    #[error("could not resolve host '{host}': {source}")]
19    DnsResolution {
20        host: String,
21        #[source]
22        source: io::Error,
23    },
24
25    /// Failed to connect to the remote host.
26    #[error("could not connect to {host}:{port}: {source}")]
27    ConnectionFailed {
28        host: String,
29        port: u16,
30        #[source]
31        source: io::Error,
32    },
33
34    /// Connection timed out.
35    #[error("connection to {host}:{port} timed out")]
36    Timeout { host: String, port: u16 },
37
38    /// Failed to send the query.
39    #[error("failed to send query: {source}")]
40    SendFailed {
41        #[source]
42        source: io::Error,
43    },
44
45    /// Failed to read the response.
46    #[error("failed to read response: {source}")]
47    ReadFailed {
48        #[source]
49        source: io::Error,
50    },
51}
52
53/// Build the wire-format query string to send over the TCP connection.
54///
55/// Per RFC 1288:
56/// - Verbose queries prepend `/W ` (with trailing space).
57/// - Forwarding appends `@host1@host2...` for all hosts except the last
58///   (the last host is the connection target, not part of the query string).
59/// - The query is terminated with `\r\n`.
60///
61/// # Example
62///
63/// ```
64/// use digit::query::Query;
65/// use digit::protocol::build_query_string;
66///
67/// let q = Query::parse(Some("user@host"), true, 79).unwrap();
68/// assert_eq!(build_query_string(&q), "/W user\r\n");
69/// ```
70pub fn build_query_string(query: &Query) -> String {
71    let mut result = String::new();
72
73    // Verbose prefix per RFC 1288.
74    if query.long {
75        result.push_str("/W ");
76    }
77
78    // User portion.
79    if let Some(ref user) = query.user {
80        result.push_str(user);
81    }
82
83    // Forwarding: include all hosts except the last (the connection target).
84    // These become @host1@host2... in the query string.
85    if query.hosts.len() > 1 {
86        for host in &query.hosts[..query.hosts.len() - 1] {
87            result.push('@');
88            result.push_str(host);
89        }
90    }
91
92    result.push_str("\r\n");
93    result
94}
95
96/// Attempt to connect to a single socket address with a timeout.
97fn connect_to_addr(
98    addr: std::net::SocketAddr,
99    host: &str,
100    port: u16,
101    timeout: Duration,
102) -> Result<TcpStream, FingerError> {
103    TcpStream::connect_timeout(&addr, timeout).map_err(|e| {
104        if e.kind() == io::ErrorKind::TimedOut {
105            FingerError::Timeout {
106                host: host.to_string(),
107                port,
108            }
109        } else {
110            FingerError::ConnectionFailed {
111                host: host.to_string(),
112                port,
113                source: e,
114            }
115        }
116    })
117}
118
119/// Execute a finger query over TCP, returning raw bytes.
120///
121/// Like [`finger`], but returns the raw response bytes without UTF-8 decoding.
122/// The response is capped at `max_response_size` bytes. Useful for piping
123/// output to other tools or when the response contains non-UTF-8 data.
124///
125/// # Example
126///
127/// ```no_run
128/// use std::time::Duration;
129/// use digit::query::Query;
130/// use digit::protocol::finger_raw;
131///
132/// let q = Query::parse(Some("user@example.com"), false, 79).unwrap();
133/// let bytes = finger_raw(&q, Duration::from_secs(10), 1_048_576).unwrap();
134/// std::io::Write::write_all(&mut std::io::stdout(), &bytes).unwrap();
135/// ```
136pub fn finger_raw(
137    query: &Query,
138    timeout: Duration,
139    max_response_size: u64,
140) -> Result<Vec<u8>, FingerError> {
141    let host = query.target_host();
142    let addr_str = format!("{}:{}", host, query.port);
143
144    // Resolve hostname to socket addresses.
145    let addrs: Vec<std::net::SocketAddr> = addr_str
146        .to_socket_addrs()
147        .map_err(|e| FingerError::DnsResolution {
148            host: host.to_string(),
149            source: e,
150        })?
151        .collect();
152
153    if addrs.is_empty() {
154        return Err(FingerError::DnsResolution {
155            host: host.to_string(),
156            source: io::Error::new(io::ErrorKind::NotFound, "no addresses found"),
157        });
158    }
159
160    // Connect: single address connects directly, multiple race in parallel.
161    let mut stream = if addrs.len() == 1 {
162        connect_to_addr(addrs[0], host, query.port, timeout)?
163    } else {
164        let (tx, rx) = mpsc::channel();
165        let addr_count = addrs.len();
166
167        for addr in addrs {
168            let tx = tx.clone();
169            thread::spawn(move || {
170                let result = TcpStream::connect_timeout(&addr, timeout);
171                let _ = tx.send(result);
172            });
173        }
174        drop(tx);
175
176        let mut last_err = None;
177        let mut winner = None;
178        for _ in 0..addr_count {
179            match rx.recv() {
180                Ok(Ok(s)) => {
181                    winner = Some(s);
182                    break;
183                }
184                Ok(Err(e)) => {
185                    last_err = Some(e);
186                }
187                Err(_) => break,
188            }
189        }
190
191        match winner {
192            Some(s) => s,
193            None => {
194                let e = last_err.unwrap_or_else(|| {
195                    io::Error::new(io::ErrorKind::ConnectionRefused, "all addresses failed")
196                });
197                if e.kind() == io::ErrorKind::TimedOut {
198                    return Err(FingerError::Timeout {
199                        host: host.to_string(),
200                        port: query.port,
201                    });
202                } else {
203                    return Err(FingerError::ConnectionFailed {
204                        host: host.to_string(),
205                        port: query.port,
206                        source: e,
207                    });
208                }
209            }
210        }
211    };
212
213    // Set read/write timeouts on the connected socket.
214    stream.set_read_timeout(Some(timeout)).ok();
215    stream.set_write_timeout(Some(timeout)).ok();
216
217    // Send the query.
218    let query_string = build_query_string(query);
219    stream.write_all(query_string.as_bytes()).map_err(|e| {
220        if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
221            FingerError::Timeout {
222                host: host.to_string(),
223                port: query.port,
224            }
225        } else {
226            FingerError::SendFailed { source: e }
227        }
228    })?;
229
230    // Read the response, capped at max_response_size bytes.
231    let mut buf = Vec::new();
232    stream
233        .take(max_response_size)
234        .read_to_end(&mut buf)
235        .map_err(|e| {
236            if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
237                FingerError::Timeout {
238                    host: host.to_string(),
239                    port: query.port,
240                }
241            } else {
242                FingerError::ReadFailed { source: e }
243            }
244        })?;
245
246    Ok(buf)
247}
248
249/// Execute a finger query over TCP, returning the response as a string.
250///
251/// Calls [`finger_raw`] and decodes the response with [`String::from_utf8_lossy`].
252/// Invalid UTF-8 bytes are replaced with U+FFFD. The response is capped at
253/// `max_response_size` bytes.
254///
255/// # Example
256///
257/// ```no_run
258/// use std::time::Duration;
259/// use digit::query::Query;
260/// use digit::protocol::finger;
261///
262/// let q = Query::parse(Some("user@example.com"), false, 79).unwrap();
263/// let response = finger(&q, Duration::from_secs(10), 1_048_576).unwrap();
264/// println!("{}", response);
265/// ```
266pub fn finger(
267    query: &Query,
268    timeout: Duration,
269    max_response_size: u64,
270) -> Result<String, FingerError> {
271    let bytes = finger_raw(query, timeout, max_response_size)?;
272    Ok(String::from_utf8_lossy(&bytes).into_owned())
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::query::Query;
279
280    #[test]
281    fn query_string_user_at_host() {
282        let q = Query::parse(Some("user@host"), false, 79).unwrap();
283        assert_eq!(build_query_string(&q), "user\r\n");
284    }
285
286    #[test]
287    fn query_string_list_users() {
288        let q = Query::parse(Some("@host"), false, 79).unwrap();
289        assert_eq!(build_query_string(&q), "\r\n");
290    }
291
292    #[test]
293    fn query_string_verbose_user() {
294        let q = Query::parse(Some("user@host"), true, 79).unwrap();
295        assert_eq!(build_query_string(&q), "/W user\r\n");
296    }
297
298    #[test]
299    fn query_string_verbose_list() {
300        let q = Query::parse(Some("@host"), true, 79).unwrap();
301        assert_eq!(build_query_string(&q), "/W \r\n");
302    }
303
304    #[test]
305    fn query_string_forwarding() {
306        let q = Query::parse(Some("user@host1@host2"), false, 79).unwrap();
307        assert_eq!(build_query_string(&q), "user@host1\r\n");
308    }
309
310    #[test]
311    fn query_string_forwarding_verbose() {
312        let q = Query::parse(Some("user@host1@host2"), true, 79).unwrap();
313        assert_eq!(build_query_string(&q), "/W user@host1\r\n");
314    }
315
316    #[test]
317    fn query_string_forwarding_no_user() {
318        let q = Query::parse(Some("@host1@host2"), false, 79).unwrap();
319        assert_eq!(build_query_string(&q), "@host1\r\n");
320    }
321
322    #[test]
323    fn query_string_three_host_chain() {
324        let q = Query::parse(Some("user@a@b@c"), false, 79).unwrap();
325        assert_eq!(build_query_string(&q), "user@a@b\r\n");
326    }
327
328    #[test]
329    fn query_string_localhost_user() {
330        let q = Query::parse(Some("user"), false, 79).unwrap();
331        assert_eq!(build_query_string(&q), "user\r\n");
332    }
333
334    #[test]
335    fn query_string_localhost_list() {
336        let q = Query::parse(None, false, 79).unwrap();
337        assert_eq!(build_query_string(&q), "\r\n");
338    }
339}