digit-cli 0.2.0

A finger protocol client (RFC 1288 / RFC 742)
Documentation
use std::io::{self, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

use crate::query::Query;

/// Errors that can occur during a finger protocol exchange.
#[derive(Debug, thiserror::Error)]
pub enum FingerError {
    /// Failed to resolve the hostname.
    #[error("could not resolve host '{host}': {source}")]
    DnsResolution {
        host: String,
        #[source]
        source: io::Error,
    },

    /// Failed to connect to the remote host.
    #[error("could not connect to {host}:{port}: {source}")]
    ConnectionFailed {
        host: String,
        port: u16,
        #[source]
        source: io::Error,
    },

    /// Connection timed out.
    #[error("connection to {host}:{port} timed out")]
    Timeout { host: String, port: u16 },

    /// Failed to send the query.
    #[error("failed to send query: {source}")]
    SendFailed {
        #[source]
        source: io::Error,
    },

    /// Failed to read the response.
    #[error("failed to read response: {source}")]
    ReadFailed {
        #[source]
        source: io::Error,
    },
}

/// Build the wire-format query string to send over the TCP connection.
///
/// Per RFC 1288:
/// - Verbose queries prepend `/W ` (with trailing space).
/// - Forwarding appends `@host1@host2...` for all hosts except the last
///   (the last host is the connection target, not part of the query string).
/// - The query is terminated with `\r\n`.
pub fn build_query_string(query: &Query) -> String {
    let mut result = String::new();

    // Verbose prefix per RFC 1288.
    if query.long {
        result.push_str("/W ");
    }

    // User portion.
    if let Some(ref user) = query.user {
        result.push_str(user);
    }

    // Forwarding: include all hosts except the last (the connection target).
    // These become @host1@host2... in the query string.
    if query.hosts.len() > 1 {
        for host in &query.hosts[..query.hosts.len() - 1] {
            result.push('@');
            result.push_str(host);
        }
    }

    result.push_str("\r\n");
    result
}

/// Attempt to connect to a single socket address with a timeout.
fn connect_to_addr(
    addr: std::net::SocketAddr,
    host: &str,
    port: u16,
    timeout: Duration,
) -> Result<TcpStream, FingerError> {
    TcpStream::connect_timeout(&addr, timeout).map_err(|e| {
        if e.kind() == io::ErrorKind::TimedOut {
            FingerError::Timeout {
                host: host.to_string(),
                port,
            }
        } else {
            FingerError::ConnectionFailed {
                host: host.to_string(),
                port,
                source: e,
            }
        }
    })
}

/// Execute a finger query over TCP.
///
/// Connects to the target host, sends the query string, reads the full
/// response until the server closes the connection, and returns the
/// response as a string. Invalid UTF-8 bytes are replaced with U+FFFD.
pub fn finger(query: &Query, timeout: Duration) -> Result<String, FingerError> {
    let host = query.target_host();
    let addr_str = format!("{}:{}", host, query.port);

    // Resolve hostname to socket addresses.
    let addrs: Vec<std::net::SocketAddr> = addr_str
        .to_socket_addrs()
        .map_err(|e| FingerError::DnsResolution {
            host: host.to_string(),
            source: e,
        })?
        .collect();

    if addrs.is_empty() {
        return Err(FingerError::DnsResolution {
            host: host.to_string(),
            source: io::Error::new(io::ErrorKind::NotFound, "no addresses found"),
        });
    }

    // Connect: single address connects directly, multiple race in parallel.
    let mut stream = if addrs.len() == 1 {
        connect_to_addr(addrs[0], host, query.port, timeout)?
    } else {
        let (tx, rx) = mpsc::channel();
        let addr_count = addrs.len();

        for addr in addrs {
            let tx = tx.clone();
            thread::spawn(move || {
                let result = TcpStream::connect_timeout(&addr, timeout);
                let _ = tx.send(result);
            });
        }
        drop(tx);

        let mut last_err = None;
        let mut winner = None;
        for _ in 0..addr_count {
            match rx.recv() {
                Ok(Ok(s)) => {
                    winner = Some(s);
                    break;
                }
                Ok(Err(e)) => {
                    last_err = Some(e);
                }
                Err(_) => break,
            }
        }

        match winner {
            Some(s) => s,
            None => {
                let e = last_err.unwrap_or_else(|| {
                    io::Error::new(io::ErrorKind::ConnectionRefused, "all addresses failed")
                });
                if e.kind() == io::ErrorKind::TimedOut {
                    return Err(FingerError::Timeout {
                        host: host.to_string(),
                        port: query.port,
                    });
                } else {
                    return Err(FingerError::ConnectionFailed {
                        host: host.to_string(),
                        port: query.port,
                        source: e,
                    });
                }
            }
        }
    };

    // Set read/write timeouts on the connected socket.
    stream.set_read_timeout(Some(timeout)).ok();
    stream.set_write_timeout(Some(timeout)).ok();

    // Send the query.
    let query_string = build_query_string(query);
    stream.write_all(query_string.as_bytes()).map_err(|e| {
        if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
            FingerError::Timeout {
                host: host.to_string(),
                port: query.port,
            }
        } else {
            FingerError::SendFailed { source: e }
        }
    })?;

    // Read the full response.
    let mut buf = Vec::new();
    stream.read_to_end(&mut buf).map_err(|e| {
        if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
            FingerError::Timeout {
                host: host.to_string(),
                port: query.port,
            }
        } else {
            FingerError::ReadFailed { source: e }
        }
    })?;

    Ok(String::from_utf8_lossy(&buf).into_owned())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::query::Query;

    #[test]
    fn query_string_user_at_host() {
        let q = Query::parse(Some("user@host"), false, 79).unwrap();
        assert_eq!(build_query_string(&q), "user\r\n");
    }

    #[test]
    fn query_string_list_users() {
        let q = Query::parse(Some("@host"), false, 79).unwrap();
        assert_eq!(build_query_string(&q), "\r\n");
    }

    #[test]
    fn query_string_verbose_user() {
        let q = Query::parse(Some("user@host"), true, 79).unwrap();
        assert_eq!(build_query_string(&q), "/W user\r\n");
    }

    #[test]
    fn query_string_verbose_list() {
        let q = Query::parse(Some("@host"), true, 79).unwrap();
        assert_eq!(build_query_string(&q), "/W \r\n");
    }

    #[test]
    fn query_string_forwarding() {
        let q = Query::parse(Some("user@host1@host2"), false, 79).unwrap();
        assert_eq!(build_query_string(&q), "user@host1\r\n");
    }

    #[test]
    fn query_string_forwarding_verbose() {
        let q = Query::parse(Some("user@host1@host2"), true, 79).unwrap();
        assert_eq!(build_query_string(&q), "/W user@host1\r\n");
    }

    #[test]
    fn query_string_forwarding_no_user() {
        let q = Query::parse(Some("@host1@host2"), false, 79).unwrap();
        assert_eq!(build_query_string(&q), "@host1\r\n");
    }

    #[test]
    fn query_string_three_host_chain() {
        let q = Query::parse(Some("user@a@b@c"), false, 79).unwrap();
        assert_eq!(build_query_string(&q), "user@a@b\r\n");
    }

    #[test]
    fn query_string_localhost_user() {
        let q = Query::parse(Some("user"), false, 79).unwrap();
        assert_eq!(build_query_string(&q), "user\r\n");
    }

    #[test]
    fn query_string_localhost_list() {
        let q = Query::parse(None, false, 79).unwrap();
        assert_eq!(build_query_string(&q), "\r\n");
    }
}