1use std::io::{self, Read, Write};
2use std::net::{TcpStream, ToSocketAddrs};
3use std::sync::mpsc;
4use std::thread;
5use std::time::Duration;
6
7use crate::query::Query;
8
9#[derive(Debug, thiserror::Error)]
11pub enum FingerError {
12 #[error("could not resolve host '{host}': {source}")]
14 DnsResolution {
15 host: String,
16 #[source]
17 source: io::Error,
18 },
19
20 #[error("could not connect to {host}:{port}: {source}")]
22 ConnectionFailed {
23 host: String,
24 port: u16,
25 #[source]
26 source: io::Error,
27 },
28
29 #[error("connection to {host}:{port} timed out")]
31 Timeout { host: String, port: u16 },
32
33 #[error("failed to send query: {source}")]
35 SendFailed {
36 #[source]
37 source: io::Error,
38 },
39
40 #[error("failed to read response: {source}")]
42 ReadFailed {
43 #[source]
44 source: io::Error,
45 },
46}
47
48pub fn build_query_string(query: &Query) -> String {
56 let mut result = String::new();
57
58 if query.long {
60 result.push_str("/W ");
61 }
62
63 if let Some(ref user) = query.user {
65 result.push_str(user);
66 }
67
68 if query.hosts.len() > 1 {
71 for host in &query.hosts[..query.hosts.len() - 1] {
72 result.push('@');
73 result.push_str(host);
74 }
75 }
76
77 result.push_str("\r\n");
78 result
79}
80
81fn connect_to_addr(
83 addr: std::net::SocketAddr,
84 host: &str,
85 port: u16,
86 timeout: Duration,
87) -> Result<TcpStream, FingerError> {
88 TcpStream::connect_timeout(&addr, timeout).map_err(|e| {
89 if e.kind() == io::ErrorKind::TimedOut {
90 FingerError::Timeout {
91 host: host.to_string(),
92 port,
93 }
94 } else {
95 FingerError::ConnectionFailed {
96 host: host.to_string(),
97 port,
98 source: e,
99 }
100 }
101 })
102}
103
104pub fn finger(query: &Query, timeout: Duration) -> Result<String, FingerError> {
110 let host = query.target_host();
111 let addr_str = format!("{}:{}", host, query.port);
112
113 let addrs: Vec<std::net::SocketAddr> = addr_str
115 .to_socket_addrs()
116 .map_err(|e| FingerError::DnsResolution {
117 host: host.to_string(),
118 source: e,
119 })?
120 .collect();
121
122 if addrs.is_empty() {
123 return Err(FingerError::DnsResolution {
124 host: host.to_string(),
125 source: io::Error::new(io::ErrorKind::NotFound, "no addresses found"),
126 });
127 }
128
129 let mut stream = if addrs.len() == 1 {
131 connect_to_addr(addrs[0], host, query.port, timeout)?
132 } else {
133 let (tx, rx) = mpsc::channel();
134 let addr_count = addrs.len();
135
136 for addr in addrs {
137 let tx = tx.clone();
138 thread::spawn(move || {
139 let result = TcpStream::connect_timeout(&addr, timeout);
140 let _ = tx.send(result);
141 });
142 }
143 drop(tx);
144
145 let mut last_err = None;
146 let mut winner = None;
147 for _ in 0..addr_count {
148 match rx.recv() {
149 Ok(Ok(s)) => {
150 winner = Some(s);
151 break;
152 }
153 Ok(Err(e)) => {
154 last_err = Some(e);
155 }
156 Err(_) => break,
157 }
158 }
159
160 match winner {
161 Some(s) => s,
162 None => {
163 let e = last_err.unwrap_or_else(|| {
164 io::Error::new(io::ErrorKind::ConnectionRefused, "all addresses failed")
165 });
166 if e.kind() == io::ErrorKind::TimedOut {
167 return Err(FingerError::Timeout {
168 host: host.to_string(),
169 port: query.port,
170 });
171 } else {
172 return Err(FingerError::ConnectionFailed {
173 host: host.to_string(),
174 port: query.port,
175 source: e,
176 });
177 }
178 }
179 }
180 };
181
182 stream.set_read_timeout(Some(timeout)).ok();
184 stream.set_write_timeout(Some(timeout)).ok();
185
186 let query_string = build_query_string(query);
188 stream.write_all(query_string.as_bytes()).map_err(|e| {
189 if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
190 FingerError::Timeout {
191 host: host.to_string(),
192 port: query.port,
193 }
194 } else {
195 FingerError::SendFailed { source: e }
196 }
197 })?;
198
199 let mut buf = Vec::new();
201 stream.read_to_end(&mut buf).map_err(|e| {
202 if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
203 FingerError::Timeout {
204 host: host.to_string(),
205 port: query.port,
206 }
207 } else {
208 FingerError::ReadFailed { source: e }
209 }
210 })?;
211
212 Ok(String::from_utf8_lossy(&buf).into_owned())
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::query::Query;
219
220 #[test]
221 fn query_string_user_at_host() {
222 let q = Query::parse(Some("user@host"), false, 79).unwrap();
223 assert_eq!(build_query_string(&q), "user\r\n");
224 }
225
226 #[test]
227 fn query_string_list_users() {
228 let q = Query::parse(Some("@host"), false, 79).unwrap();
229 assert_eq!(build_query_string(&q), "\r\n");
230 }
231
232 #[test]
233 fn query_string_verbose_user() {
234 let q = Query::parse(Some("user@host"), true, 79).unwrap();
235 assert_eq!(build_query_string(&q), "/W user\r\n");
236 }
237
238 #[test]
239 fn query_string_verbose_list() {
240 let q = Query::parse(Some("@host"), true, 79).unwrap();
241 assert_eq!(build_query_string(&q), "/W \r\n");
242 }
243
244 #[test]
245 fn query_string_forwarding() {
246 let q = Query::parse(Some("user@host1@host2"), false, 79).unwrap();
247 assert_eq!(build_query_string(&q), "user@host1\r\n");
248 }
249
250 #[test]
251 fn query_string_forwarding_verbose() {
252 let q = Query::parse(Some("user@host1@host2"), true, 79).unwrap();
253 assert_eq!(build_query_string(&q), "/W user@host1\r\n");
254 }
255
256 #[test]
257 fn query_string_forwarding_no_user() {
258 let q = Query::parse(Some("@host1@host2"), false, 79).unwrap();
259 assert_eq!(build_query_string(&q), "@host1\r\n");
260 }
261
262 #[test]
263 fn query_string_three_host_chain() {
264 let q = Query::parse(Some("user@a@b@c"), false, 79).unwrap();
265 assert_eq!(build_query_string(&q), "user@a@b\r\n");
266 }
267
268 #[test]
269 fn query_string_localhost_user() {
270 let q = Query::parse(Some("user"), false, 79).unwrap();
271 assert_eq!(build_query_string(&q), "user\r\n");
272 }
273
274 #[test]
275 fn query_string_localhost_list() {
276 let q = Query::parse(None, false, 79).unwrap();
277 assert_eq!(build_query_string(&q), "\r\n");
278 }
279}