kapiti 0.0.3

The Kapiti DNS Server
Documentation
use std::convert::TryFrom;
use std::fs::{self, File};
use std::io;
use std::path::Path;
use std::time::{Duration, SystemTime};

use anyhow::{anyhow, Context, Result};
use async_io::Timer;
use chrono::{DateTime, NaiveDateTime};
use futures_lite::FutureExt;
use hyper::header;
use hyper::{Client, HeaderMap, Method};
use tracing::{debug, info, level_enabled, trace, warn, Level};

use crate::fetcher::Fetcher;
use crate::filter::path;
use crate::hyper_smol;

/// Downloads the specified URL to the specified path using the provided client.
/// If the local path exists and has an mtime newer than the URL, the download is skipped.
/// The client uses the external resolver, ensuring that the query is NOT affected by local filters.
/// The path meanwhile must have a ".zstd" extension or an error will be returned.
/// Returns a boolean for whether the file was actually uploaded (true) or skipped (false)
pub async fn update_file(
    client: &Client<hyper_smol::SmolConnector>,
    fetcher: &Fetcher,
    url: &String,
    path: &Path,
    timeout_ms: u64,
) -> Result<bool> {
    let file_mtime = get_file_mtime_ms(path)
        .with_context(|| format!("Failed to check local copy of {} at {:?}", url, path))?;
    let (head_redirect_url, needs_update) = match file_mtime {
        Some(file_mtime_ms_u128) => {
            let file_mtime_ms = i64::try_from(file_mtime_ms_u128)
                .with_context(|| format!("Invalid file mtime for {:?}", path))?;
            file_needs_update(client, fetcher, url, file_mtime_ms).await?
        }
        None => {
            // Local file not found, do download
            (None, true)
        }
    };
    if !needs_update {
        // File exists and is up to date
        info!("Skipping download of {}: Local copy is up to date", url);
        return Ok(false);
    }

    // File doesn't exist, or file is out of date. Get a new version.
    info!("Downloading {} to {:?}", url, path);

    let get_url = match &head_redirect_url {
        // If the HEAD query hit a redirect, follow that same redirect for the following GET query
        Some(u) => u,
        // Otherwise use the original URL for the GET query
        None => url,
    };

    let mut resp = client
        .request(fetcher.build_request(&Method::GET, get_url)?)
        .or(async {
            Timer::after(Duration::from_millis(timeout_ms)).await;
            // hyper keeps error types crate-private. Jump through hoops to produce an Ok response with an error.
            let response = hyper::Response::new(hyper::Body::empty());
            let (mut parts, body) = response.into_parts();
            parts.status = http::StatusCode::GATEWAY_TIMEOUT;
            Ok(hyper::Response::from_parts(parts, body))
        })
        .await
        .with_context(|| format!("HTTP GET to {} failed", get_url))?;

    // Only allow redirect for GET query if we didn't already follow a redirect for the HEAD query.
    if head_redirect_url.is_none() && resp.status().is_redirection() {
        // Basic support for redirects: Just allow at most one redirect, and don't change request content for the new destination.
        // Intentionally basic for now, can improve later if needed.
        let loc = header_to_str(resp.headers(), &header::LOCATION, url)?;
        trace!("Following redirect: {} => {}", url, loc);
        resp = client
            .request(fetcher.build_request(&Method::GET, &loc)?)
            .await
            .with_context(|| format!("HTTP GET to {} failed", loc))?;
    }

    // Note that we just pass the original url: log the original requested value rather than the redirected value

    // Write to "file.tmp" then rename to "file[.ext]"
    let tmp_path = path.with_extension("tmp");
    // If anything fails with the download then try to delete "file.tmp" automatically.
    let _tmp_guard = scopeguard::guard(tmp_path.clone(), |path| {
        trace!("Cleaning up {:?} if it exists", path);
        let _ = fs::remove_file(path);
    });
    trace!("Downloading to {:?}", tmp_path);

    {
        let mut tmp_file = File::create(&tmp_path)?;
        if path::is_zstd_extension(path) {
            // ZSTD compression for file output
            let mut encoder =
                zstd::stream::Encoder::new(tmp_file, zstd::DEFAULT_COMPRESSION_LEVEL)?
                    .auto_finish();
            fetcher
                .write_response(&url, &mut encoder, &mut resp)
                .await?;
        } else {
            // No compression for file output
            fetcher
                .write_response(&url, &mut tmp_file, &mut resp)
                .await?;
        }
    }

    trace!("Renaming {:?} => {:?}", tmp_path, path);
    fs::rename(&tmp_path, &path).with_context(|| {
        format!(
            "Failed to rename downloaded filter file from {:?} to {:?}",
            tmp_path, path
        )
    })?;
    Ok(true)
}

