use std::path::PathBuf;
use anyhow::{anyhow, Result};
use clap::Parser;
use const_format::concatcp;
use regex::Regex;
use text_io::read;
const RISC0_DEFAULT_VERSION: &str = env!("CARGO_PKG_VERSION");
const RISC0_RELEASE_TAG: &str = concatcp!("v", RISC0_DEFAULT_VERSION);
const HOST_MAIN: &str = include_str!("../../templates/rust-starter/host/src/main.rs");
const HOST_CARGO_TOML: &str = include_str!("../../templates/rust-starter/host/Cargo-toml");
const PROJECT_CARGO_TOML: &str = include_str!("../../templates/rust-starter/Cargo-toml");
const METHODS_BUILD_SCRIPT: &str = include_str!("../../templates/rust-starter/methods/build.rs");
const METHODS_CARGO_TOML: &str = include_str!("../../templates/rust-starter/methods/Cargo-toml");
const METHODS_LIB: &str = include_str!("../../templates/rust-starter/methods/src/lib.rs");
const GUEST_CARGO_TOML: &str =
include_str!("../../templates/rust-starter/methods/guest/Cargo-toml");
const GUEST_MAIN: &str = include_str!("../../templates/rust-starter/methods/guest/src/main.rs");
static PROJECT_TEMPLATED_FILES: &[(&str, &str)] = &[
("host/src/main.rs", HOST_MAIN),
("host/Cargo.toml", HOST_CARGO_TOML),
("Cargo.toml", PROJECT_CARGO_TOML),
("methods/build.rs", METHODS_BUILD_SCRIPT),
("methods/Cargo.toml", METHODS_CARGO_TOML),
("methods/src/lib.rs", METHODS_LIB),
("methods/guest/Cargo.toml", GUEST_CARGO_TOML),
("methods/guest/src/main.rs", GUEST_MAIN),
];
const RUST_TOOLCHAIN_TOML: &str = include_str!("../../templates/rust-starter/rust-toolchain.toml");
const README: &str = include_str!("../../templates/rust-starter/README.md");
const GIT_IGNORE: &str = include_str!("../../templates/rust-starter/.gitignore");
const LICENSE: &str = include_str!("../../templates/rust-starter/LICENSE");
static PROJECT_NON_TEMPLATED_FILES: &[(&str, &str)] = &[
("README.md", README),
("rust-toolchain.toml", RUST_TOOLCHAIN_TOML),
(".gitignore", GIT_IGNORE),
("LICENSE", LICENSE),
];
#[derive(Parser)]
pub struct NewCommand {
#[arg()]
pub name: String,
#[arg(long, default_value = RISC0_RELEASE_TAG)]
pub tag: String,
#[arg(long, default_value = "")]
pub branch: String,
#[arg(long)]
pub dest: Option<PathBuf>,
#[arg(long)]
pub use_git_branch: Option<String>,
#[arg(long, global = false)]
pub no_std: bool,
#[arg(long)]
pub path: Option<PathBuf>,
#[arg(long)]
pub guest_name: Option<String>,
}
impl NewCommand {
pub fn run(&self) -> Result<()> {
let dest_dir = if let Some(dest_dir) = self.dest.clone() {
dest_dir
} else {
std::env::current_dir().expect("Failed to fetch cwd")
};
let risc0_version = std::env::var("CARGO_PKG_VERSION")
.unwrap_or_else(|_| RISC0_DEFAULT_VERSION.to_string());
let mut template_variables = Vec::new();
if let Some(branch) = self.use_git_branch.as_ref() {
let spec =
format!("git = \"https://github.com/risc0/risc0.git\", branch = \"{branch}\"");
template_variables.push((Regex::new(r"\{\{ *risc0_build *\}\}")?, spec.clone()));
template_variables.push((Regex::new(r"\{\{ *risc0_zkvm *\}\}")?, spec));
} else if let Some(path) = self.path.as_ref() {
let path = path.to_str().unwrap();
let build = format!("path = \"{path}/risc0/build\"");
let zkvm = format!("path = \"{path}/risc0/zkvm\"");
template_variables.push((Regex::new(r"\{\{ *risc0_build *\}\}")?, build));
template_variables.push((Regex::new(r"\{\{ *risc0_zkvm *\}\}")?, zkvm));
} else {
let spec = format!("version = \"{risc0_version}\"");
template_variables.push((Regex::new(r"\{\{ *risc0_build *\}\}")?, spec.clone()));
template_variables.push((Regex::new(r"\{\{ *risc0_zkvm *\}\}")?, spec));
}
let guest_name = match &self.guest_name {
Some(name) => name.clone(),
None => {
eprint!(
"Guest name was not supplied through the --guest-name option. Please enter\x20\
package name for your template or press [enter] to use default guest package\x20\
name \"method\"\n\
Enter package name > "
);
let input_name: String = read!("{}\n");
if input_name.is_empty() {
"method".to_string()
} else {
input_name.clone()
}
}
};
syn::parse_str::<syn::Ident>(guest_name.as_str()).map_err(|_e| {
anyhow!("guest name [{guest_name}] must be a rust valid rust identifier")
})?;
let guest_name_const = guest_name.replace('-', "_").to_ascii_uppercase();
template_variables.push((
Regex::new(r"\{\{ *guest_package_name *\}\}")?,
format!("\"{guest_name}\""),
));
template_variables.push((
Regex::new(r"\{\{ *guest_id *\}\}")?,
format!("{guest_name_const}_ID"),
));
template_variables.push((
Regex::new(r"\{\{ *guest_elf *\}\}")?,
format!("{guest_name_const}_ELF"),
));
if !self.no_std {
template_variables.push((
Regex::new(r"\{\{ *risc0_feature_std *\}\}")?,
", features = ['std']".to_string(),
));
template_variables.push((Regex::new(r"\{\{ *no_std_preamble *\}\}")?, "".to_string()));
} else {
let no_std_preamble = "#![no_main]\n\
#![no_std]\n\
risc0_zkvm::guest::entry!(main);\n";
template_variables.push((
Regex::new(r"\{\{ *no_std_preamble *\}\}")?,
no_std_preamble.to_string(),
));
template_variables.push((
Regex::new(r"\{\{ *risc0_feature_std *\}\}")?,
"".to_string(),
));
}
self.gen_template(dest_dir, template_variables)?;
Ok(())
}
fn gen_template(&self, dest: PathBuf, template_variables: Vec<(Regex, String)>) -> Result<()> {
let root = dest.join(self.name.clone());
std::fs::create_dir_all(root.join("host/src"))?;
std::fs::create_dir_all(root.join("methods/src"))?;
std::fs::create_dir_all(root.join("methods/guest/src"))?;
for (filepath, content) in PROJECT_TEMPLATED_FILES {
std::fs::write(
root.join(filepath),
&Self::gen_file(content, template_variables.clone()),
)?;
}
for (filepath, content) in PROJECT_NON_TEMPLATED_FILES {
std::fs::write(root.join(filepath), content)?;
}
Ok(())
}
fn gen_file(haystack: &str, patterns: Vec<(Regex, String)>) -> String {
let mut haystack: String = haystack.to_string();
for (pattern, replace) in patterns {
haystack = pattern.replace_all(&haystack, replace).to_string();
}
haystack
}
}
#[cfg(test)]
mod tests {
use std::{
fs::File,
io::{BufRead, BufReader},
path::Path,
};
use tempfile::{tempdir, TempDir};
use super::*;
fn make_test_env() -> (TempDir, &'static str) {
let tmpdir = tempdir().expect("Failed to create tempdir");
(tmpdir, "my_project")
}
fn find_in_file(needle: &str, file: &Path) -> bool {
let file = File::open(file).unwrap();
let reader = BufReader::new(file);
for line in reader.lines() {
let line_data = line.expect("Failed to readline");
if line_data.contains(needle) {
return true;
}
}
false
}
#[test]
fn basic_new() {
let new = NewCommand::parse_from(["new", "--guest-name", "method", "my_project"]);
assert_eq!(new.name, "my_project");
}
#[test]
fn basic_generate() {
let (tmpdir, proj_name) = make_test_env();
let new = NewCommand::parse_from([
"new",
"--dest",
&tmpdir.path().to_string_lossy(),
"--guest-name",
"method",
proj_name,
]);
new.run().unwrap();
let proj_path = tmpdir.path().join(proj_name);
assert!(proj_path.exists());
assert!(find_in_file(
&format!("risc0-zkvm = {{ version = \"{RISC0_DEFAULT_VERSION}\" }}"),
&proj_path.join("host/Cargo.toml")
));
assert!(!find_in_file(
"#![no_std]",
&proj_path.join("methods/guest/src/main.rs")
));
}
#[test]
fn generate_no_git_branch() {
let (tmpdir, proj_name) = make_test_env();
let new = NewCommand::parse_from([
"new",
"--dest",
&tmpdir.path().to_string_lossy(),
"--use-git-branch",
"main",
"--guest-name",
"method",
proj_name,
]);
new.run().unwrap();
let proj_path = tmpdir.path().join(proj_name);
assert!(proj_path.exists());
assert!(!proj_path.join(".git").exists());
assert!(find_in_file(
"risc0-zkvm = { git = \"https://github.com/risc0/risc0.git\", branch = \"main\"",
&proj_path.join("host/Cargo.toml")
));
}
#[test]
fn generate_no_std() {
let (tmpdir, proj_name) = make_test_env();
let new = NewCommand::parse_from([
"new",
"--dest",
&tmpdir.path().to_string_lossy(),
"--no-std",
"--guest-name",
"method",
proj_name,
]);
new.run().unwrap();
let proj_path = tmpdir.path().join(proj_name);
assert!(find_in_file(
"#![no_std]",
&proj_path.join("methods/guest/src/main.rs")
));
assert!(!find_in_file(
"feature = ['std']",
&proj_path.join("methods/guest/Cargo.toml")
));
}
}