use buffa_codegen::idents::escape_mod_ident;
use connectrpc_codegen::plugin::CodeGeneratorResponseFile;
use flexstr::{IntoOptimizedFlexStr as _, SharedStr};
use heck::ToSnakeCase;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Ident, Path, Type};
use uni_error::UniError;
use crate::error::{CodegenErrKind, CodegenResult};
use crate::internal::guardrails::{ensure_unique_generated_identifiers, ensure_unique_routes};
use crate::internal::ir::{CommentSet, DescriptorIr, Method, ProtoFile, Service};
use crate::internal::options::CodegenOptions;
use crate::internal::resolver::TypeResolver;
use crate::internal::shape::{FileShapes, RequestShape, plan_file_shapes};
const WS_MODULE_SUFFIX: &str = "_ws";
const ROUTER_FUNCTION_NAME: &str = "make_router";
pub fn generate_file(
ir: &DescriptorIr,
file_name: &str,
options: &CodegenOptions,
) -> CodegenResult<Option<CodeGeneratorResponseFile>> {
let proto_file = ir.file(file_name).ok_or_else(|| {
UniError::from_kind_context(
CodegenErrKind::FileToGenerateNotFound,
format!("file_to_generate {file_name:?} was not present in proto_file"),
)
})?;
if !proto_file.has_http_bindings() {
return Ok(None);
}
let Some(content) = RustGenerator::new(ir, proto_file, options)?.generate()? else {
return Ok(None);
};
Ok(Some(CodeGeneratorResponseFile {
name: Some(output_file_name(proto_file.name.as_ref())),
content: Some(content),
..Default::default()
}))
}
fn output_file_name(proto_file_name: &str) -> String {
let stem = proto_file_name
.strip_suffix(".proto")
.unwrap_or(proto_file_name);
format!("{stem}.connect2ws.rs")
}
struct RustGenerator<'a> {
proto_file: &'a ProtoFile,
options: &'a CodegenOptions,
resolver: TypeResolver<'a>,
shapes: FileShapes,
}
impl<'a> RustGenerator<'a> {
fn new(
ir: &'a DescriptorIr,
proto_file: &'a ProtoFile,
options: &'a CodegenOptions,
) -> CodegenResult<Self> {
Ok(Self {
proto_file,
options,
resolver: TypeResolver::new(ir, options),
shapes: plan_file_shapes(ir, proto_file, options)?,
})
}
fn generate(&self) -> CodegenResult<Option<String>> {
self.validate_file()?;
let modules = self
.proto_file
.services
.iter()
.map(|service| self.service_module(service))
.collect::<CodegenResult<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
if modules.is_empty() {
return Ok(None);
}
let source = format_tokens(quote! {
#(#modules)*
})?;
Ok(Some(format!(
"// @generated by connect2ws.\n\
// This file is generated from {}. Do not edit by hand.\n\n\
{source}",
self.proto_file.name.as_ref()
)))
}
fn service_module(&self, service: &Service) -> CodegenResult<Option<TokenStream>> {
let ws_methods = self.ws_methods(service)?;
if ws_methods.is_empty() {
return Ok(None);
}
self.validate_service(service, &ws_methods)?;
let module_ident = parse_ident(
&service_module_name(service.name.as_ref()),
"generated WebSocket module",
)?;
let handler_items = ws_methods
.iter()
.map(|method| self.handler_tokens(service, method))
.collect::<CodegenResult<Vec<_>>>()?;
let router_item = self.router_tokens(service, &ws_methods)?;
Ok(Some(quote! {
pub mod #module_ident {
#![allow(unused_imports)]
use std::sync::Arc;
use axum::extract::{State, WebSocketUpgrade};
use http::{Extensions, HeaderMap};
#(#handler_items)*
#router_item
}
}))
}
fn ws_methods<'m>(&self, service: &'m Service) -> CodegenResult<Vec<WsMethod<'m>>> {
let mut ws_methods = Vec::new();
for method in &service.methods {
let Some(ws_method) = self.ws_method(service, method)? else {
continue;
};
ws_methods.push(ws_method);
}
Ok(ws_methods)
}
fn ws_method<'m>(
&self,
service: &Service,
method: &'m Method,
) -> CodegenResult<Option<WsMethod<'m>>> {
let Some(binding) = method.http.as_ref() else {
return Ok(None);
};
let Some(kind) = WsMethodKind::from_method(method) else {
return Ok(None);
};
let shape = self.shape_for(method)?;
if matches!(kind, WsMethodKind::Server) && has_path_or_query(shape) {
eprintln!(
"warning: connect2ws skipping {}.{} at {} because WebSocket generation does not support path/query bindings for server-streaming methods",
service.full_name.as_ref(),
method.name.as_ref(),
binding.path.as_ref()
);
return Ok(None);
}
Ok(Some(WsMethod {
method,
kind,
route_path: ws_route_path(binding.path.as_ref()),
}))
}
fn handler_tokens(
&self,
service: &Service,
ws_method: &WsMethod<'_>,
) -> CodegenResult<TokenStream> {
let method = ws_method.method;
let comments = comment_attrs(&method.comments);
let service_trait = parse_path(
self.resolver
.connect_service_trait(service.full_name.as_ref())
.as_str(),
"Connect service trait",
)?;
let method_ident = parse_ident(
self.resolver.method_fn_name(method.name.as_ref()).as_ref(),
"WebSocket handler function",
)?;
let service_value = parse_ident(
self.resolver.value_ident("service", self.options).as_ref(),
"service binding",
)?;
let ctx_value = parse_ident(
self.resolver.value_ident("ctx", self.options).as_ref(),
"context binding",
)?;
let headers_value = parse_ident(
self.resolver.value_ident("headers", self.options).as_ref(),
"headers binding",
)?;
let extensions_value = parse_ident(
self.resolver
.value_ident("extensions", self.options)
.as_ref(),
"extensions binding",
)?;
let ws_upgrade_value = parse_ident(
self.resolver
.value_ident("ws_upgrade", self.options)
.as_ref(),
"WebSocket upgrade binding",
)?;
let stream_value = parse_ident(
self.resolver.value_ident("stream", self.options).as_ref(),
"WebSocket stream binding",
)?;
let sink_value = parse_ident(
self.resolver.value_ident("sink", self.options).as_ref(),
"WebSocket sink binding",
)?;
let request_value = parse_ident(
self.resolver.value_ident("request", self.options).as_ref(),
"request binding",
)?;
let response_value = parse_ident(
self.resolver.value_ident("response", self.options).as_ref(),
"response binding",
)?;
let err_value = parse_ident(
self.resolver.value_ident("err", self.options).as_ref(),
"error binding",
)?;
let runtime = parse_path(self.options.runtime_module.as_ref(), "runtime module")?;
let output_type = parse_type(
self.resolver
.owned_message_type(method.output_type.as_ref())?
.as_str(),
"method output type",
)?;
let request_view_type = parse_type(
&format!(
"{}<'static>",
self.resolver
.view_message_type(method.input_type.as_ref())?
.as_str()
),
"method request view type",
)?;
let callback_body = match ws_method.kind {
WsMethodKind::Server => quote! {
let #request_value = match #runtime::make_ws_request::<#request_view_type>(#stream_value).await {
Ok(#request_value) => #request_value,
Err(#err_value) => {
#runtime::close_ws(#sink_value, #err_value).await;
return;
}
};
let #response_value = #service_value.#method_ident(#ctx_value, #request_value).await;
#runtime::process_ws_stream_response::<#output_type, _>(
#response_value,
#sink_value
).await;
},
WsMethodKind::Client => quote! {
let #request_value = #runtime::make_ws_stream_request::<#request_view_type>(#stream_value);
let #response_value = #service_value.#method_ident(#ctx_value, #request_value).await;
#runtime::process_ws_response::<#output_type, _>(
#response_value,
#sink_value
).await;
},
WsMethodKind::Bidi => quote! {
let #request_value = #runtime::make_ws_stream_request::<#request_view_type>(#stream_value);
let #response_value = #service_value.#method_ident(#ctx_value, #request_value).await;
#runtime::process_ws_stream_response::<#output_type, _>(
#response_value,
#sink_value
).await;
},
};
Ok(quote! {
#(#comments)*
pub async fn #method_ident<S>(
State(#service_value): State<Arc<S>>,
#ws_upgrade_value: WebSocketUpgrade,
#headers_value: HeaderMap,
#extensions_value: Extensions,
) -> axum::response::Response
where
S: #service_trait + Send + Sync + 'static,
{
#runtime::upgrade_to_ws(
#ws_upgrade_value,
#headers_value,
#extensions_value,
|#headers_value, #extensions_value, #stream_value, #sink_value| async move {
let #ctx_value = #runtime::request_context(#headers_value, #extensions_value);
#callback_body
},
).await
}
})
}
fn router_tokens(
&self,
service: &Service,
ws_methods: &[WsMethod<'_>],
) -> CodegenResult<TokenStream> {
let service_trait = parse_path(
self.resolver
.connect_service_trait(service.full_name.as_ref())
.as_str(),
"Connect service trait",
)?;
let router_ident = parse_ident(ROUTER_FUNCTION_NAME, "router function")?;
let route_calls = ws_methods
.iter()
.map(|method| self.route_tokens(method))
.collect::<CodegenResult<Vec<_>>>()?;
Ok(quote! {
pub fn #router_ident<S>(service: Arc<S>) -> axum::Router
where
S: #service_trait + Send + Sync + 'static,
{
axum::Router::new()
#(#route_calls)*
.with_state(service)
}
})
}
fn route_tokens(&self, ws_method: &WsMethod<'_>) -> CodegenResult<TokenStream> {
let path = ws_method.route_path.as_ref();
let method_ident = parse_ident(
self.resolver
.method_fn_name(ws_method.method.name.as_ref())
.as_ref(),
"WebSocket route handler",
)?;
Ok(quote! {
.route(#path, axum::routing::any(#method_ident::<S>))
})
}
fn shape_for(&self, method: &Method) -> CodegenResult<&RequestShape> {
self.shapes
.request_shapes
.iter()
.find(|shape| shape.method == method.full_name)
.ok_or_else(|| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!(
"request shape for {} was not planned",
method.full_name.as_ref()
),
)
})
}
fn validate_file(&self) -> CodegenResult<()> {
let scope = format!("WebSocket file {}", self.proto_file.name.as_ref());
ensure_unique_generated_identifiers(
&scope,
self.proto_file
.services
.iter()
.filter(|service| {
service.methods.iter().any(|method| {
method.http.is_some() && WsMethodKind::from_method(method).is_some()
})
})
.map(|service| {
(
service_module_name(service.name.as_ref()),
format!("service {}", service.full_name.as_ref()),
)
}),
)
}
fn validate_service(
&self,
service: &Service,
ws_methods: &[WsMethod<'_>],
) -> CodegenResult<()> {
let scope = format!("WebSocket service {}", service.full_name.as_ref());
ensure_unique_generated_identifiers(
&scope,
ws_methods.iter().map(|ws_method| {
(
self.resolver.method_fn_name(ws_method.method.name.as_ref()),
format!("method {}", ws_method.method.full_name.as_ref()),
)
}),
)?;
ensure_unique_routes(
&scope,
ws_methods.iter().map(|ws_method| {
(
ws_method.route_path.clone(),
format!("method {}", ws_method.method.full_name.as_ref()),
)
}),
)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum WsMethodKind {
Server,
Client,
Bidi,
}
impl WsMethodKind {
fn from_method(method: &Method) -> Option<Self> {
match (method.client_streaming, method.server_streaming) {
(false, false) => None,
(false, true) => Some(Self::Server),
(true, false) => Some(Self::Client),
(true, true) => Some(Self::Bidi),
}
}
}
struct WsMethod<'a> {
method: &'a Method,
kind: WsMethodKind,
route_path: SharedStr,
}
fn has_path_or_query(shape: &RequestShape) -> bool {
!shape.path_fields.is_empty() || shape.query_shape.is_some()
}
pub fn ws_route_path(path: &str) -> SharedStr {
let path = path.trim_end_matches('/');
if path.is_empty() {
"/ws".into()
} else {
format!("{path}/ws").into_opt()
}
}
fn comment_attrs(comments: &CommentSet) -> Vec<TokenStream> {
comments
.leading_detached
.iter()
.chain(comments.leading.iter())
.flat_map(|comment| comment.lines())
.map(str::trim_end)
.filter(|line| !line.is_empty())
.map(|line| {
quote! {
#[doc = #line]
}
})
.collect()
}
fn service_module_name(service_name: &str) -> SharedStr {
format!(
"{}{WS_MODULE_SUFFIX}",
escape_mod_ident(&service_name.to_snake_case())
)
.into_opt()
}
fn parse_ident(value: &str, context: &'static str) -> CodegenResult<Ident> {
syn::parse_str(value).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("{context} identifier {value:?} did not parse: {err}"),
)
})
}
fn parse_path(value: &str, context: &'static str) -> CodegenResult<Path> {
syn::parse_str(value).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("{context} path {value:?} did not parse: {err}"),
)
})
}
fn parse_type(value: &str, context: &'static str) -> CodegenResult<Type> {
syn::parse_str(value).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("{context} {value:?} did not parse: {err}"),
)
})
}
fn format_tokens(tokens: TokenStream) -> CodegenResult<String> {
let file = syn::parse2(tokens.clone()).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("generated Rust file did not parse: {err}\n{tokens}"),
)
})?;
Ok(prettyplease::unparse(&file))
}
#[cfg(test)]
mod tests {
use buffa::encoding::{Tag, WireType};
use buffa::{MessageField, UnknownField, UnknownFieldData};
use connectrpc_codegen::codegen::descriptor::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, MethodDescriptorProto,
MethodOptions, ServiceDescriptorProto,
field_descriptor_proto::{Label, Type},
};
use crate::{CodeGeneratorRequest, try_generate_ws};
#[test]
fn generates_expected_websocket_handlers() {
let response = try_generate_ws(&streaming_request()).unwrap();
assert_eq!(response.file.len(), 1);
assert_eq!(
response.file[0].name.as_deref(),
Some("test/v1/streaming.connect2ws.rs")
);
let content = response.file[0].content.as_deref().unwrap();
assert!(content.contains("// @generated by connect2ws."));
assert!(content.contains("pub mod test_service_ws"));
assert!(content.contains("pub async fn server_stream"));
assert!(content.contains("pub async fn client_stream"));
assert!(content.contains("pub async fn bidi_stream"));
assert!(!content.contains("pub async fn unary"));
assert!(content.contains("::connect2axum::make_ws_request::<"));
assert!(content.contains("::connect2axum::make_ws_stream_request::<"));
assert!(content.contains("::connect2axum::process_ws_response::<"));
assert!(content.contains("::connect2axum::process_ws_stream_response::<"));
assert!(content.contains(".route(\"/test:server-stream/ws\""));
assert!(content.contains(".route(\"/test:client-stream/ws\""));
assert!(content.contains(".route(\"/test:bidi-stream/ws\""));
}
#[test]
fn skips_files_without_websocket_methods() {
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/unary.proto".into()],
proto_file: vec![unary_test_file()],
..Default::default()
};
let response = try_generate_ws(&request).unwrap();
assert!(response.file.is_empty());
}
#[test]
fn skips_server_streaming_methods_with_path_or_query_bindings() {
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/path.proto".into()],
proto_file: vec![server_streaming_path_file()],
..Default::default()
};
let response = try_generate_ws(&request).unwrap();
assert!(response.file.is_empty());
}
#[test]
fn rejects_duplicate_websocket_routes() {
let mut file = streaming_test_file();
file.service[0].method.push(streaming_method(
"ServerStreamAgain",
http_rule(4, "/test:server-stream", Some("*")),
false,
true,
));
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/streaming.proto".into()],
proto_file: vec![file],
..Default::default()
};
let err = try_generate_ws(&request).unwrap_err();
assert!(err.to_string().contains("duplicate route"));
assert!(err.to_string().contains("/test:server-stream/ws"));
}
#[test]
fn rejects_duplicate_websocket_handler_idents() {
let mut file = streaming_test_file();
file.service[0].method.push(streaming_method(
"Bidi_Stream",
http_rule(4, "/test:bidi-stream-again", Some("*")),
true,
true,
));
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/streaming.proto".into()],
proto_file: vec![file],
..Default::default()
};
let err = try_generate_ws(&request).unwrap_err();
assert!(
err.to_string()
.contains("duplicate generated Rust identifier")
);
assert!(err.to_string().contains("bidi_stream"));
}
fn streaming_request() -> CodeGeneratorRequest {
CodeGeneratorRequest {
file_to_generate: vec!["test/v1/streaming.proto".into()],
proto_file: vec![streaming_test_file()],
..Default::default()
}
}
fn streaming_test_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("test/v1/streaming.proto".into()),
package: Some("test.v1".into()),
message_type: vec![
DescriptorProto {
name: Some("TestRequest".into()),
field: vec![field("data", 1, Type::TYPE_STRING)],
..Default::default()
},
DescriptorProto {
name: Some("TestResponse".into()),
field: vec![field("message", 1, Type::TYPE_STRING)],
..Default::default()
},
],
service: vec![ServiceDescriptorProto {
name: Some("TestService".into()),
method: vec![
streaming_method(
"Unary",
http_rule(4, "/test:unary", Some("*")),
false,
false,
),
streaming_method(
"ServerStream",
http_rule(4, "/test:server-stream", Some("*")),
false,
true,
),
streaming_method(
"ClientStream",
http_rule(4, "/test:client-stream", Some("*")),
true,
false,
),
streaming_method(
"BidiStream",
http_rule(4, "/test:bidi-stream", Some("*")),
true,
true,
),
],
..Default::default()
}],
..Default::default()
}
}
fn unary_test_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("test/v1/unary.proto".into()),
package: Some("test.v1".into()),
message_type: vec![
DescriptorProto {
name: Some("TestRequest".into()),
field: vec![field("data", 1, Type::TYPE_STRING)],
..Default::default()
},
DescriptorProto {
name: Some("TestResponse".into()),
..Default::default()
},
],
service: vec![ServiceDescriptorProto {
name: Some("TestService".into()),
method: vec![streaming_method(
"Unary",
http_rule(4, "/test:unary", Some("*")),
false,
false,
)],
..Default::default()
}],
..Default::default()
}
}
fn server_streaming_path_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("test/v1/path.proto".into()),
package: Some("test.v1".into()),
message_type: vec![
DescriptorProto {
name: Some("TestRequest".into()),
field: vec![field("data", 1, Type::TYPE_STRING)],
..Default::default()
},
DescriptorProto {
name: Some("TestResponse".into()),
..Default::default()
},
],
service: vec![ServiceDescriptorProto {
name: Some("TestService".into()),
method: vec![streaming_method(
"ServerStream",
http_rule(4, "/test/{data}/server-stream", Some("*")),
false,
true,
)],
..Default::default()
}],
..Default::default()
}
}
fn streaming_method(
name: &str,
options: MessageField<MethodOptions>,
client_streaming: bool,
server_streaming: bool,
) -> MethodDescriptorProto {
MethodDescriptorProto {
name: Some(name.into()),
input_type: Some(".test.v1.TestRequest".into()),
output_type: Some(".test.v1.TestResponse".into()),
client_streaming: Some(client_streaming),
server_streaming: Some(server_streaming),
options,
..Default::default()
}
}
fn field(name: &str, number: i32, kind: Type) -> FieldDescriptorProto {
FieldDescriptorProto {
name: Some(name.into()),
number: Some(number),
label: Some(Label::LABEL_OPTIONAL),
r#type: Some(kind),
json_name: Some(name.into()),
..Default::default()
}
}
fn http_rule(verb_field: u32, path: &str, body: Option<&str>) -> MessageField<MethodOptions> {
let mut rule = Vec::new();
Tag::new(verb_field, WireType::LengthDelimited).encode(&mut rule);
buffa::types::encode_string(path, &mut rule);
if let Some(body) = body {
Tag::new(7, WireType::LengthDelimited).encode(&mut rule);
buffa::types::encode_string(body, &mut rule);
}
let mut options = MethodOptions::default();
options.__buffa_unknown_fields.push(UnknownField {
number: 72_295_728,
data: UnknownFieldData::LengthDelimited(rule),
});
MessageField::some(options)
}
}