async fn file_needs_update(
    client: &Client<hyper_smol::SmolConnector>,
    fetcher: &Fetcher,
    url: &String,
    file_mtime_ms: i64,
) -> Result<(Option<String>, bool)> {
    let mut redirect_url: Option<String> = None;

    // Local file already exists, see if it makes sense to update.
    // Check the server's Last-Modified response and compare it to our mtime.
    // We avoid dealing with If-Modified-Since, since support for it is apparently very inconsistent.
    // see also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/HEAD
    let mut head_resp = client
        .request(fetcher.build_request(&Method::HEAD, &url)?)
        .await
        .with_context(|| format!("HTTP HEAD query to {} failed", url))?;

    if head_resp.status().is_redirection() {
        // Basic support for redirects: Just allow at most one redirect, and don't change request content for the new destination.
        // Intentionally basic for now, can improve later if needed.
        let loc = header_to_str(head_resp.headers(), &header::LOCATION, &url)?;
        head_resp = client
            .request(fetcher.build_request(&Method::GET, &loc)?)
            .await
            .with_context(|| format!("HTTP HEAD query to {} failed", loc))?;
        redirect_url = Some(loc);
    }

    // Check if local file is older than server's Last-Modified
    match header_to_str(head_resp.headers(), &header::LAST_MODIFIED, &url) {
        Ok(url_mtime_header) => {
            if level_enabled!(Level::DEBUG) {
                debug!(
                    "Existing file mtime='{:?}' vs {} Last-Modified='{}'",
                    NaiveDateTime::from_timestamp_opt(file_mtime_ms / 1000, 0),
                    url,
                    url_mtime_header
                );
            }
            let url_mtime_ms = DateTime::parse_from_rfc2822(url_mtime_header.as_str())
                .with_context(|| format!("Failed to parse Last-Modified header from {}", url))?
                .timestamp_millis();
            return Ok((redirect_url, file_mtime_ms < url_mtime_ms));
        }
        Err(_e) => {
            // No Last-Modified, continue below...
        }
    };

    // Check if local file is older than the server's expire period defined by Date+Expires.
    // (Seen with downloads from raw.githubusercontent.com)
    match (
        header_to_str(head_resp.headers(), &header::DATE, &url),
        header_to_str(head_resp.headers(), &header::EXPIRES, &url),
    ) {
        (Ok(url_date_header), Ok(url_expires_header)) => {
            if level_enabled!(Level::DEBUG) {
                debug!(
                    "Existing file mtime='{:?}' vs {} Date='{}' + Expires='{}'",
                    NaiveDateTime::from_timestamp_opt(file_mtime_ms / 1000, 0),
                    url,
                    url_date_header,
                    url_expires_header
                );
            }
            let url_date_ms = DateTime::parse_from_rfc2822(url_date_header.as_str())
                .with_context(|| format!("Failed to parse Date header from {}", url))?
                .timestamp_millis();
            let url_expires_ms = DateTime::parse_from_rfc2822(url_expires_header.as_str())
                .with_context(|| format!("Failed to parse Expires header from {}", url))?
                .timestamp_millis();
            if url_date_ms > url_expires_ms {
                // The server's Date timestamp is ahead of their own Expires timestamp.
                // Give up and redownload the file.
                warn!(
                    "Server Date={}/{} is older than server Expires={}/{}",
                    url_date_ms, url_date_header, url_expires_ms, url_expires_header
                );
                return Ok((redirect_url, true));
            }

            // We want to download if the file is older than the difference of Expires-Date (aka the expire duration)
            // For example, if the expire duration is 5 hours, then we should only download if the local file is more than 5 hours old.
            // This allows us to avoid full parsing support of Cache-Control headers.
            // Also, we are careful to avoid comparing server timestamps against local timestamps. We only compare relative durations.
            match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
                Ok(epoch_duration) => {
                    let now_ms = i64::try_from(epoch_duration.as_millis())
                        .with_context(|| "current time is invalid")?;
                    if now_ms < file_mtime_ms {
                        // Give up and redownload the file.
                        warn!(
                            "File was modified in the future: mtime={} now={}",
                            file_mtime_ms, now_ms
                        );
                        return Ok((redirect_url, true));
                    }
                    let expire_duration_ms = url_expires_ms - url_date_ms;
                    let file_age_ms = now_ms - file_mtime_ms;
                    debug!(
                        "File age {} vs expire duration {}",
                        file_age_ms, expire_duration_ms
                    );
                    Ok((redirect_url, file_age_ms > expire_duration_ms))
                }
                Err(_) => {
                    // Give up and redownload the file.
                    warn!("Current time is before 1970");
                    Ok((redirect_url, true))
                }
            }
        }
        _ => {
            // Missing Last-Modified and missing either Date and/or Expires. Give up and do the download.
            Ok((redirect_url, true))
        }
    }
}

fn get_file_mtime_ms(path: &Path) -> Result<Option<u128>> {
    match fs::metadata(path) {
        Ok(metadata) => {
            let mtime = metadata
                .modified()
                .with_context(|| format!("Failed to get modified time for {:?}", path))?;
            match mtime.duration_since(SystemTime::UNIX_EPOCH) {
                Ok(duration) => Ok(Some(duration.as_millis())),
                Err(_) => {
                    // mtime is before epoch, lets just treat it as being created AT epoch
                    warn!("File was created before 1970: {:?}", path);
                    Ok(Some(0))
                }
            }
        }
        Err(e) => {
            if e.kind() != io::ErrorKind::NotFound {
                Err(e).with_context(|| format!("Failed to get metadata for {:?}", path))?;
            }
            Ok(None)
        }
    }
}

fn header_to_str(
    headers: &HeaderMap,
    header: &header::HeaderName,
    origin: &String,
) -> Result<String> {
    match headers.get(header) {
        Some(header_val) => match header_val.to_str() {
            Ok(header_str) => Ok(header_str.to_string()),
            Err(e) => Err(anyhow!(
                "Failed to convert {} {:?} to string: {:?}",
                origin,
                header,
                e
            )),
        },
        None => Err(anyhow!(
            "{} response has missing {:?}: {:?}",
            origin,
            header,
            headers
        )),
    }
}