use cargo_metadata::{Metadata, MetadataCommand};
use chip::{get_board_chip_map, get_chip_options};
use clap::Parser;
use futures::stream::StreamExt;
use inquire::ui::{Attributes, Color, RenderConfig, StyleSheet, Styled};
use inquire::{Select, Text};
use keyboard_toml::{parse_keyboard_toml, ProjectInfo};
use reqwest::Client;
use std::error::Error;
use std::fs;
use std::fs::File;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use zip::ZipArchive;
mod args;
mod chip;
mod keyboard_toml;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
inquire::set_global_render_config(get_render_config());
let args = args::Args::parse();
match args.command {
args::Commands::Create {
keyboard_toml_path,
vial_json_path,
target_dir,
} => create_project(keyboard_toml_path, vial_json_path, target_dir).await,
args::Commands::Init {
project_name,
chip,
split,
local_path,
} => init_project(project_name, chip, split, local_path).await,
args::Commands::GetChip { keyboard_toml_path } => {
let project_info = parse_keyboard_toml(&keyboard_toml_path, None)?;
println!("{}", project_info.chip);
Ok(())
}
args::Commands::GetProjectName { keyboard_toml_path } => {
let project_info = parse_keyboard_toml(&keyboard_toml_path, None)?;
println!("{}", project_info.project_name);
Ok(())
}
}
}
async fn create_project(
keyboard_toml_path: Option<String>,
vial_json_path: Option<String>,
target_dir: Option<String>,
) -> Result<(), Box<dyn Error>> {
let keyboard_toml_path = if keyboard_toml_path.is_none() {
Text::new("Path to keyboard.toml:")
.with_default("./keyboard.toml")
.prompt()?
} else {
keyboard_toml_path.unwrap()
};
let vial_json_path = if vial_json_path.is_none() {
Text::new("Path to vial.json")
.with_default(&"./vial.json")
.prompt()?
} else {
vial_json_path.unwrap()
};
let project_info = parse_keyboard_toml(&keyboard_toml_path, target_dir)?;
download_project_template(&project_info).await?;
fs::copy(
&keyboard_toml_path,
project_info.target_dir.join("keyboard.toml"),
)?;
fs::copy(&vial_json_path, project_info.target_dir.join("vial.json"))?;
post_process(project_info)?;
Ok(())
}
fn post_process(project_info: ProjectInfo) -> Result<(), Box<dyn Error>> {
replace_in_folder(
&project_info,
"toml",
"{{ project_name }}",
&project_info.project_name,
)?;
replace_in_folder(
&project_info,
"json",
"{{ project_name }}",
&project_info.project_name,
)?;
replace_in_folder(&project_info, "toml", "{{ chip_name }}", &project_info.chip)?;
replace_in_folder(
&project_info,
"toml",
"{{ uf2_key }}",
&project_info.uf2_key,
)?;
if project_info.disabled_default_feature.len() > 0 {
let metadata = MetadataCommand::new()
.current_dir(&project_info.target_dir)
.exec()?;
disable_rmk_default_features(&project_info.target_dir, &metadata, project_info.disabled_default_feature)?;
}
if project_info.enabled_feature.len() > 0 {
enable_rmk_features(&project_info.target_dir, project_info.enabled_feature)?;
}
Ok(())
}
fn replace_in_folder(
project_info: &ProjectInfo,
ext: &str,
from: &str,
to: &str,
) -> Result<(), Box<dyn Error>> {
let walker = walkdir::WalkDir::new(&project_info.target_dir)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map_or(false, |e| e == ext));
for entry in walker {
let path = entry.path();
let content = fs::read_to_string(path)?;
let new_content = content.replace(from, to);
fs::write(path, new_content)?;
}
Ok(())
}
async fn download_project_template(project_info: &ProjectInfo) -> Result<(), Box<dyn Error>> {
let user = "HaoboGu";
let repo = "rmk-template";
let branch = "main";
let url = format!(
"https://github.com/{}/{}/archive/refs/heads/{}.zip",
user, repo, branch
);
download_with_progress(&url, &project_info.target_dir, &project_info.remote_folder).await
}
async fn init_project(
project_name: Option<String>,
chip: Option<String>,
split: Option<bool>,
local_path: Option<String>,
) -> Result<(), Box<dyn Error>> {
let project_name = if project_name.is_none() {
Text::new("Project Name:").prompt()?.replace(" ", "_")
} else {
project_name.unwrap().replace(" ", "_")
};
let split = if split.is_none() {
Select::new("Choose your keyboard type?", vec!["normal", "split"]).prompt()? == "split"
} else {
split.unwrap()
};
let mut chip_or_board = if chip.is_none() {
Select::new("Choose your microcontroller or board", get_chip_options(split))
.prompt()?
.to_string()
} else {
chip.unwrap()
};
let target_dir = PathBuf::from(&project_name);
fs::create_dir_all(&target_dir)?;
let board_chip_map = get_board_chip_map();
if let Some(c) = board_chip_map.get(chip_or_board.as_str()) {
chip_or_board = c.to_string();
};
let remote_folder = if split {
format!("{}_{}", chip_or_board, "split")
} else {
chip_or_board.clone()
};
let uf2_key = if chip_or_board.starts_with("stm32") {
chip_or_board[..7].to_string()
} else {
if chip_or_board == "pico_w" {
"rp2040".to_string()
} else {
chip_or_board.clone()
}
};
let project_info = ProjectInfo {
project_name,
target_dir,
remote_folder,
chip: chip_or_board,
uf2_key,
disabled_default_feature: Vec::new(),
enabled_feature: Vec::new(),
};
match local_path {
Some(p) => {
copy_dir_recursive(Path::new(&p), &project_info.target_dir)?;
}
None => {
download_project_template(&project_info).await?;
}
}
post_process(project_info)?;
Ok(())
}
async fn download_with_progress<P>(
download_url: &str,
output_path: P,
folder: &str,
) -> Result<(), Box<dyn Error>>
where
P: AsRef<Path>,
{
let output_path = output_path.as_ref();
if output_path.exists() {
fs::remove_dir_all(output_path)?;
}
fs::create_dir_all(output_path)?;
println!("⇣ Download project template for {}...", folder);
let client = Client::new();
let response = client.get(download_url).send().await?;
if !response.status().is_success() {
return Err(format!("Download failed: {}", response.status()).into());
}
let temp_file_path = output_path.join("temp.zip");
let mut temp_file = File::create(&temp_file_path)?;
struct TempFileCleanup<'a> {
path: &'a Path,
}
impl<'a> Drop for TempFileCleanup<'a> {
fn drop(&mut self) {
if self.path.exists() {
if let Err(e) = fs::remove_file(self.path) {
eprintln!(
"Failed to remove temp file '{}': {}",
self.path.display(),
e
);
}
}
}
}
let _cleanup_guard = TempFileCleanup {
path: &temp_file_path,
};
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
temp_file.write_all(&chunk)?;
}
let zip_file = File::open(&temp_file_path)?;
let mut zip = ZipArchive::new(zip_file)?;
let mut folder_found = false;
for i in 0..zip.len() {
let mut file = zip.by_index(i)?;
let file_name = file.enclosed_name().ok_or("Invalid file path")?;
let segments: Vec<_> = file_name.iter().collect();
if segments.len() > 1 && segments[1] == folder {
folder_found = true;
let relative_name = file_name.iter().skip(2).collect::<PathBuf>();
let out_path = output_path.join(relative_name);
if file.is_dir() {
fs::create_dir_all(&out_path)?;
} else {
if let Some(parent) = out_path.parent() {
fs::create_dir_all(parent)?;
}
let mut outfile = File::create(&out_path)?;
io::copy(&mut file, &mut outfile)?;
}
}
}
if !folder_found {
if folder.starts_with("stm32") {
if folder.len() > 7 {
let stm32_series = &folder[..7];
for i in 0..zip.len() {
let mut file = zip.by_index(i)?;
let file_name = file.enclosed_name().ok_or("Invalid file path")?;
let segments: Vec<_> = file_name.iter().collect();
if segments.len() > 1 && segments[1] == stm32_series {
folder_found = true;
let relative_name = file_name.iter().skip(2).collect::<PathBuf>();
let out_path = output_path.join(relative_name);
if file.is_dir() {
fs::create_dir_all(&out_path)?;
} else {
if let Some(parent) = out_path.parent() {
fs::create_dir_all(parent)?;
}
let mut outfile = File::create(&out_path)?;
io::copy(&mut file, &mut outfile)?;
}
}
}
}
if !folder_found {
println!("️️🚨 There's no template available for [{folder}], using the default stm32 template. You may need to make further edit.");
for i in 0..zip.len() {
let mut file = zip.by_index(i)?;
let file_name = file.enclosed_name().ok_or("Invalid file path")?;
let segments: Vec<_> = file_name.iter().collect();
if segments.len() > 1 && segments[1] == "stm32" {
folder_found = true;
let relative_name = file_name.iter().skip(2).collect::<PathBuf>();
let out_path = output_path.join(relative_name);
if file.is_dir() {
fs::create_dir_all(&out_path)?;
} else {
if let Some(parent) = out_path.parent() {
fs::create_dir_all(parent)?;
}
let mut outfile = File::create(&out_path)?;
io::copy(&mut file, &mut outfile)?;
}
}
}
}
}
if !folder_found {
return Err(format!(
"The specified chip/board '{}' does not exist in the template repo",
folder
)
.into());
}
}
println!("✅ Project created, path: {}", output_path.display());
Ok(())
}
fn get_render_config() -> RenderConfig<'static> {
let mut render_config = RenderConfig::default();
render_config.prompt_prefix = Styled::new("?").with_fg(Color::LightRed);
render_config.error_message = render_config
.error_message
.with_prefix(Styled::new("❌").with_fg(Color::LightRed));
render_config.answer = StyleSheet::new()
.with_attr(Attributes::ITALIC)
.with_fg(Color::LightGreen);
render_config
}
fn copy_dir_recursive(src: &Path, dest: &Path) -> io::Result<()> {
if !src.is_dir() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Source is not a directory",
));
}
fs::create_dir_all(dest)?;
for entry in fs::read_dir(src)? {
let entry = entry?;
let file_type = entry.file_type()?;
let src_path = entry.path();
let dest_path = dest.join(entry.file_name());
if file_type.is_dir() {
copy_dir_recursive(&src_path, &dest_path)?;
} else {
fs::copy(&src_path, &dest_path)?;
}
}
Ok(())
}
fn disable_rmk_default_features(
target_dir: &PathBuf,
metadata: &Metadata,
features: Vec<String>,
) -> Result<(), String> {
println!("Disabling default features: {:?}", features);
let cargo_toml_path = Path::new(target_dir).join("Cargo.toml");
let mut manifest =
cargo_toml::Manifest::from_path(&cargo_toml_path).map_err(|e| e.to_string())?;
if let Some(cargo_toml::Dependency::Detailed(rmk_dep)) = manifest.dependencies.get_mut("rmk") {
let mut default_features = get_dependency_default_features("rmk", metadata)?;
default_features.retain(|s| !features.contains(s));
rmk_dep.features.append(&mut default_features);
rmk_dep.features.sort_unstable();
rmk_dep.features.dedup();
rmk_dep.default_features = false;
} else {
return Err("No valid rmk dependency found".to_string());
}
let updated_toml = toml::to_string(&manifest)
.map_err(|e| format!("Failed to serialize updated Cargo.toml: {}", e))?;
fs::write(&cargo_toml_path, updated_toml)
.map_err(|e| format!("Failed to write updated Cargo.toml: {}", e))?;
Ok(())
}
fn get_dependency_default_features(
dependency: &str,
metadata: &Metadata,
) -> Result<Vec<String>, String> {
let dep = metadata
.packages
.iter()
.find(|p| p.name.to_string() == dependency)
.ok_or(format!("Failed to find {} in dependencies", dependency))?;
dep.features
.get("default")
.cloned()
.ok_or(format!("Failed to get default {} features", dependency))
}
fn enable_rmk_features(
target_dir: &PathBuf,
features: Vec<String>,
) -> Result<(), String> {
println!("Enabling features: {:?}", features);
let cargo_toml_path = Path::new(target_dir).join("Cargo.toml");
let mut manifest =
cargo_toml::Manifest::from_path(&cargo_toml_path).map_err(|e| e.to_string())?;
if let Some(cargo_toml::Dependency::Detailed(rmk_dep)) = manifest.dependencies.get_mut("rmk") {
for feature in features {
if !rmk_dep.features.contains(&feature) {
rmk_dep.features.push(feature);
}
}
rmk_dep.features.sort_unstable();
rmk_dep.features.dedup();
} else {
return Err("No valid rmk dependency found".to_string());
}
let updated_toml = toml::to_string(&manifest)
.map_err(|e| format!("Failed to serialize updated Cargo.toml: {}", e))?;
fs::write(&cargo_toml_path, updated_toml)
.map_err(|e| format!("Failed to write updated Cargo.toml: {}", e))?;
Ok(())
}