use std::collections::HashSet;
use std::io::{self, Read, Write};
use anyhow::{Context, Result};
use prost::Message;
use prost_reflect::DescriptorPool;
use prost_types::compiler::{CodeGeneratorRequest, CodeGeneratorResponse};
const VERSION: &str = concat!(env!("CARGO_PKG_NAME"), " ", env!("CARGO_PKG_VERSION"),);
fn main() -> Result<()> {
if let Some(flag) = std::env::args().nth(1) {
match flag.as_str() {
"--version" | "-V" => {
println!("{VERSION}");
return Ok(());
}
"--help" | "-h" => {
println!(
"{VERSION}\n\nUsage: invoked by `protoc` via stdin/stdout.\n\
See https://github.com/nu-sync/protoc-gen-rust-temporal for buf.gen.yaml examples."
);
return Ok(());
}
_ => {}
}
}
let mut input = Vec::new();
io::stdin().read_to_end(&mut input).context("read stdin")?;
let response = match build_response(&input) {
Ok(files) => CodeGeneratorResponse {
file: files,
error: None,
supported_features: Some(
prost_types::compiler::code_generator_response::Feature::Proto3Optional as u64,
),
},
Err(e) => CodeGeneratorResponse {
error: Some(format!("{e:#}")),
..Default::default()
},
};
let mut buf = Vec::new();
response
.encode(&mut buf)
.context("encode CodeGeneratorResponse")?;
io::stdout().write_all(&buf).context("write stdout")?;
Ok(())
}
fn build_response(raw: &[u8]) -> Result<Vec<prost_types::compiler::code_generator_response::File>> {
let req = CodeGeneratorRequest::decode(raw).context("decode CodeGeneratorRequest")?;
let files_to_generate: HashSet<String> = req.file_to_generate.into_iter().collect();
let proto_file_blobs = extract_proto_file_blobs(raw)?;
let mut fds_bytes = Vec::new();
for blob in &proto_file_blobs {
encode_tagged(&mut fds_bytes, 1, blob);
}
let mut pool = DescriptorPool::new();
pool.decode_file_descriptor_set(&*fds_bytes)
.context("decode_file_descriptor_set (extensions preserved)")?;
protoc_gen_rust_temporal::run_with_pool(&pool, &files_to_generate).with_context(|| {
let mut targets: Vec<&str> = files_to_generate.iter().map(String::as_str).collect();
targets.sort();
format!("generating from [{}]", targets.join(", "))
})
}
fn extract_proto_file_blobs(mut raw: &[u8]) -> Result<Vec<Vec<u8>>> {
use prost::bytes::Buf;
use prost::encoding::{WireType, decode_key, decode_varint};
let mut out = Vec::new();
while raw.has_remaining() {
let (tag, wire_type) = decode_key(&mut raw).context("decode key")?;
match (tag, wire_type) {
(15, WireType::LengthDelimited) => {
let len = decode_varint(&mut raw).context("decode proto_file len")? as usize;
if raw.remaining() < len {
anyhow::bail!("truncated proto_file blob");
}
out.push(raw[..len].to_vec());
raw = &raw[len..];
}
(_, WireType::Varint) => {
let _ = decode_varint(&mut raw).context("skip varint")?;
}
(_, WireType::SixtyFourBit) => {
if raw.remaining() < 8 {
anyhow::bail!("truncated 64-bit");
}
raw = &raw[8..];
}
(_, WireType::LengthDelimited) => {
let len = decode_varint(&mut raw).context("skip ld len")? as usize;
if raw.remaining() < len {
anyhow::bail!("truncated ld");
}
raw = &raw[len..];
}
(_, WireType::ThirtyTwoBit) => {
if raw.remaining() < 4 {
anyhow::bail!("truncated 32-bit");
}
raw = &raw[4..];
}
(_, WireType::StartGroup | WireType::EndGroup) => {
anyhow::bail!("unexpected group wire type");
}
}
}
Ok(out)
}
fn encode_tagged(out: &mut Vec<u8>, field: u32, payload: &[u8]) {
use prost::encoding::{WireType, encode_key, encode_varint};
encode_key(field, WireType::LengthDelimited, out);
encode_varint(payload.len() as u64, out);
out.extend_from_slice(payload);
}