1#![doc = include_str!("../README.md")]
2
3use quote::format_ident;
4use syn::parse_quote;
5
6pub fn service_generator() -> Box<ServiceGenerator> {
14 Box::new(ServiceGenerator {})
15}
16
17struct Service {
18 rpc_trait_name: syn::Ident,
20
21 fqn: String,
23
24 methods: Vec<Method>,
26}
27
28struct Method {
29 name: syn::Ident,
31
32 proto_name: String,
34
35 input_type: syn::Type,
37
38 output_type: syn::Type,
40}
41
42impl Service {
43 fn from_prost(s: prost_build::Service) -> Self {
44 let fqn = format!("{}.{}", s.package, s.proto_name);
45 let rpc_trait_name = format_ident!("{}", &s.name);
46 let methods = s
47 .methods
48 .into_iter()
49 .map(|m| Method::from_prost(&s.package, &s.proto_name, m))
50 .collect();
51
52 Self {
53 rpc_trait_name,
54 fqn,
55 methods,
56 }
57 }
58}
59
60impl Method {
61 fn from_prost(pkg_name: &str, svc_name: &str, m: prost_build::Method) -> Self {
62 let as_type = |s| -> syn::Type {
63 let Ok(typ) = syn::parse_str::<syn::Type>(s) else {
64 panic!(
65 "twirp-build failed generated invalid Rust while processing {pkg}.{svc}/{name}). this is a bug in twirp-build, please file a GitHub issue",
66 pkg = pkg_name,
67 svc = svc_name,
68 name = m.proto_name,
69 );
70 };
71 typ
72 };
73
74 let input_type = as_type(&m.input_type);
75 let output_type = as_type(&m.output_type);
76 let name = format_ident!("{}", m.name);
77 let message = m.proto_name;
78
79 Self {
80 name,
81 proto_name: message,
82 input_type,
83 output_type,
84 }
85 }
86}
87
88pub struct ServiceGenerator;
89
90impl prost_build::ServiceGenerator for ServiceGenerator {
91 fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
92 let service = Service::from_prost(service);
93
94 let service_fqn_path = format!("/{}", service.fqn);
96 let mut trait_methods: Vec<syn::TraitItemFn> = Vec::with_capacity(service.methods.len());
97 let mut proxy_methods: Vec<syn::ImplItemFn> = Vec::with_capacity(service.methods.len());
98 for m in &service.methods {
99 let name = &m.name;
100 let input_type = &m.input_type;
101 let output_type = &m.output_type;
102
103 trait_methods.push(parse_quote! {
104 async fn #name(&self, req: twirp::Request<#input_type>) -> twirp::Result<twirp::Response<#output_type>>;
105 });
106
107 proxy_methods.push(parse_quote! {
108 async fn #name(&self, req: twirp::Request<#input_type>) -> twirp::Result<twirp::Response<#output_type>> {
109 T::#name(&*self, req).await
110 }
111 });
112 }
113
114 let rpc_trait_name = &service.rpc_trait_name;
115 let server_trait: syn::ItemTrait = parse_quote! {
116 #[twirp::async_trait::async_trait]
117 pub trait #rpc_trait_name: Send + Sync {
118 #(#trait_methods)*
119 }
120 };
121 let server_trait_impl: syn::ItemImpl = parse_quote! {
122 #[twirp::async_trait::async_trait]
123 impl<T> #rpc_trait_name for std::sync::Arc<T>
124 where
125 T: #rpc_trait_name + Sync + Send
126 {
127 #(#proxy_methods)*
128 }
129 };
130
131 let mut expr: syn::Expr = parse_quote! {
133 twirp::details::TwirpRouterBuilder::new(#service_fqn_path, api)
134 };
135 for m in &service.methods {
136 let name = &m.name;
137 let input_type = &m.input_type;
138 let path = format!("/{}", m.proto_name);
139
140 expr = parse_quote! {
141 #expr.route(#path, |api: T, req: twirp::Request<#input_type>| async move {
142 api.#name(req).await
143 })
144 };
145 }
146 let router: syn::ItemFn = parse_quote! {
147 pub fn router<T>(api: T) -> twirp::Router
148 where
149 T: #rpc_trait_name + Clone + Send + Sync + 'static
150 {
151 #expr.build()
152 }
153 };
154
155 let mut client_methods: Vec<syn::ImplItemFn> = Vec::with_capacity(service.methods.len());
159 for m in &service.methods {
160 let name = &m.name;
161 let input_type = &m.input_type;
162 let output_type = &m.output_type;
163 let request_path = format!("{}/{}", service.fqn, m.proto_name);
164
165 client_methods.push(parse_quote! {
166 async fn #name(&self, req: twirp::Request<#input_type>) -> twirp::Result<twirp::Response<#output_type>> {
167 self.request(#request_path, req).await
168 }
169 })
170 }
171 let client_trait: syn::ItemImpl = parse_quote! {
172 #[twirp::async_trait::async_trait]
173 impl #rpc_trait_name for twirp::client::Client {
174 #(#client_methods)*
175 }
176 };
177
178 let service_fqn = &service.fqn;
184 let handler_name = format_ident!("{rpc_trait_name}Handler");
185 let handler_struct: syn::ItemStruct = parse_quote! {
186 pub struct #handler_name {
187 inner: std::sync::Arc<dyn #rpc_trait_name>,
188 }
189 };
190 let mut method_matches: Vec<syn::Arm> = Vec::with_capacity(service.methods.len());
191 for m in &service.methods {
192 let name = &m.name;
193 let method = &m.proto_name;
194 method_matches.push(parse_quote! {
195 #method => {
196 twirp::details::encode_response(self.inner.#name(twirp::details::decode_request(req).await?).await?)
197 }
198 });
199 }
200 let handler_impl: syn::ItemImpl = parse_quote! {
201 impl #handler_name {
202 #[allow(clippy::new_ret_no_self)]
203 pub fn new<M: #rpc_trait_name + 'static>(inner: M) -> Self {
204 Self { inner: std::sync::Arc::new(inner) }
205 }
206 }
207
208 };
209 let handler_direct_impl: syn::ItemImpl = parse_quote! {
210 #[twirp::async_trait::async_trait]
211 impl twirp::client::DirectHandler for #handler_name {
212 fn service(&self) -> &str {
213 #service_fqn
214 }
215 async fn handle(&self, method: &str, req: twirp::reqwest::Request) -> twirp::Result<twirp::reqwest::Response> {
216 match method {
217 #(#method_matches)*
218 _ => Err(twirp::bad_route(format!("unknown rpc `{method}` for service `{}`, url: {:?}", #service_fqn, req.url()))),
219 }
220 }
221 }
222 };
223 let direct_api_handler: syn::ItemMod = parse_quote! {
224 #[allow(dead_code)]
225 pub mod handler {
226 use super::*;
227
228 #handler_struct
229 #handler_impl
230 #handler_direct_impl
231 }
232 };
233
234 let ast: syn::File = parse_quote! {
237 pub use twirp;
238
239 #server_trait
240 #server_trait_impl
241
242 #router
243
244 #client_trait
245
246 #direct_api_handler
247 };
248
249 let code = prettyplease::unparse(&ast);
250 buf.push_str(&code);
251 }
252}