Skip to main content

connect2axum_codegen/
lib.rs

1//! Protoc/Buf code generation for `connect2axum`.
2//!
3
4mod error;
5pub(crate) mod internal;
6
7use crate::internal::ir::build_ir;
8use crate::internal::options::CodegenOptions;
9use crate::internal::{asyncapi, openapi, rest, ws};
10
11pub use connectrpc_codegen::plugin::{CodeGeneratorRequest, CodeGeneratorResponse};
12pub use error::{CodegenErrKind, CodegenResult};
13
14/// `google.protobuf.compiler.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL`.
15const FEATURE_PROTO3_OPTIONAL: u64 = 1;
16
17/// Generate a REST protoc plugin response for a request.
18///
19/// Errors are returned through the protoc plugin error field so `buf generate`
20/// and `protoc` can display them as compiler-plugin failures.
21#[must_use]
22pub fn generate_rest(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
23    match try_generate_rest(request) {
24        Ok(response) => response,
25        Err(err) => plugin_error_response(err.to_string()),
26    }
27}
28
29/// Generate a REST protoc plugin response, returning typed project errors.
30pub fn try_generate_rest(request: &CodeGeneratorRequest) -> CodegenResult<CodeGeneratorResponse> {
31    let options = CodegenOptions::parse(request.parameter.as_deref())?;
32    let ir = build_ir(request)?;
33    let files = request
34        .file_to_generate
35        .iter()
36        .map(|file_name| rest::generate_file(&ir, file_name, &options))
37        .collect::<CodegenResult<Vec<_>>>()?
38        .into_iter()
39        .flatten()
40        .collect();
41
42    Ok(with_supported_features(CodeGeneratorResponse {
43        file: files,
44        ..Default::default()
45    }))
46}
47
48/// Generate a WebSocket protoc plugin response for a request.
49///
50/// Errors are returned through the protoc plugin error field so `buf generate`
51/// and `protoc` can display them as compiler-plugin failures.
52#[must_use]
53pub fn generate_ws(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
54    match try_generate_ws(request) {
55        Ok(response) => response,
56        Err(err) => plugin_error_response(err.to_string()),
57    }
58}
59
60/// Generate a WebSocket protoc plugin response, returning typed project errors.
61pub fn try_generate_ws(request: &CodeGeneratorRequest) -> CodegenResult<CodeGeneratorResponse> {
62    let options = CodegenOptions::parse(request.parameter.as_deref())?;
63    let ir = build_ir(request)?;
64    let files = request
65        .file_to_generate
66        .iter()
67        .map(|file_name| ws::generate_file(&ir, file_name, &options))
68        .collect::<CodegenResult<Vec<_>>>()?
69        .into_iter()
70        .flatten()
71        .collect();
72
73    Ok(with_supported_features(CodeGeneratorResponse {
74        file: files,
75        ..Default::default()
76    }))
77}
78
79/// Generate a merged OpenAPI v3.1 protoc plugin response for a request.
80///
81/// This delegates schema and comment harvesting to grpc-gateway's
82/// `protoc-gen-openapiv3`, then patches the result to match connect2axum REST
83/// behavior.
84#[must_use]
85pub fn generate_openapi(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
86    match try_generate_openapi(request) {
87        Ok(response) => response,
88        Err(err) => plugin_error_response(err.to_string()),
89    }
90}
91
92/// Generate a merged OpenAPI v3.1 protoc plugin response, returning typed
93/// project errors.
94pub fn try_generate_openapi(
95    request: &CodeGeneratorRequest,
96) -> CodegenResult<CodeGeneratorResponse> {
97    openapi::generate(request).map(with_supported_features)
98}
99
100/// Generate an AsyncAPI v3.1 protoc plugin response for generated WebSocket
101/// routes.
102///
103/// Errors are returned through the protoc plugin error field so `buf generate`
104/// and `protoc` can display them as compiler-plugin failures.
105#[must_use]
106pub fn generate_asyncapi(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
107    match try_generate_asyncapi(request) {
108        Ok(response) => response,
109        Err(err) => plugin_error_response(err.to_string()),
110    }
111}
112
113/// Generate an AsyncAPI v3.1 protoc plugin response, returning typed project
114/// errors.
115pub fn try_generate_asyncapi(
116    request: &CodeGeneratorRequest,
117) -> CodegenResult<CodeGeneratorResponse> {
118    asyncapi::generate(request).map(with_supported_features)
119}
120
121fn plugin_error_response(error: String) -> CodeGeneratorResponse {
122    with_supported_features(CodeGeneratorResponse {
123        error: Some(error),
124        ..Default::default()
125    })
126}
127
128fn with_supported_features(mut response: CodeGeneratorResponse) -> CodeGeneratorResponse {
129    response.supported_features =
130        Some(response.supported_features.unwrap_or_default() | FEATURE_PROTO3_OPTIONAL);
131    response
132}
133
134#[cfg(test)]
135mod tests {
136    use buffa::Message as _;
137    use buffa::encoding::{Tag, WireType};
138    use buffa::{MessageField, UnknownField, UnknownFieldData};
139    use connectrpc_codegen::codegen::descriptor::{
140        DescriptorProto, FieldDescriptorProto, FileDescriptorProto, MethodDescriptorProto,
141        MethodOptions, ServiceDescriptorProto,
142        field_descriptor_proto::{Label, Type},
143    };
144
145    use super::{
146        CodeGeneratorRequest, CodeGeneratorResponse, FEATURE_PROTO3_OPTIONAL, generate_asyncapi,
147        generate_openapi, generate_rest, generate_ws, try_generate_rest,
148    };
149
150    #[test]
151    fn empty_request_generates_empty_response() {
152        let request = CodeGeneratorRequest::default();
153
154        let response = generate_rest(&request);
155
156        assert!(response.file.is_empty());
157        assert!(response.error.is_none());
158    }
159
160    #[test]
161    fn generators_advertise_proto3_optional_support() {
162        let request = CodeGeneratorRequest::default();
163
164        for response in [
165            generate_rest(&request),
166            generate_ws(&request),
167            generate_openapi(&request),
168            generate_asyncapi(&request),
169        ] {
170            assert_eq!(
171                response.supported_features.unwrap_or_default() & FEATURE_PROTO3_OPTIONAL,
172                FEATURE_PROTO3_OPTIONAL
173            );
174        }
175    }
176
177    #[test]
178    fn unknown_option_generates_plugin_error_response() {
179        let request = CodeGeneratorRequest {
180            parameter: Some("surprise=true".into()),
181            ..Default::default()
182        };
183
184        let response = generate_rest(&request);
185
186        assert!(response.file.is_empty());
187        assert!(
188            response
189                .error
190                .as_deref()
191                .is_some_and(|err| err.contains("unknown plugin option: surprise"))
192        );
193        assert_eq!(
194            response.supported_features.unwrap_or_default() & FEATURE_PROTO3_OPTIONAL,
195            FEATURE_PROTO3_OPTIONAL
196        );
197    }
198
199    #[test]
200    fn generates_deterministic_file_names_for_files_with_http_bindings() {
201        let request = CodeGeneratorRequest {
202            file_to_generate: vec!["hello/v1/hello.proto".into(), "echo.proto".into()],
203            proto_file: vec![
204                test_file("hello/v1/hello.proto", "hello.v1", true),
205                test_file("echo.proto", "echo.v1", true),
206            ],
207            ..Default::default()
208        };
209
210        let response = try_generate_rest(&request).unwrap();
211
212        let names = response
213            .file
214            .iter()
215            .map(|file| file.name.as_deref())
216            .collect::<Vec<_>>();
217        assert_eq!(
218            names,
219            vec![
220                Some("hello/v1/hello.connect2rest.rs"),
221                Some("echo.connect2rest.rs")
222            ]
223        );
224    }
225
226    #[test]
227    fn skips_files_without_http_bindings() {
228        let request = CodeGeneratorRequest {
229            file_to_generate: vec!["hello/v1/hello.proto".into()],
230            proto_file: vec![test_file("hello/v1/hello.proto", "hello.v1", false)],
231            ..Default::default()
232        };
233
234        let response = try_generate_rest(&request).unwrap();
235
236        assert!(response.file.is_empty());
237    }
238
239    #[test]
240    fn missing_file_to_generate_is_a_typed_error() {
241        let request = CodeGeneratorRequest {
242            file_to_generate: vec!["missing.proto".into()],
243            proto_file: vec![],
244            ..Default::default()
245        };
246
247        let err = try_generate_rest(&request).unwrap_err();
248
249        assert!(
250            err.to_string()
251                .contains("file_to_generate \"missing.proto\" was not present in proto_file")
252        );
253    }
254
255    #[test]
256    fn plugin_protocol_messages_round_trip() {
257        let request = CodeGeneratorRequest::default();
258        let request_bytes = request.encode_to_vec();
259        let decoded_request =
260            CodeGeneratorRequest::decode_from_slice(&request_bytes).expect("request decodes");
261
262        let response = generate_rest(&decoded_request);
263        let response_bytes = response.encode_to_vec();
264        let decoded_response =
265            CodeGeneratorResponse::decode_from_slice(&response_bytes).expect("response decodes");
266
267        assert!(decoded_response.file.is_empty());
268        assert!(decoded_response.error.is_none());
269    }
270
271    fn test_file(name: &str, package: &str, with_http_binding: bool) -> FileDescriptorProto {
272        FileDescriptorProto {
273            name: Some(name.into()),
274            package: Some(package.into()),
275            message_type: vec![
276                DescriptorProto {
277                    name: Some("HelloRequest".into()),
278                    field: vec![FieldDescriptorProto {
279                        name: Some("name".into()),
280                        number: Some(1),
281                        label: Some(Label::LABEL_OPTIONAL),
282                        r#type: Some(Type::TYPE_STRING),
283                        json_name: Some("name".into()),
284                        ..Default::default()
285                    }],
286                    ..Default::default()
287                },
288                DescriptorProto {
289                    name: Some("HelloResponse".into()),
290                    ..Default::default()
291                },
292            ],
293            service: vec![ServiceDescriptorProto {
294                name: Some("HelloService".into()),
295                method: vec![MethodDescriptorProto {
296                    name: Some("SayHello".into()),
297                    input_type: Some(format!(".{package}.HelloRequest")),
298                    output_type: Some(format!(".{package}.HelloResponse")),
299                    options: if with_http_binding {
300                        method_options()
301                    } else {
302                        MessageField::none()
303                    },
304                    ..Default::default()
305                }],
306                ..Default::default()
307            }],
308            ..Default::default()
309        }
310    }
311
312    fn method_options() -> MessageField<MethodOptions> {
313        let mut rule = Vec::new();
314        Tag::new(2, WireType::LengthDelimited).encode(&mut rule);
315        buffa::types::encode_string("/hello/{name}", &mut rule);
316
317        let mut options = MethodOptions::default();
318        options.__buffa_unknown_fields.push(UnknownField {
319            number: 72_295_728,
320            data: UnknownFieldData::LengthDelimited(rule),
321        });
322        MessageField::some(options)
323    }
324}