compio_dns 0.1.1

A DNS resolver for compio, using cache and system configuration for zero-cost async resolution. / compio 的异步 DNS 解析器,支持缓存和系统配置。
use std::{
  env, fs, io,
  path::{Path, PathBuf},
  process::Command,
  result,
};

use serde::Deserialize;
use toml_edit::{DocumentMut, Item, Table, Value};

#[derive(thiserror::Error, Debug)]
enum Error {
  #[error("cargo metadata exited with error: {0}")]
  Metadata(String),
  #[error("Failed to execute cargo metadata: {0}")]
  MetadataExec(#[from] io::Error),
  #[error("Failed to parse cargo metadata JSON: {0}")]
  MetadataParse(#[from] sonic_rs::Error),
  #[error("Dependency {0} not found in metadata")]
  DependencyNotFound(String),
  #[error("No parent directory for manifest_path: {0}")]
  NoParentDir(String),
  #[error("Failed to read file {0}: {1}")]
  ReadFile(PathBuf, io::Error),
  #[error("Failed to write file {0}: {1}")]
  WriteFile(PathBuf, io::Error),
  #[error("Failed to parse Cargo.toml: {0}")]
  ParseToml(#[from] toml_edit::TomlError),
}

const DIR: &str = env!("CARGO_MANIFEST_DIR");

const MOD_SYS: &str = "pub(crate) mod sys;";
const DEPS: &str = "dependencies";
const WORKSPACE: &str = "workspace";
const FEATURES: &str = "features";
const VERSION: &str = "version";
const TOML: &str = "Cargo.toml";
const COMPIO_NET: &str = "compio-net";
const CFG_IF: &str = "cfg_if::cfg_if! {";
const MOD_RS: &str = "src/resolve/mod.rs";
const SYS: &str = "src/resolve/sys";
const CARGO_CMD: &str = "cargo";

type Result<T> = result::Result<T, Error>;

#[derive(Deserialize, Debug)]
struct Metadata {
  packages: Vec<Package>,
  workspace_root: String,
}

#[derive(Deserialize, Debug)]
struct Package {
  name: String,
  manifest_path: String,
}

fn metadata() -> Result<Metadata> {
  let output = Command::new(env::var("CARGO").unwrap_or_else(|_| CARGO_CMD.to_string()))
    .args(["metadata", "--format-version=1"])
    .output()?;

  if !output.status.success() {
    return Err(Error::Metadata(
      String::from_utf8_lossy(&output.stderr).to_string(),
    ));
  }

  Ok(sonic_rs::from_slice(&output.stdout)?)
}

#[derive(Debug)]
struct Patcher<'a> {
  net: PathBuf,
  meta: &'a Metadata,
}

impl<'a> Patcher<'a> {
  fn new(net: PathBuf, meta: &'a Metadata) -> Self {
    Self { net, meta }
  }

  fn run(&self) -> Result<()> {
    let mod_path = self.net.join(MOD_RS);
    let (patched, content) = self.is_patched(&mod_path)?;
    if patched {
      return Ok(());
    }

    self.copy()?;
    self.patch_mod(&mod_path, content)?;
    self.patch_toml()?;

    Ok(())
  }

  fn is_patched(&self, mod_path: &Path) -> Result<(bool, String)> {
    let content =
      fs::read_to_string(mod_path).map_err(|e| Error::ReadFile(mod_path.to_path_buf(), e))?;
    Ok((content.contains(MOD_SYS), content))
  }

  fn copy(&self) -> Result<()> {
    let target = self.net.join(SYS);
    let src = Path::new(DIR).join("src");

    if target.exists() {
      if target.is_dir() {
        fs::remove_dir_all(&target).ok();
      } else {
        fs::remove_file(&target).ok();
      }
    }

    copy_dir(&src, &target)
  }

  fn patch_mod(&self, path: &Path, content: String) -> Result<()> {
    let mut lines: Vec<String> = content.lines().map(|s| s.to_string()).collect();
    let start_idx = lines.iter().position(|line| line.contains(CFG_IF));

    if let Some(start) = start_idx {
      let start = start + 1;
      lines[start] = format!(
        r#"
    if #[cfg(feature="dns")] {{
      {MOD_SYS}
    }} else {}"#,
        lines[start].trim()
      );

      fs::write(path, lines.join("\n")).map_err(|e| Error::WriteFile(path.to_path_buf(), e))?;
    }

    Ok(())
  }

