use std::{
env, fs,
path::{Path, PathBuf},
process::Command,
};
use anyhow::{Context, Result, bail};
fn main() -> Result<()> {
let mut args = env::args().skip(1);
match args.next().as_deref() {
Some("web-proto") => run_web_proto(args.collect()),
Some("audit-idempotency") => run_audit_idempotency(),
Some(command) => bail!("unknown command '{command}'"),
None => bail!("expected a command (for example: web-proto)"),
}
}
fn run_audit_idempotency() -> Result<()> {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let workspace_root = manifest_dir
.parent()
.and_then(|path| path.parent())
.context("failed to locate workspace root")?;
let proto_path = workspace_root.join("proto/heddle/v1/service.proto");
let source = fs::read_to_string(&proto_path)
.with_context(|| format!("read {}", proto_path.display()))?;
let rpcs = extract_rpcs(&source);
let messages = extract_messages(&source);
let mut missing: Vec<String> = Vec::new();
let mut audited = 0usize;
for rpc in &rpcs {
if !is_state_changing(&rpc.name) {
continue;
}
audited += 1;
let target_message = stream_envelope_target(&rpc.request_message, &messages)
.unwrap_or_else(|| rpc.request_message.clone());
let body = messages.get(&target_message);
let body = match body {
Some(b) => b,
None => {
missing.push(format!(
"{}::{} -> request message {} not found in proto",
rpc.service, rpc.name, target_message
));
continue;
}
};
if !body.contains("string client_operation_id = 15;") {
missing.push(format!(
"{}::{} -> {} is missing `string client_operation_id = 15;`",
rpc.service, rpc.name, target_message
));
}
}
if !missing.is_empty() {
eprintln!(
"audit-idempotency: {} state-changing RPC(s) missing client_operation_id = 15:",
missing.len()
);
for m in &missing {
eprintln!(" {m}");
}
bail!("audit-idempotency failed");
}
println!(
"audit-idempotency: {} state-changing RPC(s) carry `client_operation_id = 15`.",
audited
);
Ok(())
}
#[derive(Debug)]
struct ProtoRpc {
service: String,
name: String,
request_message: String,
}
fn extract_rpcs(source: &str) -> Vec<ProtoRpc> {
let mut out = Vec::new();
let mut current_service: Option<String> = None;
for line in source.lines() {
let trimmed = line.trim();
if let Some(rest) = trimmed.strip_prefix("service ") {
current_service = rest
.split_whitespace()
.next()
.map(|s| s.trim_end_matches('{').to_string());
continue;
}
if trimmed == "}" {
current_service = None;
continue;
}
if let Some(rest) = trimmed.strip_prefix("rpc ")
&& let Some(service) = ¤t_service
{
let after_name = rest
.split_once('(')
.map(|(name, after)| (name.trim(), after));
let Some((name, after_paren)) = after_name else {
continue;
};
let req = after_paren
.split_once(')')
.map(|(req, _)| req.trim().trim_start_matches("stream").trim());
let Some(req) = req else { continue };
out.push(ProtoRpc {
service: service.clone(),
name: name.to_string(),
request_message: req.to_string(),
});
}
}
out
}
fn extract_messages(source: &str) -> std::collections::HashMap<String, String> {
let mut out = std::collections::HashMap::new();
let bytes = source.as_bytes();
let needle = "message ";
let mut cursor = 0usize;
while let Some(rel) = source[cursor..].find(needle) {
let start = cursor + rel;
if start > 0 {
let prev = bytes[start - 1];
if prev.is_ascii_alphanumeric() || prev == b'_' {
cursor = start + needle.len();
continue;
}
}
let after = start + needle.len();
let name_end = after
+ source[after..]
.find(|c: char| c.is_whitespace() || c == '{')
.unwrap_or(0);
let name = source[after..name_end].trim();
let Some(brace_open) = source[name_end..].find('{') else {
cursor = name_end;
continue;
};
let brace_open = name_end + brace_open;
let brace_close = match_close_brace(bytes, brace_open).unwrap_or(bytes.len());
let body = &source[brace_open..brace_close.min(bytes.len())];
out.insert(name.to_string(), body.to_string());
cursor = brace_close;
}
out
}
fn stream_envelope_target(
name: &str,
messages: &std::collections::HashMap<String, String>,
) -> Option<String> {
let body = messages.get(name)?;
if !body.contains("oneof body") {
return None;
}
for line in body.lines() {
let trimmed = line.trim();
if let Some(after) = trimmed.strip_suffix(" request = 1;") {
return Some(after.trim().to_string());
}
}
None
}
fn match_close_brace(bytes: &[u8], open: usize) -> Option<usize> {
let mut depth: i32 = 0;
let mut i = open;
while i < bytes.len() {
match bytes[i] {
b'{' => depth += 1,
b'}' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
_ => {}
}
i += 1;
}
None
}
fn is_state_changing(name: &str) -> bool {
const AUTH_FLOW_BEGIN_ALLOW: &[&str] = &[
"BeginWebAuthnRegistration",
"BeginWebAuthnAuthentication",
"BeginDeviceAuthorization",
"BeginOAuthLogin",
"BeginOAuthLink",
"BeginInvitationFlow",
];
if AUTH_FLOW_BEGIN_ALLOW.contains(&name) {
return false;
}
const PREFIXES: &[&str] = &[
"Update",
"Push",
"Pull",
"Mint",
"Issue",
"Revoke",
"Rotate",
"Sign",
"Begin", "Commit",
"Abort",
"Create",
"Delete",
"Add",
"Remove",
"Approve",
"Register",
"Deregister",
"ResolveDiscussion",
"RespondToHook",
"OpenDiscussion",
"AppendTurn",
"Finish",
"Complete",
"Cancel",
"Set", ];
PREFIXES.iter().any(|p| name.starts_with(p))
}
fn run_web_proto(args: Vec<String>) -> Result<()> {
let check = args.iter().any(|arg| arg == "--check");
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let workspace_root = manifest_dir
.parent()
.and_then(|path| path.parent())
.context("failed to locate workspace root")?;
let proto_dir = workspace_root.join("proto");
let proto_file = proto_dir.join("heddle/v1/service.proto");
let web_dir = workspace_root.join("web");
let output_root = web_dir.join("src/lib/gen/proto");
let relative_output = PathBuf::from("heddle/v1/service_pb.ts");
let protoc = protoc_bin_vendored::protoc_bin_path()?;
let plugin = resolve_plugin_path(&web_dir)?;
if check {
let temp = tempfile::tempdir().context("failed to create temp directory")?;
generate_service_descriptor(&protoc, &plugin, &proto_dir, &proto_file, temp.path())?;
let generated = temp.path().join(&relative_output);
let checked_in = output_root.join(&relative_output);
assert_file_matches(&generated, &checked_in)?;
println!("web proto output is up to date");
return Ok(());
}
generate_service_descriptor(&protoc, &plugin, &proto_dir, &proto_file, &output_root)?;
println!("generated {}", output_root.join(relative_output).display());
Ok(())
}
fn generate_service_descriptor(
protoc: &Path,
plugin: &Path,
proto_dir: &Path,
proto_file: &Path,
output_root: &Path,
) -> Result<()> {
fs::create_dir_all(output_root).with_context(|| {
format!(
"failed to create output directory '{}'",
output_root.display()
)
})?;
let status = Command::new(protoc)
.arg(format!("--plugin=protoc-gen-es={}", plugin.display()))
.arg(format!("--proto_path={}", proto_dir.display()))
.arg(format!("--es_out=target=ts:{}", output_root.display()))
.arg(proto_file)
.status()
.with_context(|| format!("failed to run protoc at '{}'", protoc.display()))?;
if !status.success() {
bail!("protoc exited with status {status}");
}
Ok(())
}
fn resolve_plugin_path(web_dir: &Path) -> Result<PathBuf> {
if let Ok(value) = env::var("PROTOC_GEN_ES") {
let path = PathBuf::from(value);
if path.exists() {
return Ok(path);
}
bail!(
"PROTOC_GEN_ES was set, but '{}' does not exist",
path.display()
);
}
let candidates = [
web_dir.join("node_modules/.bin/protoc-gen-es"),
web_dir.join("node_modules/.bin/protoc-gen-es.cmd"),
];
for candidate in candidates {
if candidate.exists() {
return Ok(candidate);
}
}
bail!(
"could not find protoc-gen-es in web/node_modules/.bin.\n\
Install web dependencies first (for example: `cd web && npm install`) or set PROTOC_GEN_ES."
)
}
fn assert_file_matches(generated: &Path, checked_in: &Path) -> Result<()> {
let generated_contents = fs::read_to_string(generated)
.with_context(|| format!("failed to read generated file '{}'", generated.display()))?;
let checked_in_contents = fs::read_to_string(checked_in)
.with_context(|| format!("failed to read checked-in file '{}'", checked_in.display()))?;
if generated_contents != checked_in_contents {
bail!(
"generated proto output differs from '{}'. Run `npm run proto:gen` in web.",
checked_in.display()
);
}
Ok(())
}