1#![doc = ::document_features::document_features!()]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5
6use std::fmt::Write;
7use std::fs;
8use std::io::{self, Read};
9use std::path::Path;
10use std::{env, path::PathBuf};
11
12use prost_build::Config;
13use prost_reflect::{prost::Message as ProstMessage, prost_types::FileDescriptorSet};
14
15#[derive(Default)]
22pub struct DescriptorDataConfig {
23 collect_oneofs_data: bool,
24 collect_enums_data: bool,
25 collect_messages_data: bool,
26}
27
28impl DescriptorDataConfig {
29 pub fn set_up_validators(
32 &self,
33 config: &mut Config,
34 files: &[impl AsRef<Path>],
35 include_paths: &[impl AsRef<Path>],
36 packages: &[&str],
37 ) -> Result<DescriptorData, Box<dyn std::error::Error>> {
38 set_up_validators_inner(self, config, files, include_paths, packages)
39 }
40
41 #[must_use]
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 #[must_use]
49 pub const fn collect_all_data() -> Self {
50 Self {
51 collect_oneofs_data: true,
52 collect_enums_data: true,
53 collect_messages_data: true,
54 }
55 }
56
57 #[must_use]
59 pub const fn collect_oneofs_data(mut self) -> Self {
60 self.collect_oneofs_data = true;
61 self
62 }
63
64 #[must_use]
66 pub const fn collect_enums_data(mut self) -> Self {
67 self.collect_enums_data = true;
68 self
69 }
70
71 #[must_use]
73 pub const fn collect_messages_data(mut self) -> Self {
74 self.collect_messages_data = true;
75 self
76 }
77}
78
79#[derive(Default)]
83pub struct DescriptorData {
84 pub oneofs: Vec<Oneof>,
85 pub enums: Vec<Enum>,
86 pub messages: Vec<Message>,
87}
88
89pub struct Oneof {
91 pub name: String,
92 pub parent_message: String,
93 pub package: String,
94}
95
96pub struct Message {
98 pub name: String,
99 pub parent_message: Option<String>,
100 pub package: String,
101}
102
103pub struct Enum {
105 pub name: String,
106 pub parent_message: Option<String>,
107 pub package: String,
108}
109
110impl Message {
111 #[must_use]
113 pub fn full_name(&self) -> String {
114 let Self {
115 name,
116 parent_message,
117 package,
118 } = self;
119
120 let mut str = format!("{package}.");
121
122 if let Some(parent) = parent_message {
123 let _ = write!(str, "{parent}.{name}");
124 } else {
125 let _ = write!(str, "{name}");
126 }
127
128 str
129 }
130}
131
132impl Enum {
133 #[must_use]
135 pub fn full_name(&self) -> String {
136 let Self {
137 name,
138 parent_message,
139 package,
140 } = self;
141
142 let mut str = format!("{package}.");
143
144 if let Some(parent) = parent_message {
145 let _ = write!(str, "{parent}.{name}");
146 } else {
147 let _ = write!(str, "{name}");
148 }
149
150 str
151 }
152}
153
154impl Oneof {
155 #[must_use]
157 pub fn full_name(&self) -> String {
158 let Self {
159 name,
160 parent_message,
161 package,
162 } = self;
163
164 format!("{package}.{parent_message}.{name}")
165 }
166}
167
168fn full_ish_name<'a>(item: &'a str, package: &'a str) -> &'a str {
169 item.strip_prefix(&format!("{package}."))
170 .unwrap_or(item)
171}
172
173pub fn set_up_validators(
185 config: &mut Config,
186 files: &[impl AsRef<Path>],
187 include_paths: &[impl AsRef<Path>],
188 packages: &[&str],
189) -> Result<DescriptorData, Box<dyn std::error::Error>> {
190 set_up_validators_inner(
191 &DescriptorDataConfig::default(),
192 config,
193 files,
194 include_paths,
195 packages,
196 )
197}
198
199fn set_up_validators_inner(
200 desc_data_config: &DescriptorDataConfig,
201 config: &mut Config,
202 files: &[impl AsRef<Path>],
203 include_paths: &[impl AsRef<Path>],
204 packages: &[&str],
205) -> Result<DescriptorData, Box<dyn std::error::Error>> {
206 let out_dir = env::var("OUT_DIR")
207 .map(PathBuf::from)
208 .unwrap_or(env::temp_dir());
209
210 config
211 .extern_path(".google.protobuf", "::protify::proto_types")
212 .extern_path(".buf.validate", "::protify::proto_types::protovalidate")
213 .compile_well_known_types();
214
215 let temp_descriptor_path = out_dir.join("__temp_file_descriptor_set.bin");
216 {
217 let mut temp_config = prost_build::Config::new();
218 temp_config.file_descriptor_set_path(&temp_descriptor_path);
219 temp_config.out_dir(&out_dir);
220 temp_config.compile_protos(files, include_paths)?;
221 }
222
223 let mut fds_file = std::fs::File::open(&temp_descriptor_path)?;
224 let mut fds_bytes = Vec::new();
225 fds_file.read_to_end(&mut fds_bytes)?;
226 let fds = FileDescriptorSet::decode(fds_bytes.as_slice())?;
227 let pool = prost_reflect::DescriptorPool::from_file_descriptor_set(fds)?;
228
229 let mut desc_data = DescriptorData::default();
230
231 for message_desc in pool.all_messages() {
232 let package = message_desc.package_name();
233
234 if packages.contains(&package) {
235 let message_name = message_desc.full_name();
236
237 if desc_data_config.collect_messages_data {
238 desc_data.messages.push(Message {
239 name: message_desc.name().to_string(),
240 parent_message: message_desc
241 .parent_message()
242 .map(|p| full_ish_name(p.full_name(), package).to_string()),
243 package: package.to_string(),
244 });
245 }
246
247 config.message_attribute(message_name, "#[derive(::protify::ValidatedMessage)]");
248 #[cfg(feature = "cel")]
249 {
250 config.message_attribute(message_name, "#[derive(::protify::CelValue)]");
251 }
252 config.message_attribute(
253 message_name,
254 format!(r#"#[proto(name = "{message_name}")]"#),
255 );
256
257 for oneof in message_desc.oneofs() {
258 let parent_message = oneof.parent_message().full_name();
259
260 if desc_data_config.collect_oneofs_data {
261 desc_data.oneofs.push(Oneof {
262 name: oneof.name().to_string(),
263 parent_message: full_ish_name(parent_message, package).to_string(),
264 package: package.to_string(),
265 });
266 }
267
268 config.enum_attribute(oneof.full_name(), "#[derive(::protify::ValidatedOneof)]");
269 #[cfg(feature = "cel")]
270 {
271 config.enum_attribute(oneof.full_name(), "#[derive(::protify::CelOneof)]");
272 }
273 config.enum_attribute(
274 oneof.full_name(),
275 format!(r#"#[proto(parent_message = "{parent_message}")]"#),
276 );
277 }
278 }
279 }
280
281 for enum_desc in pool.all_enums() {
282 let package = enum_desc.package_name();
283
284 if packages.contains(&package) {
285 let enum_full_ish_name = full_ish_name(enum_desc.full_name(), package);
286
287 if desc_data_config.collect_enums_data {
288 desc_data.enums.push(Enum {
289 name: enum_desc.name().to_string(),
290 parent_message: enum_desc
291 .parent_message()
292 .map(|p| full_ish_name(p.full_name(), package).to_string()),
293 package: package.to_string(),
294 });
295 }
296
297 config.enum_attribute(enum_desc.full_name(), "#[derive(::protify::ProtoEnum)]");
298 config.enum_attribute(
299 enum_desc.full_name(),
300 format!(r#"#[proto(name = "{enum_full_ish_name}")]"#),
301 );
302 }
303 }
304
305 Ok(desc_data)
306}
307
308pub fn get_proto_files(base_dir: impl Into<PathBuf>) -> io::Result<Vec<String>> {
312 let base_dir: PathBuf = base_dir.into();
313 let mut proto_files = Vec::new();
314
315 if !base_dir.is_dir() {
316 return Err(io::Error::new(
317 io::ErrorKind::InvalidInput,
318 format!("Path {} is not a directory.", base_dir.display()),
319 ));
320 }
321
322 for entry in fs::read_dir(base_dir)? {
323 let entry = entry?;
324 let path = entry.path();
325
326 if path.is_file() && path.extension().is_some_and(|ext| ext == "proto") {
327 proto_files.push(
328 path.to_str()
329 .ok_or_else(|| {
330 io::Error::new(
331 io::ErrorKind::InvalidData,
332 format!("Path {} contains invalid Unicode.", path.display()),
333 )
334 })?
335 .to_owned(),
336 );
337 }
338 }
339
340 Ok(proto_files)
341}