shoes 0.2.2

A multi-protocol proxy server.
use std::net::IpAddr;
use std::sync::Arc;

use json::JsonValue;
use log::{debug, warn};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;

use crate::async_tls::AsyncTlsConnector;
use crate::line_reader::LineReader;
use crate::resolver::Resolver;
use crate::util::allocate_vec;

pub struct DohResolver {
    tls_connector: Arc<Box<dyn AsyncTlsConnector>>,
}

impl DohResolver {
    pub fn new(tls_connector: Arc<Box<dyn AsyncTlsConnector>>) -> Self {
        Self { tls_connector }
    }
}

impl Resolver for DohResolver {
    async fn resolve_host(&self, host: &str) -> std::io::Result<Vec<IpAddr>> {
        lookup_host(host, &self.tls_connector).await
    }
}

async fn lookup_host(
    host: &str,
    tls_connector: &Arc<Box<dyn AsyncTlsConnector>>,
) -> std::io::Result<Vec<IpAddr>> {
    let stream = TcpStream::connect("8.8.8.8:443").await?;
    let mut stream = tls_connector.connect("dns.google", stream).await?;

    let mut request = String::with_capacity(4096);
    request.push_str("GET /resolve?name=");
    request.push_str(host);
    request.push_str(" HTTP/1.1\r\n");
    request.push_str("Host: dns.google\r\n");
    request.push_str("User-Agent: curl/7.68.0\r\n");
    request.push_str("Accept: application/dns-json\r\n");
    request.push_str("Accept-Encoding: identity\r\n");
    request.push_str("Connection: close\r\n\r\n");

    debug!("request: {}", &request);

    stream.write_all(&request.into_bytes()).await?;

    let mut line_reader = LineReader::new();
    let line = line_reader.read_line(&mut stream).await?;
    if !line.starts_with("HTTP/1.1 200 OK") {
        loop {
            let line = line_reader.read_line(&mut stream).await?;
            debug!("FAILURE LINE: {}", line);
            if line.is_empty() {
                break;
            }
        }

        if let Ok(s) = std::str::from_utf8(line_reader.unparsed_data()) {
            debug!("MSGUP: {}", s);
        }

        let mut buf = allocate_vec(40960);
        let len = stream.read(&mut buf).await?;
        debug!("READ {}", len);
        if let Ok(s) = std::str::from_utf8(&buf[0..len]) {
            debug!("MSG: {}", s);
        }

        return Err(std::io::Error::new(
            std::io::ErrorKind::Other,
            format!("DoH request failed: {}", "blurghaaa"),
        ));
    }

    let mut content_length: Option<usize> = None;
    loop {
        let line = line_reader.read_line(&mut stream).await?;
        if line.is_empty() {
            break;
        }
        if line.to_ascii_lowercase().starts_with("content-length: ") {
            let len = line[16..].parse::<usize>().map_err(|e| {
                std::io::Error::new(
                    std::io::ErrorKind::InvalidData,
                    format!("failed to parse content length: {}", e),
                )
            })?;
            if content_length.is_some() {
                return Err(std::io::Error::new(
                    std::io::ErrorKind::Other,
                    format!(
                        "Got content-length header but it was already received: {}",
                        line
                    ),
                ));
            }
            content_length = Some(len);
        }
    }

    let json_val = match content_length {
        Some(len) => {
            if len > 255_000 {
                return Err(std::io::Error::new(
                    std::io::ErrorKind::Other,
                    format!("JSON content length is too big ({})", len),
                ));
            }

            let mut json_bytes = allocate_vec(len);
            stream.read_exact(&mut json_bytes).await?;

            // TODO: we could set connection: keep-alive and reuse the stream.
            let json_str = String::from_utf8(json_bytes).map_err(|e| {
                std::io::Error::new(
                    std::io::ErrorKind::InvalidData,
                    format!("failed to parse dns JSON: {}", e),
                )
            })?;

            JsonValue::from(json_str)
        }
        None => {
            panic!("TODO: handle transfer-encoding chunked");
        }
    };

    let answer = &json_val["Answer"];
    if !answer.is_array() {
        return Err(std::io::Error::new(
            std::io::ErrorKind::Other,
            "failed to read json answer",
        ));
    }

    let mut results = vec![];
    for item in answer.members() {
        // 1 is A, 28 is AAAA
        match item["type"].as_str() {
            Some("1") | Some("28") => {
                // A or AAAA
                // ref: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4
                let data = match item["data"].as_str() {
                    Some(s) => s,
                    None => {
                        warn!("Answer item did not have data field");
                        continue;
                    }
                };
                let ip_addr = data.parse::<IpAddr>().map_err(|e| {
                    std::io::Error::new(
                        std::io::ErrorKind::InvalidData,
                        format!("failed to parse ip address: {}", e),
                    )
                })?;
                results.push(ip_addr);
            }
            Some(_) => {
                continue;
            }
            None => {
                warn!("Answer item did not have type field");
                continue;
            }
        }
    }

    if results.is_empty() {
        return Err(std::io::Error::new(
            std::io::ErrorKind::Other,
            format!("Failed to resolve {}", host),
        ));
    }

    debug!("Successfully resolved {}: {:?}", host, &results);

    Ok(results)
}