python_proto_importer/postprocess/
fds.rs1use anyhow::{Context, Result};
2use prost::Message;
3use prost_reflect::DescriptorPool;
4use prost_types::FileDescriptorSet;
5use std::collections::HashSet;
6use std::path::Path;
7
8#[allow(dead_code)]
10pub fn load_fds_from_bytes(bytes: &[u8]) -> Result<DescriptorPool> {
11 let pool = DescriptorPool::decode(bytes).context("failed to decode FileDescriptorSet")?;
12 Ok(pool)
13}
14
15#[allow(dead_code)]
19pub fn is_proto_generated_module(module: &str) -> bool {
20 module.ends_with("_pb2") || module.ends_with("_pb2_grpc")
21}
22
23pub fn collect_generated_basenames_from_bytes(bytes: &[u8]) -> Result<HashSet<String>> {
26 let fds = FileDescriptorSet::decode(bytes).context("decode FDS via prost-types failed")?;
27 let mut set = HashSet::new();
28 for file in fds.file {
29 if let Some(stem) = file
30 .name
31 .as_deref()
32 .and_then(|name| Path::new(name).file_stem().and_then(|s| s.to_str()))
33 {
34 set.insert(format!("{stem}_pb2"));
35 set.insert(format!("{stem}_pb2_grpc"));
36 }
37 }
38 Ok(set)
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44 use prost::Message;
45 use prost_types::{FileDescriptorProto, FileDescriptorSet};
46
47 #[test]
48 fn is_proto_generated_module_pb2() {
49 assert!(is_proto_generated_module("service_pb2"));
50 assert!(is_proto_generated_module("api.v1.service_pb2"));
51 assert!(!is_proto_generated_module("service"));
52 assert!(!is_proto_generated_module("service_pb2.something"));
53 }
54
55 #[test]
56 fn is_proto_generated_module_grpc() {
57 assert!(is_proto_generated_module("service_pb2_grpc"));
58 assert!(is_proto_generated_module("api.v1.service_pb2_grpc"));
59 assert!(!is_proto_generated_module("service_grpc"));
60 assert!(!is_proto_generated_module("service_pb2_grpc.something"));
61 }
62
63 #[test]
64 fn collect_generated_basenames_empty() {
65 let fds = FileDescriptorSet { file: vec![] };
66 let bytes = fds.encode_to_vec();
67
68 let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
69 assert!(result.is_empty());
70 }
71
72 #[test]
73 fn collect_generated_basenames_single_file() {
74 let file = FileDescriptorProto {
75 name: Some("service/api.proto".to_string()),
76 ..Default::default()
77 };
78 let fds = FileDescriptorSet { file: vec![file] };
79 let bytes = fds.encode_to_vec();
80
81 let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
82 let expected = ["api_pb2", "api_pb2_grpc"]
83 .iter()
84 .map(|s| s.to_string())
85 .collect();
86 assert_eq!(result, expected);
87 }
88
89 #[test]
90 fn collect_generated_basenames_multiple_files() {
91 let files = vec![
92 FileDescriptorProto {
93 name: Some("service/user.proto".to_string()),
94 ..Default::default()
95 },
96 FileDescriptorProto {
97 name: Some("api/payment.proto".to_string()),
98 ..Default::default()
99 },
100 FileDescriptorProto {
101 name: Some("common.proto".to_string()),
102 ..Default::default()
103 },
104 ];
105 let fds = FileDescriptorSet { file: files };
106 let bytes = fds.encode_to_vec();
107
108 let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
109 let expected = [
110 "user_pb2",
111 "user_pb2_grpc",
112 "payment_pb2",
113 "payment_pb2_grpc",
114 "common_pb2",
115 "common_pb2_grpc",
116 ]
117 .iter()
118 .map(|s| s.to_string())
119 .collect();
120 assert_eq!(result, expected);
121 }
122
123 #[test]
124 fn collect_generated_basenames_file_without_name() {
125 let files = vec![
126 FileDescriptorProto {
127 name: Some("valid.proto".to_string()),
128 ..Default::default()
129 },
130 FileDescriptorProto {
131 name: None, ..Default::default()
133 },
134 ];
135 let fds = FileDescriptorSet { file: files };
136 let bytes = fds.encode_to_vec();
137
138 let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
139 let expected = ["valid_pb2", "valid_pb2_grpc"]
141 .iter()
142 .map(|s| s.to_string())
143 .collect();
144 assert_eq!(result, expected);
145 }
146
147 #[test]
148 fn collect_generated_basenames_nested_paths() {
149 let file = FileDescriptorProto {
150 name: Some("deeply/nested/path/service.proto".to_string()),
151 ..Default::default()
152 };
153 let fds = FileDescriptorSet { file: vec![file] };
154 let bytes = fds.encode_to_vec();
155
156 let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
157 let expected = ["service_pb2", "service_pb2_grpc"]
158 .iter()
159 .map(|s| s.to_string())
160 .collect();
161 assert_eq!(result, expected);
162 }
163
164 #[test]
165 fn collect_generated_basenames_invalid_bytes() {
166 let invalid_bytes = b"invalid protobuf data";
167 let result = collect_generated_basenames_from_bytes(invalid_bytes);
168 assert!(result.is_err());
169 assert!(
170 result
171 .unwrap_err()
172 .to_string()
173 .contains("decode FDS via prost-types failed")
174 );
175 }
176
177 #[test]
178 fn load_fds_from_bytes_valid() {
179 let file = FileDescriptorProto {
181 name: Some("test.proto".to_string()),
182 package: Some("test".to_string()),
183 ..Default::default()
184 };
185 let fds = FileDescriptorSet { file: vec![file] };
186 let bytes = fds.encode_to_vec();
187
188 let result = load_fds_from_bytes(&bytes);
189 assert!(result.is_ok());
190 }
191
192 #[test]
193 fn load_fds_from_bytes_invalid() {
194 let invalid_bytes = b"not a valid protobuf";
195 let result = load_fds_from_bytes(invalid_bytes);
196 assert!(result.is_err());
197 assert!(
198 result
199 .unwrap_err()
200 .to_string()
201 .contains("failed to decode FileDescriptorSet")
202 );
203 }
204}