use std::io::{self, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use crate::query::Query;
#[derive(Debug, thiserror::Error)]
pub enum FingerError {
#[error("could not resolve host '{host}': {source}")]
DnsResolution {
host: String,
#[source]
source: io::Error,
},
#[error("could not connect to {host}:{port}: {source}")]
ConnectionFailed {
host: String,
port: u16,
#[source]
source: io::Error,
},
#[error("connection to {host}:{port} timed out")]
Timeout { host: String, port: u16 },
#[error("failed to send query: {source}")]
SendFailed {
#[source]
source: io::Error,
},
#[error("failed to read response: {source}")]
ReadFailed {
#[source]
source: io::Error,
},
}
pub fn build_query_string(query: &Query) -> String {
let mut result = String::new();
if query.long {
result.push_str("/W ");
}
if let Some(ref user) = query.user {
result.push_str(user);
}
if query.hosts.len() > 1 {
for host in &query.hosts[..query.hosts.len() - 1] {
result.push('@');
result.push_str(host);
}
}
result.push_str("\r\n");
result
}
fn connect_to_addr(
addr: std::net::SocketAddr,
host: &str,
port: u16,
timeout: Duration,
) -> Result<TcpStream, FingerError> {
TcpStream::connect_timeout(&addr, timeout).map_err(|e| {
if e.kind() == io::ErrorKind::TimedOut {
FingerError::Timeout {
host: host.to_string(),
port,
}
} else {
FingerError::ConnectionFailed {
host: host.to_string(),
port,
source: e,
}
}
})
}
pub fn finger(query: &Query, timeout: Duration) -> Result<String, FingerError> {
let host = query.target_host();
let addr_str = format!("{}:{}", host, query.port);
let addrs: Vec<std::net::SocketAddr> = addr_str
.to_socket_addrs()
.map_err(|e| FingerError::DnsResolution {
host: host.to_string(),
source: e,
})?
.collect();
if addrs.is_empty() {
return Err(FingerError::DnsResolution {
host: host.to_string(),
source: io::Error::new(io::ErrorKind::NotFound, "no addresses found"),
});
}
let mut stream = if addrs.len() == 1 {
connect_to_addr(addrs[0], host, query.port, timeout)?
} else {
let (tx, rx) = mpsc::channel();
let addr_count = addrs.len();
for addr in addrs {
let tx = tx.clone();
thread::spawn(move || {
let result = TcpStream::connect_timeout(&addr, timeout);
let _ = tx.send(result);
});
}
drop(tx);
let mut last_err = None;
let mut winner = None;
for _ in 0..addr_count {
match rx.recv() {
Ok(Ok(s)) => {
winner = Some(s);
break;
}
Ok(Err(e)) => {
last_err = Some(e);
}
Err(_) => break,
}
}
match winner {
Some(s) => s,
None => {
let e = last_err.unwrap_or_else(|| {
io::Error::new(io::ErrorKind::ConnectionRefused, "all addresses failed")
});
if e.kind() == io::ErrorKind::TimedOut {
return Err(FingerError::Timeout {
host: host.to_string(),
port: query.port,
});
} else {
return Err(FingerError::ConnectionFailed {
host: host.to_string(),
port: query.port,
source: e,
});
}
}
}
};
stream.set_read_timeout(Some(timeout)).ok();
stream.set_write_timeout(Some(timeout)).ok();
let query_string = build_query_string(query);
stream.write_all(query_string.as_bytes()).map_err(|e| {
if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
FingerError::Timeout {
host: host.to_string(),
port: query.port,
}
} else {
FingerError::SendFailed { source: e }
}
})?;
let mut buf = Vec::new();
stream.read_to_end(&mut buf).map_err(|e| {
if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
FingerError::Timeout {
host: host.to_string(),
port: query.port,
}
} else {
FingerError::ReadFailed { source: e }
}
})?;
Ok(String::from_utf8_lossy(&buf).into_owned())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::Query;
#[test]
fn query_string_user_at_host() {
let q = Query::parse(Some("user@host"), false, 79).unwrap();
assert_eq!(build_query_string(&q), "user\r\n");
}
#[test]
fn query_string_list_users() {
let q = Query::parse(Some("@host"), false, 79).unwrap();
assert_eq!(build_query_string(&q), "\r\n");
}
#[test]
fn query_string_verbose_user() {
let q = Query::parse(Some("user@host"), true, 79).unwrap();
assert_eq!(build_query_string(&q), "/W user\r\n");
}
#[test]
fn query_string_verbose_list() {
let q = Query::parse(Some("@host"), true, 79).unwrap();
assert_eq!(build_query_string(&q), "/W \r\n");
}
#[test]
fn query_string_forwarding() {
let q = Query::parse(Some("user@host1@host2"), false, 79).unwrap();
assert_eq!(build_query_string(&q), "user@host1\r\n");
}
#[test]
fn query_string_forwarding_verbose() {
let q = Query::parse(Some("user@host1@host2"), true, 79).unwrap();
assert_eq!(build_query_string(&q), "/W user@host1\r\n");
}
#[test]
fn query_string_forwarding_no_user() {
let q = Query::parse(Some("@host1@host2"), false, 79).unwrap();
assert_eq!(build_query_string(&q), "@host1\r\n");
}
#[test]
fn query_string_three_host_chain() {
let q = Query::parse(Some("user@a@b@c"), false, 79).unwrap();
assert_eq!(build_query_string(&q), "user@a@b\r\n");
}
#[test]
fn query_string_localhost_user() {
let q = Query::parse(Some("user"), false, 79).unwrap();
assert_eq!(build_query_string(&q), "user\r\n");
}
#[test]
fn query_string_localhost_list() {
let q = Query::parse(None, false, 79).unwrap();
assert_eq!(build_query_string(&q), "\r\n");
}
}