use std::env;
use std::fmt::Write as _;
use std::io::Result;
use std::path::PathBuf;
use prost_reflect::{DescriptorPool, Value};
fn main() -> Result<()> {
const ALLOW_ATTR: &str = "#[allow(clippy::all, clippy::arithmetic_side_effects, clippy::panic, clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]";
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("protos");
let proto_files = vec![
root.join("confidence/flags/admin/v1/types.proto"),
root.join("confidence/flags/admin/v1/resolver.proto"),
root.join("confidence/flags/resolver/v1/api.proto"),
root.join("confidence/flags/resolver/v1/internal_api.proto"),
root.join("confidence/flags/resolver/v1/wasm_api.proto"),
root.join("confidence/flags/resolver/v1/events/events.proto"),
];
for proto_file in &proto_files {
println!("cargo:rerun-if-changed={}", proto_file.display());
}
let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
let mut config = prost_build::Config::new();
config.type_attribute(".", ALLOW_ATTR);
[
"confidence.flags.admin.v1.ClientResolveInfo.EvaluationContextSchemaInstance",
"confidence.flags.admin.v1.ContextFieldSemanticType",
"confidence.flags.admin.v1.ContextFieldSemanticType.type",
"confidence.flags.admin.v1.ContextFieldSemanticType.VersionSemanticType",
"confidence.flags.admin.v1.ContextFieldSemanticType.CountrySemanticType",
"confidence.flags.admin.v1.ContextFieldSemanticType.TimestampSemanticType",
"confidence.flags.admin.v1.ContextFieldSemanticType.DateSemanticType",
"confidence.flags.admin.v1.ContextFieldSemanticType.EntitySemanticType",
"confidence.flags.admin.v1.ContextFieldSemanticType.EnumSemanticType",
"confidence.flags.admin.v1.ContextFieldSemanticType.EnumSemanticType.EnumValue",
]
.iter()
.for_each(|&p| {
config.type_attribute(p, "#[derive(Eq, Hash)]");
});
config
.file_descriptor_set_path(&descriptor_path)
.btree_map(["."]);
#[cfg(feature = "json")]
{
config
.compile_well_known_types()
.extern_path(".google.protobuf", "::pbjson_types");
}
config.compile_protos(&proto_files, &[root])?;
#[cfg(feature = "json")]
{
let descriptor_set = std::fs::read(&descriptor_path)?;
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)?
.ignore_unknown_fields()
.btree_map(["."])
.build(&[
".confidence.flags.admin.v1",
".confidence.flags.resolver.v1",
".confidence.flags.resolver.v1.events",
".confidence.flags.types.v1",
".confidence.auth.v1",
".confidence.iam.v1",
".google.type",
])?;
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
for entry in std::fs::read_dir(&out_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().is_some_and(|e| e == "rs")
&& path
.file_name()
.is_some_and(|n| n.to_str().unwrap().contains(".serde.rs"))
{
let content = std::fs::read_to_string(&path)?;
let mut new_content = content
.replace("\nimpl ", &format!("\n{}\nimpl ", ALLOW_ATTR))
.replace("\nimpl<", &format!("\n{}\nimpl<", ALLOW_ATTR));
if new_content.starts_with("impl ") || new_content.starts_with("impl<") {
new_content = format!("{}\n{}", ALLOW_ATTR, new_content);
}
std::fs::write(&path, new_content)?;
}
}
}
generate_telemetry_config(&descriptor_path)?;
Ok(())
}
fn generate_telemetry_config(descriptor_path: &std::path::Path) -> Result<()> {
let descriptor_bytes = std::fs::read(descriptor_path)?;
let pool = DescriptorPool::decode(descriptor_bytes.as_ref())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let histogram_ext = pool
.get_extension_by_name("confidence.api.histogram")
.expect("confidence.api.histogram extension not found in descriptor pool");
let mut code = String::new();
for msg in pool.all_messages() {
let opts = msg.options();
if !opts.has_extension(&histogram_ext) {
continue;
}
let val = opts.get_extension(&histogram_ext);
let histogram_msg = match val.as_ref() {
Value::Message(m) => m,
_ => continue,
};
let min_value = histogram_msg
.get_field_by_name("min_value")
.and_then(|v| v.as_i32())
.unwrap_or(0);
let max_value = histogram_msg
.get_field_by_name("max_value")
.and_then(|v| v.as_i32())
.unwrap_or(0);
let bucket_count = histogram_msg
.get_field_by_name("bucket_count")
.and_then(|v| v.as_i32())
.unwrap_or(0);
let unit = histogram_msg
.get_field_by_name("unit")
.and_then(|v| v.as_str().map(|s| s.to_owned()))
.unwrap_or_default();
assert!(
min_value > 0,
"histogram min_value must be positive, got {min_value} on {}",
msg.full_name()
);
assert!(
max_value > min_value,
"histogram max_value must be greater than min_value, got max={max_value} min={min_value} on {}",
msg.full_name()
);
assert!(
bucket_count >= 3,
"histogram bucket_count must be >= 3, got {bucket_count} on {}",
msg.full_name()
);
let rust_type = proto_name_to_rust_type(msg.full_name());
writeln!(
code,
"impl crate::telemetry::HistogramConfig for crate::proto::{rust_type} {{"
)
.unwrap();
writeln!(code, " const MIN_VALUE: u32 = {min_value};").unwrap();
writeln!(code, " const MAX_VALUE: u32 = {max_value};").unwrap();
writeln!(code, " const BUCKET_COUNT: usize = {bucket_count};").unwrap();
writeln!(code, " const UNIT: &'static str = \"{unit}\";").unwrap();
writeln!(code, "}}").unwrap();
writeln!(code).unwrap();
}
let reason_enum = pool
.get_enum_by_name("confidence.flags.resolver.v1.ResolveReason")
.expect("ResolveReason enum not found in descriptor pool");
let mut discriminants: Vec<i32> = reason_enum.values().map(|v| v.number()).collect();
discriminants.sort();
for (i, &d) in discriminants.iter().enumerate() {
assert!(
d == i as i32,
"ResolveReason enum is sparse: expected discriminant {i} but found {d}. \
Dense layout is required for the telemetry resolve_rates array."
);
}
let reason_count = discriminants.len();
writeln!(
code,
"/// Number of ResolveReason variants (auto-generated from proto enum)."
)
.unwrap();
writeln!(code, "pub const REASON_COUNT: usize = {reason_count};").unwrap();
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
std::fs::write(out_dir.join("telemetry_config.rs"), code)?;
Ok(())
}
fn proto_name_to_rust_type(full_name: &str) -> String {
let parts: Vec<&str> = full_name.split('.').collect();
let mut result = String::new();
for (i, part) in parts.iter().enumerate() {
if i > 0 {
result.push_str("::");
}
if i < parts.len() - 1 && part.chars().next().is_some_and(|c| c.is_uppercase()) {
result.push_str(&to_snake_case(part));
} else {
result.push_str(part);
}
}
result
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() && i > 0 {
result.push('_');
}
result.push(ch.to_ascii_lowercase());
}
result
}