Skip to main content

digit/
protocol.rs

1use std::io::{self, Read, Write};
2use std::net::{TcpStream, ToSocketAddrs};
3use std::sync::mpsc;
4use std::thread;
5use std::time::Duration;
6
7use crate::query::Query;
8
9/// Errors that can occur during a finger protocol exchange.
10#[derive(Debug, thiserror::Error)]
11pub enum FingerError {
12    /// Failed to resolve the hostname.
13    #[error("could not resolve host '{host}': {source}")]
14    DnsResolution {
15        host: String,
16        #[source]
17        source: io::Error,
18    },
19
20    /// Failed to connect to the remote host.
21    #[error("could not connect to {host}:{port}: {source}")]
22    ConnectionFailed {
23        host: String,
24        port: u16,
25        #[source]
26        source: io::Error,
27    },
28
29    /// Connection timed out.
30    #[error("connection to {host}:{port} timed out")]
31    Timeout { host: String, port: u16 },
32
33    /// Failed to send the query.
34    #[error("failed to send query: {source}")]
35    SendFailed {
36        #[source]
37        source: io::Error,
38    },
39
40    /// Failed to read the response.
41    #[error("failed to read response: {source}")]
42    ReadFailed {
43        #[source]
44        source: io::Error,
45    },
46}
47
48/// Build the wire-format query string to send over the TCP connection.
49///
50/// Per RFC 1288:
51/// - Verbose queries prepend `/W ` (with trailing space).
52/// - Forwarding appends `@host1@host2...` for all hosts except the last
53///   (the last host is the connection target, not part of the query string).
54/// - The query is terminated with `\r\n`.
55pub fn build_query_string(query: &Query) -> String {
56    let mut result = String::new();
57
58    // Verbose prefix per RFC 1288.
59    if query.long {
60        result.push_str("/W ");
61    }
62
63    // User portion.
64    if let Some(ref user) = query.user {
65        result.push_str(user);
66    }
67
68    // Forwarding: include all hosts except the last (the connection target).
69    // These become @host1@host2... in the query string.
70    if query.hosts.len() > 1 {
71        for host in &query.hosts[..query.hosts.len() - 1] {
72            result.push('@');
73            result.push_str(host);
74        }
75    }
76
77    result.push_str("\r\n");
78    result
79}
80
81/// Attempt to connect to a single socket address with a timeout.
82fn connect_to_addr(
83    addr: std::net::SocketAddr,
84    host: &str,
85    port: u16,
86    timeout: Duration,
87) -> Result<TcpStream, FingerError> {
88    TcpStream::connect_timeout(&addr, timeout).map_err(|e| {
89        if e.kind() == io::ErrorKind::TimedOut {
90            FingerError::Timeout {
91                host: host.to_string(),
92                port,
93            }
94        } else {
95            FingerError::ConnectionFailed {
96                host: host.to_string(),
97                port,
98                source: e,
99            }
100        }
101    })
102}
103
104/// Execute a finger query over TCP.
105///
106/// Connects to the target host, sends the query string, reads the full
107/// response until the server closes the connection, and returns the
108/// response as a string. Invalid UTF-8 bytes are replaced with U+FFFD.
109pub fn finger(query: &Query, timeout: Duration) -> Result<String, FingerError> {
110    let host = query.target_host();
111    let addr_str = format!("{}:{}", host, query.port);
112
113    // Resolve hostname to socket addresses.
114    let addrs: Vec<std::net::SocketAddr> = addr_str
115        .to_socket_addrs()
116        .map_err(|e| FingerError::DnsResolution {
117            host: host.to_string(),
118            source: e,
119        })?
120        .collect();
121
122    if addrs.is_empty() {
123        return Err(FingerError::DnsResolution {
124            host: host.to_string(),
125            source: io::Error::new(io::ErrorKind::NotFound, "no addresses found"),
126        });
127    }
128
129    // Connect: single address connects directly, multiple race in parallel.
130    let mut stream = if addrs.len() == 1 {
131        connect_to_addr(addrs[0], host, query.port, timeout)?
132    } else {
133        let (tx, rx) = mpsc::channel();
134        let addr_count = addrs.len();
135
136        for addr in addrs {
137            let tx = tx.clone();
138            thread::spawn(move || {
139                let result = TcpStream::connect_timeout(&addr, timeout);
140                let _ = tx.send(result);
141            });
142        }
143        drop(tx);
144
145        let mut last_err = None;
146        let mut winner = None;
147        for _ in 0..addr_count {
148            match rx.recv() {
149                Ok(Ok(s)) => {
150                    winner = Some(s);
151                    break;
152                }
153                Ok(Err(e)) => {
154                    last_err = Some(e);
155                }
156                Err(_) => break,
157            }
158        }
159
160        match winner {
161            Some(s) => s,
162            None => {
163                let e = last_err.unwrap_or_else(|| {
164                    io::Error::new(io::ErrorKind::ConnectionRefused, "all addresses failed")
165                });
166                if e.kind() == io::ErrorKind::TimedOut {
167                    return Err(FingerError::Timeout {
168                        host: host.to_string(),
169                        port: query.port,
170                    });
171                } else {
172                    return Err(FingerError::ConnectionFailed {
173                        host: host.to_string(),
174                        port: query.port,
175                        source: e,
176                    });
177                }
178            }
179        }
180    };
181
182    // Set read/write timeouts on the connected socket.
183    stream.set_read_timeout(Some(timeout)).ok();
184    stream.set_write_timeout(Some(timeout)).ok();
185
186    // Send the query.
187    let query_string = build_query_string(query);
188    stream.write_all(query_string.as_bytes()).map_err(|e| {
189        if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
190            FingerError::Timeout {
191                host: host.to_string(),
192                port: query.port,
193            }
194        } else {
195            FingerError::SendFailed { source: e }
196        }
197    })?;
198
199    // Read the full response.
200    let mut buf = Vec::new();
201    stream.read_to_end(&mut buf).map_err(|e| {
202        if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
203            FingerError::Timeout {
204                host: host.to_string(),
205                port: query.port,
206            }
207        } else {
208            FingerError::ReadFailed { source: e }
209        }
210    })?;
211
212    Ok(String::from_utf8_lossy(&buf).into_owned())
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::query::Query;
219
220    #[test]
221    fn query_string_user_at_host() {
222        let q = Query::parse(Some("user@host"), false, 79).unwrap();
223        assert_eq!(build_query_string(&q), "user\r\n");
224    }
225
226    #[test]
227    fn query_string_list_users() {
228        let q = Query::parse(Some("@host"), false, 79).unwrap();
229        assert_eq!(build_query_string(&q), "\r\n");
230    }
231
232    #[test]
233    fn query_string_verbose_user() {
234        let q = Query::parse(Some("user@host"), true, 79).unwrap();
235        assert_eq!(build_query_string(&q), "/W user\r\n");
236    }
237
238    #[test]
239    fn query_string_verbose_list() {
240        let q = Query::parse(Some("@host"), true, 79).unwrap();
241        assert_eq!(build_query_string(&q), "/W \r\n");
242    }
243
244    #[test]
245    fn query_string_forwarding() {
246        let q = Query::parse(Some("user@host1@host2"), false, 79).unwrap();
247        assert_eq!(build_query_string(&q), "user@host1\r\n");
248    }
249
250    #[test]
251    fn query_string_forwarding_verbose() {
252        let q = Query::parse(Some("user@host1@host2"), true, 79).unwrap();
253        assert_eq!(build_query_string(&q), "/W user@host1\r\n");
254    }
255
256    #[test]
257    fn query_string_forwarding_no_user() {
258        let q = Query::parse(Some("@host1@host2"), false, 79).unwrap();
259        assert_eq!(build_query_string(&q), "@host1\r\n");
260    }
261
262    #[test]
263    fn query_string_three_host_chain() {
264        let q = Query::parse(Some("user@a@b@c"), false, 79).unwrap();
265        assert_eq!(build_query_string(&q), "user@a@b\r\n");
266    }
267
268    #[test]
269    fn query_string_localhost_user() {
270        let q = Query::parse(Some("user"), false, 79).unwrap();
271        assert_eq!(build_query_string(&q), "user\r\n");
272    }
273
274    #[test]
275    fn query_string_localhost_list() {
276        let q = Query::parse(None, false, 79).unwrap();
277        assert_eq!(build_query_string(&q), "\r\n");
278    }
279}