1use std::io::{self, Read, Write};
7use std::net::{TcpStream, ToSocketAddrs};
8use std::sync::mpsc;
9use std::thread;
10use std::time::Duration;
11
12use crate::query::Query;
13
14#[derive(Debug, thiserror::Error)]
16pub enum FingerError {
17 #[error("could not resolve host '{host}': {source}")]
19 DnsResolution {
20 host: String,
21 #[source]
22 source: io::Error,
23 },
24
25 #[error("could not connect to {host}:{port}: {source}")]
27 ConnectionFailed {
28 host: String,
29 port: u16,
30 #[source]
31 source: io::Error,
32 },
33
34 #[error("connection to {host}:{port} timed out")]
36 Timeout { host: String, port: u16 },
37
38 #[error("failed to send query: {source}")]
40 SendFailed {
41 #[source]
42 source: io::Error,
43 },
44
45 #[error("failed to read response: {source}")]
47 ReadFailed {
48 #[source]
49 source: io::Error,
50 },
51}
52
53pub fn build_query_string(query: &Query) -> String {
71 let mut result = String::new();
72
73 if query.long {
75 result.push_str("/W ");
76 }
77
78 if let Some(ref user) = query.user {
80 result.push_str(user);
81 }
82
83 if query.hosts.len() > 1 {
86 for host in &query.hosts[..query.hosts.len() - 1] {
87 result.push('@');
88 result.push_str(host);
89 }
90 }
91
92 result.push_str("\r\n");
93 result
94}
95
96fn connect_to_addr(
98 addr: std::net::SocketAddr,
99 host: &str,
100 port: u16,
101 timeout: Duration,
102) -> Result<TcpStream, FingerError> {
103 TcpStream::connect_timeout(&addr, timeout).map_err(|e| {
104 if e.kind() == io::ErrorKind::TimedOut {
105 FingerError::Timeout {
106 host: host.to_string(),
107 port,
108 }
109 } else {
110 FingerError::ConnectionFailed {
111 host: host.to_string(),
112 port,
113 source: e,
114 }
115 }
116 })
117}
118
119pub fn finger_raw(
137 query: &Query,
138 timeout: Duration,
139 max_response_size: u64,
140) -> Result<Vec<u8>, FingerError> {
141 let host = query.target_host();
142 let addr_str = format!("{}:{}", host, query.port);
143
144 let addrs: Vec<std::net::SocketAddr> = addr_str
146 .to_socket_addrs()
147 .map_err(|e| FingerError::DnsResolution {
148 host: host.to_string(),
149 source: e,
150 })?
151 .collect();
152
153 if addrs.is_empty() {
154 return Err(FingerError::DnsResolution {
155 host: host.to_string(),
156 source: io::Error::new(io::ErrorKind::NotFound, "no addresses found"),
157 });
158 }
159
160 let mut stream = if addrs.len() == 1 {
162 connect_to_addr(addrs[0], host, query.port, timeout)?
163 } else {
164 let (tx, rx) = mpsc::channel();
165 let addr_count = addrs.len();
166
167 for addr in addrs {
168 let tx = tx.clone();
169 thread::spawn(move || {
170 let result = TcpStream::connect_timeout(&addr, timeout);
171 let _ = tx.send(result);
172 });
173 }
174 drop(tx);
175
176 let mut last_err = None;
177 let mut winner = None;
178 for _ in 0..addr_count {
179 match rx.recv() {
180 Ok(Ok(s)) => {
181 winner = Some(s);
182 break;
183 }
184 Ok(Err(e)) => {
185 last_err = Some(e);
186 }
187 Err(_) => break,
188 }
189 }
190
191 match winner {
192 Some(s) => s,
193 None => {
194 let e = last_err.unwrap_or_else(|| {
195 io::Error::new(io::ErrorKind::ConnectionRefused, "all addresses failed")
196 });
197 if e.kind() == io::ErrorKind::TimedOut {
198 return Err(FingerError::Timeout {
199 host: host.to_string(),
200 port: query.port,
201 });
202 } else {
203 return Err(FingerError::ConnectionFailed {
204 host: host.to_string(),
205 port: query.port,
206 source: e,
207 });
208 }
209 }
210 }
211 };
212
213 stream.set_read_timeout(Some(timeout)).ok();
215 stream.set_write_timeout(Some(timeout)).ok();
216
217 let query_string = build_query_string(query);
219 stream.write_all(query_string.as_bytes()).map_err(|e| {
220 if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
221 FingerError::Timeout {
222 host: host.to_string(),
223 port: query.port,
224 }
225 } else {
226 FingerError::SendFailed { source: e }
227 }
228 })?;
229
230 let mut buf = Vec::new();
232 stream
233 .take(max_response_size)
234 .read_to_end(&mut buf)
235 .map_err(|e| {
236 if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
237 FingerError::Timeout {
238 host: host.to_string(),
239 port: query.port,
240 }
241 } else {
242 FingerError::ReadFailed { source: e }
243 }
244 })?;
245
246 Ok(buf)
247}
248
249pub fn finger(
267 query: &Query,
268 timeout: Duration,
269 max_response_size: u64,
270) -> Result<String, FingerError> {
271 let bytes = finger_raw(query, timeout, max_response_size)?;
272 Ok(String::from_utf8_lossy(&bytes).into_owned())
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::query::Query;
279
280 #[test]
281 fn query_string_user_at_host() {
282 let q = Query::parse(Some("user@host"), false, 79).unwrap();
283 assert_eq!(build_query_string(&q), "user\r\n");
284 }
285
286 #[test]
287 fn query_string_list_users() {
288 let q = Query::parse(Some("@host"), false, 79).unwrap();
289 assert_eq!(build_query_string(&q), "\r\n");
290 }
291
292 #[test]
293 fn query_string_verbose_user() {
294 let q = Query::parse(Some("user@host"), true, 79).unwrap();
295 assert_eq!(build_query_string(&q), "/W user\r\n");
296 }
297
298 #[test]
299 fn query_string_verbose_list() {
300 let q = Query::parse(Some("@host"), true, 79).unwrap();
301 assert_eq!(build_query_string(&q), "/W \r\n");
302 }
303
304 #[test]
305 fn query_string_forwarding() {
306 let q = Query::parse(Some("user@host1@host2"), false, 79).unwrap();
307 assert_eq!(build_query_string(&q), "user@host1\r\n");
308 }
309
310 #[test]
311 fn query_string_forwarding_verbose() {
312 let q = Query::parse(Some("user@host1@host2"), true, 79).unwrap();
313 assert_eq!(build_query_string(&q), "/W user@host1\r\n");
314 }
315
316 #[test]
317 fn query_string_forwarding_no_user() {
318 let q = Query::parse(Some("@host1@host2"), false, 79).unwrap();
319 assert_eq!(build_query_string(&q), "@host1\r\n");
320 }
321
322 #[test]
323 fn query_string_three_host_chain() {
324 let q = Query::parse(Some("user@a@b@c"), false, 79).unwrap();
325 assert_eq!(build_query_string(&q), "user@a@b\r\n");
326 }
327
328 #[test]
329 fn query_string_localhost_user() {
330 let q = Query::parse(Some("user"), false, 79).unwrap();
331 assert_eq!(build_query_string(&q), "user\r\n");
332 }
333
334 #[test]
335 fn query_string_localhost_list() {
336 let q = Query::parse(None, false, 79).unwrap();
337 assert_eq!(build_query_string(&q), "\r\n");
338 }
339}