python_proto_importer/postprocess/
fds.rs

1use anyhow::{Context, Result};
2use prost::Message;
3use prost_reflect::DescriptorPool;
4use prost_types::FileDescriptorSet;
5use std::collections::HashSet;
6use std::path::Path;
7
8/// Load a FileDescriptorSet (binary) and return a DescriptorPool
9#[allow(dead_code)]
10pub fn load_fds_from_bytes(bytes: &[u8]) -> Result<DescriptorPool> {
11    let pool = DescriptorPool::decode(bytes).context("failed to decode FileDescriptorSet")?;
12    Ok(pool)
13}
14
15/// Given a pool and a relative module path, determine if an import target
16/// corresponds to a .proto-derived module according to the pool entries.
17/// For now, this is a placeholder returning true if suffix matches _pb2 or _pb2_grpc.
18#[allow(dead_code)]
19pub fn is_proto_generated_module(module: &str) -> bool {
20    module.ends_with("_pb2") || module.ends_with("_pb2_grpc")
21}
22
23/// Decode bytes into FileDescriptorSet and collect generated module basenames
24/// like "foo_pb2", "foo_pb2_grpc" for each file in the set.
25pub fn collect_generated_basenames_from_bytes(bytes: &[u8]) -> Result<HashSet<String>> {
26    let fds = FileDescriptorSet::decode(bytes).context("decode FDS via prost-types failed")?;
27    let mut set = HashSet::new();
28    for file in fds.file {
29        if let Some(stem) = file
30            .name
31            .as_deref()
32            .and_then(|name| Path::new(name).file_stem().and_then(|s| s.to_str()))
33        {
34            set.insert(format!("{stem}_pb2"));
35            set.insert(format!("{stem}_pb2_grpc"));
36        }
37    }
38    Ok(set)
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44    use prost::Message;
45    use prost_types::{FileDescriptorProto, FileDescriptorSet};
46
47    #[test]
48    fn is_proto_generated_module_pb2() {
49        assert!(is_proto_generated_module("service_pb2"));
50        assert!(is_proto_generated_module("api.v1.service_pb2"));
51        assert!(!is_proto_generated_module("service"));
52        assert!(!is_proto_generated_module("service_pb2.something"));
53    }
54
55    #[test]
56    fn is_proto_generated_module_grpc() {
57        assert!(is_proto_generated_module("service_pb2_grpc"));
58        assert!(is_proto_generated_module("api.v1.service_pb2_grpc"));
59        assert!(!is_proto_generated_module("service_grpc"));
60        assert!(!is_proto_generated_module("service_pb2_grpc.something"));
61    }
62
63    #[test]
64    fn collect_generated_basenames_empty() {
65        let fds = FileDescriptorSet { file: vec![] };
66        let bytes = fds.encode_to_vec();
67
68        let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
69        assert!(result.is_empty());
70    }
71
72    #[test]
73    fn collect_generated_basenames_single_file() {
74        let file = FileDescriptorProto {
75            name: Some("service/api.proto".to_string()),
76            ..Default::default()
77        };
78        let fds = FileDescriptorSet { file: vec![file] };
79        let bytes = fds.encode_to_vec();
80
81        let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
82        let expected = ["api_pb2", "api_pb2_grpc"]
83            .iter()
84            .map(|s| s.to_string())
85            .collect();
86        assert_eq!(result, expected);
87    }
88
89    #[test]
90    fn collect_generated_basenames_multiple_files() {
91        let files = vec![
92            FileDescriptorProto {
93                name: Some("service/user.proto".to_string()),
94                ..Default::default()
95            },
96            FileDescriptorProto {
97                name: Some("api/payment.proto".to_string()),
98                ..Default::default()
99            },
100            FileDescriptorProto {
101                name: Some("common.proto".to_string()),
102                ..Default::default()
103            },
104        ];
105        let fds = FileDescriptorSet { file: files };
106        let bytes = fds.encode_to_vec();
107
108        let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
109        let expected = [
110            "user_pb2",
111            "user_pb2_grpc",
112            "payment_pb2",
113            "payment_pb2_grpc",
114            "common_pb2",
115            "common_pb2_grpc",
116        ]
117        .iter()
118        .map(|s| s.to_string())
119        .collect();
120        assert_eq!(result, expected);
121    }
122
123    #[test]
124    fn collect_generated_basenames_file_without_name() {
125        let files = vec![
126            FileDescriptorProto {
127                name: Some("valid.proto".to_string()),
128                ..Default::default()
129            },
130            FileDescriptorProto {
131                name: None, // This file has no name
132                ..Default::default()
133            },
134        ];
135        let fds = FileDescriptorSet { file: files };
136        let bytes = fds.encode_to_vec();
137
138        let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
139        // Should only include basenames from files with valid names
140        let expected = ["valid_pb2", "valid_pb2_grpc"]
141            .iter()
142            .map(|s| s.to_string())
143            .collect();
144        assert_eq!(result, expected);
145    }
146
147    #[test]
148    fn collect_generated_basenames_nested_paths() {
149        let file = FileDescriptorProto {
150            name: Some("deeply/nested/path/service.proto".to_string()),
151            ..Default::default()
152        };
153        let fds = FileDescriptorSet { file: vec![file] };
154        let bytes = fds.encode_to_vec();
155
156        let result = collect_generated_basenames_from_bytes(&bytes).unwrap();
157        let expected = ["service_pb2", "service_pb2_grpc"]
158            .iter()
159            .map(|s| s.to_string())
160            .collect();
161        assert_eq!(result, expected);
162    }
163
164    #[test]
165    fn collect_generated_basenames_invalid_bytes() {
166        let invalid_bytes = b"invalid protobuf data";
167        let result = collect_generated_basenames_from_bytes(invalid_bytes);
168        assert!(result.is_err());
169        assert!(
170            result
171                .unwrap_err()
172                .to_string()
173                .contains("decode FDS via prost-types failed")
174        );
175    }
176
177    #[test]
178    fn load_fds_from_bytes_valid() {
179        // Create a minimal valid FileDescriptorSet
180        let file = FileDescriptorProto {
181            name: Some("test.proto".to_string()),
182            package: Some("test".to_string()),
183            ..Default::default()
184        };
185        let fds = FileDescriptorSet { file: vec![file] };
186        let bytes = fds.encode_to_vec();
187
188        let result = load_fds_from_bytes(&bytes);
189        assert!(result.is_ok());
190    }
191
192    #[test]
193    fn load_fds_from_bytes_invalid() {
194        let invalid_bytes = b"not a valid protobuf";
195        let result = load_fds_from_bytes(invalid_bytes);
196        assert!(result.is_err());
197        assert!(
198            result
199                .unwrap_err()
200                .to_string()
201                .contains("failed to decode FileDescriptorSet")
202        );
203    }
204}