actr_framework_protoc_codegen/
modern_generator.rs1use anyhow::Result;
9use heck::ToSnakeCase;
10use prost_types::MethodDescriptorProto;
11use quote::{format_ident, quote};
12
13use crate::payload_type_extractor::extract_payload_type_or_default;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum GeneratorRole {
18 ServerSide,
20 ClientSide,
22}
23
24pub struct ModernGenerator {
26 package_name: String,
27 service_name: String,
28 role: GeneratorRole,
29}
30
31impl ModernGenerator {
32 pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
33 Self {
34 package_name: package_name.to_string(),
35 service_name: service_name.to_string(),
36 role,
37 }
38 }
39
40 pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
42 match self.role {
43 GeneratorRole::ServerSide => self.generate_server_code(methods),
44 GeneratorRole::ClientSide => self.generate_client_code(methods),
45 }
46 }
47
48 fn generate_server_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
50 let sections = [
51 self.generate_imports(),
53 self.generate_message_impls(methods)?,
55 self.generate_handler_trait(methods)?,
57 self.generate_router_impl(methods)?,
59 self.generate_workload_blanket_impl(methods)?,
61 self.generate_usage_docs(methods)?,
63 ];
64
65 Ok(sections.join("\n\n"))
66 }
67
68 fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
70 let sections = [
71 self.generate_imports(),
73 self.generate_message_impls(methods)?,
75 self.generate_context_extensions(methods)?,
77 self.generate_client_usage_docs(methods)?,
79 ];
80
81 Ok(sections.join("\n\n"))
82 }
83
84 fn generate_imports(&self) -> String {
86 let proto_module = self.package_name.replace('.', "_");
89
90 format!(
91 r#"//! 自动生成的代码 - 请勿手动编辑
92//!
93//! 由 actr-cli 的 protoc-gen-actrframework 插件生成
94
95#![allow(dead_code, unused_imports)]
96
97use async_trait::async_trait;
98use bytes::Bytes;
99use prost::Message as ProstMessage;
100
101use actr_framework::{{Context, MessageDispatcher, Workload}};
102use actr_protocol::{{ActorResult, RpcRequest, RpcEnvelope, PayloadType}};
103
104// 导入 protobuf 消息类型(由 prost 生成)
105use super::{proto_module}::*;
106"#
107 )
108 }
109
110 fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
120 let mut impls = Vec::new();
121
122 for method in methods {
123 let input_type = self.extract_message_type(method.input_type())?;
124 let output_type = self.extract_message_type(method.output_type())?;
125
126 let route_key = format!(
128 "{}.{}.{}",
129 self.package_name,
130 self.service_name,
131 method.name()
132 );
133
134 let payload_type = extract_payload_type_or_default(method);
136
137 let payload_type_code = payload_type.as_rust_variant();
139
140 let impl_code = format!(
142 r#"/// RpcRequest trait implementation - associates Request and Response types
143///
144/// This enables type-safe RPC calls with automatic response type inference:
145/// ```rust,ignore
146/// let response: {output_type} = ctx.call(&target, request).await?;
147/// ```
148impl RpcRequest for {input_type} {{
149 type Response = {output_type};
150
151 fn route_key() -> &'static str {{
152 "{route_key}"
153 }}
154
155 fn payload_type() -> PayloadType {{
156 {payload_type_code}
157 }}
158}}"#
159 );
160
161 impls.push(impl_code);
162 }
163
164 Ok(impls.join("\n\n"))
165 }
166
167 fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
169 let handler_trait_name = format!("{}Handler", self.service_name);
170 let handler_trait_ident = format_ident!("{}", handler_trait_name);
171
172 let mut method_sigs = Vec::new();
173 for method in methods {
174 let method_name = method.name().to_snake_case();
175 let method_ident = format_ident!("{}", method_name);
176 let input_type = self.extract_message_type(method.input_type())?;
177 let output_type = self.extract_message_type(method.output_type())?;
178 let input_ident = format_ident!("{}", input_type);
179 let output_ident = format_ident!("{}", output_type);
180
181 method_sigs.push(quote! {
182 async fn #method_ident<C: Context>(
184 &self,
185 req: #input_ident,
186 ctx: &C,
187 ) -> ActorResult<#output_ident>;
188 });
189 }
190
191 let handler_trait_without_attr = quote! {
192 pub trait #handler_trait_ident: Send + Sync + 'static {
208 #(#method_sigs)*
209 }
210 };
211
212 Ok(format!("#[async_trait]\n{handler_trait_without_attr}"))
214 }
215
216 fn generate_router_impl(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
218 let router_name = format!("{}Dispatcher", self.service_name);
219 let router_ident = format_ident!("{}", router_name);
220 let workload_name = format!("{}Workload", self.service_name);
221 let workload_ident = format_ident!("{}", workload_name);
222 let handler_trait = format!("{}Handler", self.service_name);
223 let handler_trait_ident = format_ident!("{}", handler_trait);
224
225 let mut match_arms = Vec::new();
227 for method in methods {
228 let route_key = format!(
229 "{}.{}.{}",
230 self.package_name,
231 self.service_name,
232 method.name()
233 );
234 let method_name = method.name().to_snake_case();
235 let method_ident = format_ident!("{}", method_name);
236 let input_type = self.extract_message_type(method.input_type())?;
237 let input_ident = format_ident!("{}", input_type);
238
239 match_arms.push(quote! {
240 #route_key => {
241 let payload = envelope.payload.as_ref()
243 .ok_or_else(|| actr_protocol::ProtocolError::DecodeError(
244 "Missing payload in RpcEnvelope".to_string()
245 ))?;
246
247 let req = #input_ident::decode(&**payload)
249 .map_err(|e| actr_protocol::ProtocolError::Actr(
250 actr_protocol::ActrError::DecodeFailure {
251 message: format!("Failed to decode {}: {}", stringify!(#input_ident), e)
252 }
253 ))?;
254
255 let resp = workload.0.#method_ident(req, ctx).await?;
257
258 Ok(resp.encode_to_vec().into())
260 }
261 });
262 }
263
264 let workload_struct = quote! {
266 pub struct #workload_ident<T: #handler_trait_ident>(pub T);
270
271 impl<T: #handler_trait_ident> #workload_ident<T> {
272 pub fn new(handler: T) -> Self {
274 Self(handler)
275 }
276 }
277 };
278
279 let router_struct = quote! {
280 pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
290 };
291
292 let router_impl_without_attr = quote! {
293 impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
294 type Workload = #workload_ident<T>;
295
296 async fn dispatch<C: Context>(
297 workload: &Self::Workload,
298 envelope: RpcEnvelope,
299 ctx: &C,
300 ) -> ActorResult<Bytes> {
301 match envelope.route_key.as_str() {
302 #(#match_arms,)*
303 _ => Err(actr_protocol::ProtocolError::Actr(
304 actr_protocol::ActrError::UnknownRoute {
305 route_key: envelope.route_key.to_string()
306 }
307 ))
308 }
309 }
310 }
311 };
312
313 let router_impl = format!("#[async_trait]\n{router_impl_without_attr}");
315
316 Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
317 }
318
319 fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
321 let router_name = format!("{}Dispatcher", self.service_name);
322 let router_ident = format_ident!("{}", router_name);
323 let workload_name = format!("{}Workload", self.service_name);
324 let workload_ident = format_ident!("{}", workload_name);
325 let handler_trait = format!("{}Handler", self.service_name);
326 let handler_trait_ident = format_ident!("{}", handler_trait);
327
328 Ok(quote! {
329 impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
333 type Dispatcher = #router_ident<T>;
334 }
335 }
336 .to_string())
337 }
338
339 fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
341 let client_struct_name = format!("{}Client", self.service_name);
342 let client_ident = format_ident!("{}", client_struct_name);
343
344 let mut client_methods = Vec::new();
345 for method in methods {
346 let method_name = method.name().to_snake_case();
347 let method_ident = format_ident!("{}", method_name);
348 let input_type = self.extract_message_type(method.input_type())?;
349 let output_type = self.extract_message_type(method.output_type())?;
350 let input_ident = format_ident!("{}", input_type);
351 let output_ident = format_ident!("{}", output_type);
352
353 let route_key = format!(
354 "{}.{}.{}",
355 self.package_name,
356 self.service_name,
357 method.name()
358 );
359
360 client_methods.push(quote! {
361 pub async fn #method_ident(
363 &self,
364 req: #input_ident,
365 ) -> ActorResult<#output_ident> {
366 self.ctx.call_remote(#route_key, req).await
367 }
368 });
369 }
370
371 let extension_method_name = self.service_name.to_snake_case();
373 let extension_method_ident = format_ident!("{}", extension_method_name);
374
375 Ok(quote! {
376 pub struct #client_ident<'a> {
380 ctx: &'a Context,
381 }
382
383 impl<'a> #client_ident<'a> {
384 #(#client_methods)*
385 }
386
387 pub trait ContextExt {
391 fn #extension_method_ident(&self) -> #client_ident;
392 }
393
394 impl ContextExt for Context {
395 fn #extension_method_ident(&self) -> #client_ident {
396 #client_ident { ctx: self }
397 }
398 }
399 }
400 .to_string())
401 }
402
403 fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
405 let handler_trait = format!("{}Handler", self.service_name);
406 let first_method = methods.first();
407
408 let example_method = if let Some(method) = first_method {
409 let method_name = method.name().to_snake_case();
410 let input_type = self.extract_message_type(method.input_type())?;
411 let output_type = self.extract_message_type(method.output_type())?;
412 format!(
413 r#"
414 async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
415 // 实现业务逻辑
416 Ok({output_type}::default())
417 }}"#
418 )
419 } else {
420 " // 实现方法...".to_string()
421 };
422
423 Ok(format!(
424 r#"/*
425## 使用示例
426
427### 1. 实现业务逻辑
428
429```rust
430use actr_framework::{{Context, ActorSystem}};
431use actr_protocol::ActorResult;
432
433pub struct MyService {{
434 // 业务状态
435}}
436
437#[async_trait]
438impl {handler_trait} for MyService {{
439{example_method}
440}}
441```
442
443### 2. 启动服务
444
445```rust
446#[tokio::main]
447async fn main() -> ActorResult<()> {{
448 let config = actr_config::Config::from_file("Actr.toml")?;
449 let service = MyService {{ /* ... */ }};
450
451 ActorSystem::new(config)?
452 .attach(service) // ← 自动获得 Workload + Dispatcher
453 .start()
454 .await?
455 .wait_for_shutdown()
456 .await
457}}
458```
459
460## 架构说明
461
462- **{handler_trait}**: 用户实现的业务逻辑接口
463- **{}Dispatcher**: zero-sized type static dispatcher(自动生成)
464- **Workload**: 通过 blanket impl 自动获得(自动生成)
465
466用户只需实现 {handler_trait},框架会自动提供路由和工作负载能力。
467*/
468"#,
469 self.service_name
470 ))
471 }
472
473 fn generate_client_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
475 let service_name_snake = self.service_name.to_snake_case();
476 let method_name_snake = methods
477 .first()
478 .map(|m| m.name().to_snake_case())
479 .unwrap_or("unknown_method".to_string());
480
481 Ok(format!(
482 r#"/*
483## 客户端使用示例
484
485```rust
486use actr_framework::Context;
487use actr_protocol::ActorResult;
488
489async fn call_remote_service(ctx: &Context) -> ActorResult<()> {{
490 use super::ContextExt;
491
492 // 类型安全的远程调用
493 let response = ctx.{service_name_snake}()
494 .{method_name_snake}(request)
495 .await?;
496
497 Ok(())
498}}
499```
500
501## 编译时路由
502
503所有远程调用在编译时确定目标服务和方法,无需运行时查找。
504*/
505"#
506 ))
507 }
508
509 fn extract_message_type(&self, type_name: &str) -> Result<String> {
511 let cleaned = type_name.trim_start_matches('.');
512 if let Some(last_part) = cleaned.split('.').next_back() {
513 Ok(last_part.to_string())
514 } else {
515 Ok(cleaned.to_string())
516 }
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use prost_types::MethodDescriptorProto;
524
525 #[test]
526 fn test_extract_message_type() {
527 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
528
529 assert_eq!(
530 generator
531 .extract_message_type(".test.v1.EchoRequest")
532 .unwrap(),
533 "EchoRequest"
534 );
535 assert_eq!(
536 generator
537 .extract_message_type("test.v1.EchoResponse")
538 .unwrap(),
539 "EchoResponse"
540 );
541 assert_eq!(
542 generator.extract_message_type("SimpleMessage").unwrap(),
543 "SimpleMessage"
544 );
545 }
546
547 #[test]
548 fn test_generate_message_impls_includes_payload_type() {
549 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
550
551 let methods = vec![MethodDescriptorProto {
552 name: Some("Echo".to_string()),
553 input_type: Some(".test.v1.EchoRequest".to_string()),
554 output_type: Some(".test.v1.EchoResponse".to_string()),
555 options: None,
556 ..Default::default()
557 }];
558
559 let result = generator.generate_message_impls(&methods).unwrap();
560
561 eprintln!("Generated code:\n{result}");
563
564 assert!(
566 result.contains("fn payload_type"),
567 "Should contain 'fn payload_type'"
568 );
569 assert!(
570 result.contains("PayloadType"),
571 "Should contain 'PayloadType'"
572 );
573 assert!(
575 result.contains("RpcReliable"),
576 "Should contain 'RpcReliable'"
577 );
578 }
579
580 #[test]
581 fn test_generate_imports_includes_payload_type() {
582 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
583 let imports = generator.generate_imports();
584
585 assert!(imports.contains("PayloadType"));
587 assert!(
588 imports
589 .contains("use actr_protocol::{ActorResult, RpcRequest, RpcEnvelope, PayloadType}")
590 );
591 }
592
593 #[test]
594 fn test_generate_client_code() {
595 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
596
597 let methods = vec![MethodDescriptorProto {
598 name: Some("Echo".to_string()),
599 input_type: Some(".test.v1.EchoRequest".to_string()),
600 output_type: Some(".test.v1.EchoResponse".to_string()),
601 options: None,
602 ..Default::default()
603 }];
604
605 let result = generator.generate(&methods);
606 assert!(result.is_ok());
607
608 let code = result.unwrap();
609 assert!(code.contains("impl RpcRequest for EchoRequest"));
611 assert!(code.contains("fn payload_type() -> PayloadType"));
612 }
613
614 #[test]
615 fn test_generate_server_code() {
616 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
617
618 let methods = vec![MethodDescriptorProto {
619 name: Some("Echo".to_string()),
620 input_type: Some(".test.v1.EchoRequest".to_string()),
621 output_type: Some(".test.v1.EchoResponse".to_string()),
622 options: None,
623 ..Default::default()
624 }];
625
626 let result = generator.generate(&methods);
627 assert!(result.is_ok());
628
629 let code = result.unwrap();
630 assert!(code.contains("pub trait TestServiceHandler"));
632 assert!(code.contains("pub struct TestServiceDispatcher"));
634 assert!(code.contains("fn payload_type() -> PayloadType"));
636 }
637}