Skip to main content

proto_build_kit/
extract.rs

1// SPDX-License-Identifier: MIT
2//! Read `MethodOptions` extension values from a descriptor pool.
3//!
4//! Custom proto `extend google.protobuf.MethodOptions` extensions store
5//! their VALUES in the binary `MethodOptions` payload. Encoding the
6//! whole `FileDescriptorSet` to bytes (the standard codegen-input form)
7//! drops them — the bytes survive, but they're treated as unknown
8//! fields on the consumer side. To read them at build time, walk the
9//! descriptor POOL instead (which `protox` builds with `prost-reflect`,
10//! preserving the extension VALUES).
11//!
12//! This module provides the generic walker: "for every method in the
13//! pool, look up extension `<fqn>` and, when present, record its
14//! string value indexed by the method's response-message FQN."
15
16use std::collections::BTreeMap;
17
18use prost_reflect::Value;
19
20/// Walk every method declared in `pool`, look for the
21/// `MethodOptions`-level extension named `extension_fqn`, and return a
22/// map keyed by **response-message FQN** (e.g. `my.v1.Resource`) with
23/// the extension's string value.
24///
25/// Methods that don't declare the extension are skipped silently.
26/// Multiple methods returning the same response type with different
27/// extension values: **first encountered wins** (per pool iteration
28/// order). Returning the same value from every method is the convention
29/// — mismatch is a service-author bug worth catching with a
30/// conformance test.
31///
32/// Returns an empty map when no methods declare the extension.
33///
34/// # Example
35///
36/// ```ignore
37/// // Given a .proto with:
38/// //   import "envelope/v1/conventions.proto";
39/// //   service UserService {
40/// //     rpc GetUser(GetUserRequest) returns (User) {
41/// //       option (envelope.v1.etag_field) = "version";
42/// //     }
43/// //   }
44/// let out = compile_protos(&["user.proto"], &[staged.path()])?;
45/// let etag_fields = extract_method_string_extension(&out.pool, "envelope.v1.etag_field");
46/// assert_eq!(etag_fields.get("my.v1.User"), Some(&"version".to_string()));
47/// ```
48#[must_use]
49pub fn extract_method_string_extension(
50    pool: &prost_reflect::DescriptorPool,
51    extension_fqn: &str,
52) -> BTreeMap<String, String> {
53    let mut out: BTreeMap<String, String> = BTreeMap::new();
54
55    for service in pool.services() {
56        for method in service.methods() {
57            let response_fqn = method.output().full_name().to_string();
58            if let Some(value) = read_string_extension(&method.options(), extension_fqn) {
59                out.entry(response_fqn).or_insert(value);
60            }
61        }
62    }
63
64    out
65}
66
67fn read_string_extension(opts: &prost_reflect::DynamicMessage, fqn: &str) -> Option<String> {
68    use prost_reflect::ReflectMessage as _;
69    let pool = opts.descriptor().parent_pool().clone();
70    let ext = pool.get_extension_by_name(fqn)?;
71    if !opts.has_extension(&ext) {
72        return None;
73    }
74    let value = opts.get_extension(&ext);
75    match &*value {
76        Value::String(s) => Some(s.clone()),
77        _ => None,
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::{Stager, compile_protos};
85
86    const CONVENTIONS_PROTO: &[u8] = br#"
87syntax = "proto3";
88package fixture.opts;
89import "google/protobuf/descriptor.proto";
90extend google.protobuf.MethodOptions {
91  optional string etag_field = 60001;
92  optional string location_field = 60002;
93}
94"#;
95
96    fn stage_conv() -> tempfile::TempDir {
97        Stager::new()
98            .add("fixture/opts/conventions.proto", CONVENTIONS_PROTO)
99            .stage()
100            .unwrap()
101    }
102
103    #[test]
104    fn extracts_string_extension_value() {
105        let proto = br#"
106syntax = "proto3";
107package fixture.v1;
108import "fixture/opts/conventions.proto";
109
110service Svc {
111  rpc Get(GetReq) returns (User) {
112    option (fixture.opts.etag_field) = "version";
113  }
114}
115message GetReq { string id = 1; }
116message User   { string id = 1; uint64 version = 2; }
117"#;
118        let staged_proto = Stager::new()
119            .add("fixture/v1/x.proto", proto)
120            .stage()
121            .unwrap();
122        let staged_conv = stage_conv();
123
124        let out = compile_protos(
125            &["fixture/v1/x.proto", "fixture/opts/conventions.proto"],
126            &[staged_proto.path(), staged_conv.path()],
127        )
128        .expect("compile");
129
130        let map = extract_method_string_extension(&out.pool, "fixture.opts.etag_field");
131        assert_eq!(map.get("fixture.v1.User"), Some(&"version".to_string()));
132    }
133
134    #[test]
135    fn returns_empty_map_when_no_methods_declare_extension() {
136        let proto = br#"
137syntax = "proto3";
138package fixture.v1;
139
140service Svc {
141  rpc Get(GetReq) returns (Resp);
142}
143message GetReq { string id = 1; }
144message Resp   { string body = 1; }
145"#;
146        let staged = Stager::new()
147            .add("fixture/v1/x.proto", proto)
148            .stage()
149            .unwrap();
150        let out = compile_protos(&["fixture/v1/x.proto"], &[staged.path()]).expect("compile");
151        let map = extract_method_string_extension(&out.pool, "fixture.opts.etag_field");
152        assert!(map.is_empty());
153    }
154
155    #[test]
156    fn first_method_wins_when_multiple_share_response_type() {
157        let proto = br#"
158syntax = "proto3";
159package fixture.v1;
160import "fixture/opts/conventions.proto";
161
162service Svc {
163  rpc First(Req)  returns (Shared) { option (fixture.opts.etag_field) = "v1"; }
164  rpc Second(Req) returns (Shared) { option (fixture.opts.etag_field) = "v2"; }
165}
166message Req    { string id = 1; }
167message Shared { string id = 1; }
168"#;
169        let staged_proto = Stager::new()
170            .add("fixture/v1/x.proto", proto)
171            .stage()
172            .unwrap();
173        let staged_conv = stage_conv();
174        let out = compile_protos(
175            &["fixture/v1/x.proto", "fixture/opts/conventions.proto"],
176            &[staged_proto.path(), staged_conv.path()],
177        )
178        .expect("compile");
179        let map = extract_method_string_extension(&out.pool, "fixture.opts.etag_field");
180        assert_eq!(map.get("fixture.v1.Shared"), Some(&"v1".to_string()));
181    }
182}