  fn patch_toml(&self) -> Result<()> {
    let dns_toml = Path::new(DIR).join(TOML);
    let net_toml = self.net.join(TOML);
    let ws_toml = Path::new(&self.meta.workspace_root).join(TOML);

    let read = |p: &Path| fs::read_to_string(p).map_err(|e| Error::ReadFile(p.to_path_buf(), e));

    let dns_content = read(&dns_toml)?;
    let net_content = read(&net_toml)?;
    let ws_content = fs::read_to_string(&ws_toml).ok();

    let dns_doc = dns_content.parse::<DocumentMut>()?;
    let mut net_doc = net_content.parse::<DocumentMut>()?;
    let ws_doc = ws_content.and_then(|c| c.parse::<DocumentMut>().ok());

    {
      let net_features = net_doc.entry(FEATURES).or_insert_with(toml_edit::table);
      if let Some(t) = net_features.as_table_mut() {
        t.entry("dns")
          .or_insert(Item::Value(Value::Array(toml_edit::Array::new())));
        let def_arr = t
          .entry("default")
          .or_insert(Item::Value(Value::Array(toml_edit::Array::new())));
        if let Some(arr) = def_arr.as_array_mut()
          && !arr.iter().any(|v| v.as_str() == Some("dns"))
        {
          arr.push("dns");
        }
      }
    }

    if let Some(dns_deps) = dns_doc.get(DEPS).and_then(|d| d.as_table()) {
      let net_deps = net_doc.entry(DEPS).or_insert_with(toml_edit::table);

      if let Some(net_table) = net_deps.as_table_mut() {
        for (name, item) in dns_deps.iter() {
          if name == COMPIO_NET {
            continue;
          }

          if !net_table.contains_key(name) {
            let item_to_insert = resolve(name, item, &ws_doc);
            net_table.insert(name, item_to_insert);
          }

          merge_feature(name, item, net_table, &ws_doc);
        }
      }
    }

    fs::write(
      &net_toml,
      net_doc.to_string()
        + r#"
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(compio_dns)'] }
"#,
    )
    .map_err(|e| Error::WriteFile(net_toml, e))?;

    Ok(())
  }
}

fn main() -> Result<()> {
  println!("cargo:rerun-if-changed={TOML}");
  println!("cargo:rerun-if-changed=build.rs");
  println!("cargo:rerun-if-changed=src");
  println!("cargo:rustc-check-cfg=cfg(compio_dns)");
  println!("cargo:rustc-cfg=compio_dns");

  let meta = metadata()?;
  let net = pkg(&meta, COMPIO_NET)?;

  let patcher = Patcher::new(net, &meta);
  patcher.run()?;

  Ok(())
}

fn copy_dir(src: &Path, dst: &Path) -> Result<()> {
  if !dst.exists() {
    fs::create_dir_all(dst).map_err(|e| Error::WriteFile(dst.to_path_buf(), e))?;
  }
  for entry in fs::read_dir(src).map_err(|e| Error::ReadFile(src.to_path_buf(), e))? {
    let entry = entry.map_err(|e| Error::ReadFile(src.to_path_buf(), e))?;
    let ty = entry
      .file_type()
      .map_err(|e| Error::ReadFile(entry.path(), e))?;
    let dst_path = dst.join(entry.file_name());
    if ty.is_dir() {
      copy_dir(&entry.path(), &dst_path)?;
    } else {
      fs::copy(entry.path(), &dst_path).map_err(|e| Error::WriteFile(dst_path, e))?;
    }
  }
  Ok(())
}

fn pkg(meta: &Metadata, name: &str) -> Result<PathBuf> {
  let pkg = meta
    .packages
    .iter()
    .find(|p| p.name == name)
    .ok_or_else(|| Error::DependencyNotFound(name.to_string()))?;

  Path::new(&pkg.manifest_path)
    .parent()
    .ok_or_else(|| Error::NoParentDir(pkg.manifest_path.clone()))
    .map(|p| p.to_path_buf())
}

fn workspace_dep<'a>(dep_name: &str, workspace_doc: &'a Option<DocumentMut>) -> Option<&'a Item> {
  workspace_doc
    .as_ref()?
    .get(WORKSPACE)?
    .get(DEPS)?
    .as_table()?
    .get(dep_name)
}

fn resolve(name: &str, item: &Item, doc: &Option<DocumentMut>) -> Item {
  let workspace = item
    .get(WORKSPACE)
    .and_then(|v| v.as_bool())
    .unwrap_or(false);

  if workspace && let Some(resolved) = workspace_dep(name, doc) {
    return resolved.clone();
  }
  item.clone()
}

fn merge_feature(
  dep_name: &str,
  dep_item: &Item,
  target_table: &mut Table,
  workspace_doc: &Option<DocumentMut>,
) {
  let mut features = Vec::new();

  let mut collect = |item: &Item| {
    if let Some(f) = item.get(FEATURES).and_then(|f| f.as_array()) {
      features.extend(f.iter().filter_map(|v| v.as_str()).map(|s| s.to_string()));
    }
  };

  collect(dep_item);

  let is_workspace = dep_item
    .get(WORKSPACE)
    .and_then(|v| v.as_bool())
    .unwrap_or(false);
  if is_workspace && let Some(resolved_item) = workspace_dep(dep_name, workspace_doc) {
    collect(resolved_item);
  }

  if features.is_empty() {
    return;
  }

  if let Some(target_item) = target_table.get_mut(dep_name) {
    if target_item.is_str() {
      let ver = target_item.as_str().unwrap().to_string();
      let mut t = toml_edit::InlineTable::new();
      t.insert(VERSION, ver.into());
      *target_item = Item::Value(Value::InlineTable(t));
    }

    if let Some(arr) = features_mut(target_item) {
      for feat in features {
        if !arr.iter().any(|x| x.as_str() == Some(&feat)) {
          arr.push(feat);
        }
      }
    }
  }
}

fn features_mut(item: &mut Item) -> Option<&mut toml_edit::Array> {
  let t = item.as_table_like_mut()?;
  if t.get(FEATURES).is_none() {
    t.insert(FEATURES, Item::Value(Value::Array(toml_edit::Array::new())));
  }
  Some(
    t.get_mut(FEATURES)
      .and_then(|i| i.as_array_mut())
      .expect("features must be an array"),
  )
}