Skip to main content

connect2axum_codegen/
lib.rs

1//! Protoc/Buf code generation for `connect2axum`.
2//!
3
4mod error;
5pub(crate) mod internal;
6
7use crate::internal::ir::build_ir;
8use crate::internal::options::CodegenOptions;
9use crate::internal::{asyncapi, openapi, rest, ws};
10
11pub use connectrpc_codegen::plugin::{CodeGeneratorRequest, CodeGeneratorResponse};
12pub use error::{CodegenErrKind, CodegenResult};
13
14/// Generate a REST protoc plugin response for a request.
15///
16/// Errors are returned through the protoc plugin error field so `buf generate`
17/// and `protoc` can display them as compiler-plugin failures.
18#[must_use]
19pub fn generate_rest(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
20    match try_generate_rest(request) {
21        Ok(response) => response,
22        Err(err) => CodeGeneratorResponse {
23            error: Some(err.to_string()),
24            ..Default::default()
25        },
26    }
27}
28
29/// Generate a REST protoc plugin response, returning typed project errors.
30pub fn try_generate_rest(request: &CodeGeneratorRequest) -> CodegenResult<CodeGeneratorResponse> {
31    let options = CodegenOptions::parse(request.parameter.as_deref())?;
32    let ir = build_ir(request)?;
33    let files = request
34        .file_to_generate
35        .iter()
36        .map(|file_name| rest::generate_file(&ir, file_name, &options))
37        .collect::<CodegenResult<Vec<_>>>()?
38        .into_iter()
39        .flatten()
40        .collect();
41
42    Ok(CodeGeneratorResponse {
43        file: files,
44        ..Default::default()
45    })
46}
47
48/// Generate a WebSocket protoc plugin response for a request.
49///
50/// Errors are returned through the protoc plugin error field so `buf generate`
51/// and `protoc` can display them as compiler-plugin failures.
52#[must_use]
53pub fn generate_ws(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
54    match try_generate_ws(request) {
55        Ok(response) => response,
56        Err(err) => CodeGeneratorResponse {
57            error: Some(err.to_string()),
58            ..Default::default()
59        },
60    }
61}
62
63/// Generate a WebSocket protoc plugin response, returning typed project errors.
64pub fn try_generate_ws(request: &CodeGeneratorRequest) -> CodegenResult<CodeGeneratorResponse> {
65    let options = CodegenOptions::parse(request.parameter.as_deref())?;
66    let ir = build_ir(request)?;
67    let files = request
68        .file_to_generate
69        .iter()
70        .map(|file_name| ws::generate_file(&ir, file_name, &options))
71        .collect::<CodegenResult<Vec<_>>>()?
72        .into_iter()
73        .flatten()
74        .collect();
75
76    Ok(CodeGeneratorResponse {
77        file: files,
78        ..Default::default()
79    })
80}
81
82/// Generate a merged OpenAPI v3.1 protoc plugin response for a request.
83///
84/// This delegates schema and comment harvesting to grpc-gateway's
85/// `protoc-gen-openapiv3`, then patches the result to match connect2axum REST
86/// behavior.
87#[must_use]
88pub fn generate_openapi(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
89    match try_generate_openapi(request) {
90        Ok(response) => response,
91        Err(err) => CodeGeneratorResponse {
92            error: Some(err.to_string()),
93            ..Default::default()
94        },
95    }
96}
97
98/// Generate a merged OpenAPI v3.1 protoc plugin response, returning typed
99/// project errors.
100pub fn try_generate_openapi(
101    request: &CodeGeneratorRequest,
102) -> CodegenResult<CodeGeneratorResponse> {
103    openapi::generate(request)
104}
105
106/// Generate an AsyncAPI v3.1 protoc plugin response for generated WebSocket
107/// routes.
108///
109/// Errors are returned through the protoc plugin error field so `buf generate`
110/// and `protoc` can display them as compiler-plugin failures.
111#[must_use]
112pub fn generate_asyncapi(request: &CodeGeneratorRequest) -> CodeGeneratorResponse {
113    match try_generate_asyncapi(request) {
114        Ok(response) => response,
115        Err(err) => CodeGeneratorResponse {
116            error: Some(err.to_string()),
117            ..Default::default()
118        },
119    }
120}
121
122/// Generate an AsyncAPI v3.1 protoc plugin response, returning typed project
123/// errors.
124pub fn try_generate_asyncapi(
125    request: &CodeGeneratorRequest,
126) -> CodegenResult<CodeGeneratorResponse> {
127    asyncapi::generate(request)
128}
129
130#[cfg(test)]
131mod tests {
132    use buffa::Message as _;
133    use buffa::encoding::{Tag, WireType};
134    use buffa::{MessageField, UnknownField, UnknownFieldData};
135    use connectrpc_codegen::codegen::descriptor::{
136        DescriptorProto, FieldDescriptorProto, FileDescriptorProto, MethodDescriptorProto,
137        MethodOptions, ServiceDescriptorProto,
138        field_descriptor_proto::{Label, Type},
139    };
140
141    use super::{CodeGeneratorRequest, CodeGeneratorResponse, generate_rest, try_generate_rest};
142
143    #[test]
144    fn empty_request_generates_empty_response() {
145        let request = CodeGeneratorRequest::default();
146
147        let response = generate_rest(&request);
148
149        assert!(response.file.is_empty());
150        assert!(response.error.is_none());
151    }
152
153    #[test]
154    fn unknown_option_generates_plugin_error_response() {
155        let request = CodeGeneratorRequest {
156            parameter: Some("surprise=true".into()),
157            ..Default::default()
158        };
159
160        let response = generate_rest(&request);
161
162        assert!(response.file.is_empty());
163        assert!(
164            response
165                .error
166                .as_deref()
167                .is_some_and(|err| err.contains("unknown plugin option: surprise"))
168        );
169    }
170
171    #[test]
172    fn generates_deterministic_file_names_for_files_with_http_bindings() {
173        let request = CodeGeneratorRequest {
174            file_to_generate: vec!["hello/v1/hello.proto".into(), "echo.proto".into()],
175            proto_file: vec![
176                test_file("hello/v1/hello.proto", "hello.v1", true),
177                test_file("echo.proto", "echo.v1", true),
178            ],
179            ..Default::default()
180        };
181
182        let response = try_generate_rest(&request).unwrap();
183
184        let names = response
185            .file
186            .iter()
187            .map(|file| file.name.as_deref())
188            .collect::<Vec<_>>();
189        assert_eq!(
190            names,
191            vec![
192                Some("hello/v1/hello.connect2rest.rs"),
193                Some("echo.connect2rest.rs")
194            ]
195        );
196    }
197
198    #[test]
199    fn skips_files_without_http_bindings() {
200        let request = CodeGeneratorRequest {
201            file_to_generate: vec!["hello/v1/hello.proto".into()],
202            proto_file: vec![test_file("hello/v1/hello.proto", "hello.v1", false)],
203            ..Default::default()
204        };
205
206        let response = try_generate_rest(&request).unwrap();
207
208        assert!(response.file.is_empty());
209    }
210
211    #[test]
212    fn missing_file_to_generate_is_a_typed_error() {
213        let request = CodeGeneratorRequest {
214            file_to_generate: vec!["missing.proto".into()],
215            proto_file: vec![],
216            ..Default::default()
217        };
218
219        let err = try_generate_rest(&request).unwrap_err();
220
221        assert!(
222            err.to_string()
223                .contains("file_to_generate \"missing.proto\" was not present in proto_file")
224        );
225    }
226
227    #[test]
228    fn plugin_protocol_messages_round_trip() {
229        let request = CodeGeneratorRequest::default();
230        let request_bytes = request.encode_to_vec();
231        let decoded_request =
232            CodeGeneratorRequest::decode_from_slice(&request_bytes).expect("request decodes");
233
234        let response = generate_rest(&decoded_request);
235        let response_bytes = response.encode_to_vec();
236        let decoded_response =
237            CodeGeneratorResponse::decode_from_slice(&response_bytes).expect("response decodes");
238
239        assert!(decoded_response.file.is_empty());
240        assert!(decoded_response.error.is_none());
241    }
242
243    fn test_file(name: &str, package: &str, with_http_binding: bool) -> FileDescriptorProto {
244        FileDescriptorProto {
245            name: Some(name.into()),
246            package: Some(package.into()),
247            message_type: vec![
248                DescriptorProto {
249                    name: Some("HelloRequest".into()),
250                    field: vec![FieldDescriptorProto {
251                        name: Some("name".into()),
252                        number: Some(1),
253                        label: Some(Label::LABEL_OPTIONAL),
254                        r#type: Some(Type::TYPE_STRING),
255                        json_name: Some("name".into()),
256                        ..Default::default()
257                    }],
258                    ..Default::default()
259                },
260                DescriptorProto {
261                    name: Some("HelloResponse".into()),
262                    ..Default::default()
263                },
264            ],
265            service: vec![ServiceDescriptorProto {
266                name: Some("HelloService".into()),
267                method: vec![MethodDescriptorProto {
268                    name: Some("SayHello".into()),
269                    input_type: Some(format!(".{package}.HelloRequest")),
270                    output_type: Some(format!(".{package}.HelloResponse")),
271                    options: if with_http_binding {
272                        method_options()
273                    } else {
274                        MessageField::none()
275                    },
276                    ..Default::default()
277                }],
278                ..Default::default()
279            }],
280            ..Default::default()
281        }
282    }
283
284    fn method_options() -> MessageField<MethodOptions> {
285        let mut rule = Vec::new();
286        Tag::new(2, WireType::LengthDelimited).encode(&mut rule);
287        buffa::types::encode_string("/hello/{name}", &mut rule);
288
289        let mut options = MethodOptions::default();
290        options.__buffa_unknown_fields.push(UnknownField {
291            number: 72_295_728,
292            data: UnknownFieldData::LengthDelimited(rule),
293        });
294        MessageField::some(options)
295    }
296}