actr_framework_protoc_codegen/
modern_generator.rs1use anyhow::Result;
9use heck::ToSnakeCase;
10use prost_types::MethodDescriptorProto;
11use quote::{format_ident, quote};
12
13use crate::payload_type_extractor::extract_payload_type_or_default;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum GeneratorRole {
18 ServerSide,
20 ClientSide,
22}
23
24pub struct ModernGenerator {
26 package_name: String,
27 service_name: String,
28 role: GeneratorRole,
29 manufacturer: String,
30}
31
32impl ModernGenerator {
33 pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
34 Self {
35 package_name: package_name.to_string(),
36 service_name: service_name.to_string(),
37 role,
38 manufacturer: "acme".to_string(), }
40 }
41
42 pub fn set_manufacturer(&mut self, manufacturer: String) {
44 self.manufacturer = manufacturer;
45 }
46
47 pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
49 match self.role {
50 GeneratorRole::ServerSide => self.generate_server_code(methods),
51 GeneratorRole::ClientSide => self.generate_client_code(methods),
52 }
53 }
54
55 fn generate_server_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
57 let sections = [
58 self.generate_imports(),
60 self.generate_message_impls(methods)?,
62 self.generate_handler_trait(methods)?,
64 self.generate_router_impl(methods)?,
66 self.generate_workload_blanket_impl(methods)?,
68 self.generate_usage_docs(methods)?,
70 ];
71
72 Ok(sections.join("\n\n"))
73 }
74
75 fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
77 let sections = [
78 self.generate_imports(),
80 self.generate_message_impls(methods)?,
82 self.generate_context_extensions(methods)?,
84 self.generate_client_usage_docs(methods)?,
86 ];
87
88 Ok(sections.join("\n\n"))
89 }
90
91 fn generate_imports(&self) -> String {
93 let proto_module = self.package_name.replace('.', "_");
96
97 format!(
98 r#"//! 自动生成的代码 - 请勿手动编辑
99//!
100//! 由 actr-cli 的 protoc-gen-actrframework 插件生成
101
102#![allow(dead_code, unused_imports)]
103
104use async_trait::async_trait;
105use bytes::Bytes;
106use prost::Message as ProstMessage;
107
108use actr_framework::{{Context, MessageDispatcher, Workload}};
109use actr_protocol::{{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}};
110
111// 导入 protobuf 消息类型(由 prost 生成)
112use super::{proto_module}::*;
113"#
114 )
115 }
116
117 fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
127 let mut impls = Vec::new();
128
129 for method in methods {
130 let input_type = self.extract_message_type(method.input_type())?;
131 let output_type = self.extract_message_type(method.output_type())?;
132
133 let route_key = format!(
135 "{}.{}.{}",
136 self.package_name,
137 self.service_name,
138 method.name()
139 );
140
141 let payload_type = extract_payload_type_or_default(method);
143
144 let payload_type_code = payload_type.as_rust_variant();
146
147 let impl_code = format!(
149 r#"/// RpcRequest trait implementation - associates Request and Response types
150///
151/// This enables type-safe RPC calls with automatic response type inference:
152/// ```rust,ignore
153/// let response: {output_type} = ctx.call(&target, request).await?;
154/// ```
155impl RpcRequest for {input_type} {{
156 type Response = {output_type};
157
158 fn route_key() -> &'static str {{
159 "{route_key}"
160 }}
161
162 fn payload_type() -> PayloadType {{
163 {payload_type_code}
164 }}
165}}"#
166 );
167
168 impls.push(impl_code);
169 }
170
171 Ok(impls.join("\n\n"))
172 }
173
174 fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
176 let handler_trait_name = format!("{}Handler", self.service_name);
177 let handler_trait_ident = format_ident!("{}", handler_trait_name);
178
179 let mut method_sigs = Vec::new();
180 for method in methods {
181 let method_name = method.name().to_snake_case();
182 let method_ident = format_ident!("{}", method_name);
183 let input_type = self.extract_message_type(method.input_type())?;
184 let output_type = self.extract_message_type(method.output_type())?;
185 let input_ident = format_ident!("{}", input_type);
186 let output_ident = format_ident!("{}", output_type);
187
188 method_sigs.push(quote! {
189 async fn #method_ident<C: Context>(
191 &self,
192 req: #input_ident,
193 ctx: &C,
194 ) -> ActorResult<#output_ident>;
195 });
196 }
197
198 let handler_trait_without_attr = quote! {
199 pub trait #handler_trait_ident: Send + Sync + 'static {
215 #(#method_sigs)*
216 }
217 };
218
219 Ok(format!("#[async_trait]\n{handler_trait_without_attr}"))
221 }
222
223 fn generate_router_impl(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
225 let router_name = format!("{}Dispatcher", self.service_name);
226 let router_ident = format_ident!("{}", router_name);
227 let workload_name = format!("{}Workload", self.service_name);
228 let workload_ident = format_ident!("{}", workload_name);
229 let handler_trait = format!("{}Handler", self.service_name);
230 let handler_trait_ident = format_ident!("{}", handler_trait);
231
232 let mut match_arms = Vec::new();
234 for method in methods {
235 let route_key = format!(
236 "{}.{}.{}",
237 self.package_name,
238 self.service_name,
239 method.name()
240 );
241 let method_name = method.name().to_snake_case();
242 let method_ident = format_ident!("{}", method_name);
243 let input_type = self.extract_message_type(method.input_type())?;
244 let input_ident = format_ident!("{}", input_type);
245
246 match_arms.push(quote! {
247 #route_key => {
248 let payload = envelope.payload.as_ref()
250 .ok_or_else(|| actr_protocol::ProtocolError::DecodeError(
251 "Missing payload in RpcEnvelope".to_string()
252 ))?;
253
254 let req = #input_ident::decode(&**payload)
256 .map_err(|e| actr_protocol::ProtocolError::Actr(
257 actr_protocol::ActrError::DecodeFailure {
258 message: format!("Failed to decode {}: {}", stringify!(#input_ident), e)
259 }
260 ))?;
261
262 let resp = workload.0.#method_ident(req, ctx).await?;
264
265 Ok(resp.encode_to_vec().into())
267 }
268 });
269 }
270
271 let workload_struct = quote! {
273 pub struct #workload_ident<T: #handler_trait_ident>(pub T);
277
278 impl<T: #handler_trait_ident> #workload_ident<T> {
279 pub fn new(handler: T) -> Self {
281 Self(handler)
282 }
283 }
284 };
285
286 let router_struct = quote! {
287 pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
297 };
298
299 let router_impl_without_attr = quote! {
300 impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
301 type Workload = #workload_ident<T>;
302
303 async fn dispatch<C: Context>(
304 workload: &Self::Workload,
305 envelope: RpcEnvelope,
306 ctx: &C,
307 ) -> ActorResult<Bytes> {
308 match envelope.route_key.as_str() {
309 #(#match_arms,)*
310 _ => Err(actr_protocol::ProtocolError::Actr(
311 actr_protocol::ActrError::UnknownRoute {
312 route_key: envelope.route_key.to_string()
313 }
314 ))
315 }
316 }
317 }
318 };
319
320 let router_impl = format!("#[async_trait]\n{router_impl_without_attr}");
322
323 Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
324 }
325
326 fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
328 let router_name = format!("{}Dispatcher", self.service_name);
329 let router_ident = format_ident!("{}", router_name);
330 let workload_name = format!("{}Workload", self.service_name);
331 let workload_ident = format_ident!("{}", workload_name);
332 let handler_trait = format!("{}Handler", self.service_name);
333 let handler_trait_ident = format_ident!("{}", handler_trait);
334
335 let actor_type_str = format!("{}.{}", self.package_name, self.service_name);
337 let manufacturer = &self.manufacturer;
338
339 Ok(quote! {
340 impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
344 type Dispatcher = #router_ident<T>;
345
346 fn actor_type(&self) -> ActrType {
347 ActrType {
348 manufacturer: #manufacturer.to_string(),
349 name: #actor_type_str.to_string(),
350 }
351 }
352 }
353 }
354 .to_string())
355 }
356
357 fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
359 let client_struct_name = format!("{}Client", self.service_name);
360 let client_ident = format_ident!("{}", client_struct_name);
361
362 let mut client_methods = Vec::new();
363 for method in methods {
364 let method_name = method.name().to_snake_case();
365 let method_ident = format_ident!("{}", method_name);
366 let input_type = self.extract_message_type(method.input_type())?;
367 let output_type = self.extract_message_type(method.output_type())?;
368 let input_ident = format_ident!("{}", input_type);
369 let output_ident = format_ident!("{}", output_type);
370
371 let route_key = format!(
372 "{}.{}.{}",
373 self.package_name,
374 self.service_name,
375 method.name()
376 );
377
378 client_methods.push(quote! {
379 pub async fn #method_ident(
381 &self,
382 req: #input_ident,
383 ) -> ActorResult<#output_ident> {
384 self.ctx.call_remote(#route_key, req).await
385 }
386 });
387 }
388
389 let extension_method_name = self.service_name.to_snake_case();
391 let extension_method_ident = format_ident!("{}", extension_method_name);
392
393 Ok(quote! {
394 pub struct #client_ident<'a> {
398 ctx: &'a Context,
399 }
400
401 impl<'a> #client_ident<'a> {
402 #(#client_methods)*
403 }
404
405 pub trait ContextExt {
409 fn #extension_method_ident(&self) -> #client_ident;
410 }
411
412 impl ContextExt for Context {
413 fn #extension_method_ident(&self) -> #client_ident {
414 #client_ident { ctx: self }
415 }
416 }
417 }
418 .to_string())
419 }
420
421 fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
423 let handler_trait = format!("{}Handler", self.service_name);
424 let first_method = methods.first();
425
426 let example_method = if let Some(method) = first_method {
427 let method_name = method.name().to_snake_case();
428 let input_type = self.extract_message_type(method.input_type())?;
429 let output_type = self.extract_message_type(method.output_type())?;
430 format!(
431 r#"
432 async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
433 // 实现业务逻辑
434 Ok({output_type}::default())
435 }}"#
436 )
437 } else {
438 " // 实现方法...".to_string()
439 };
440
441 Ok(format!(
442 r#"/*
443## 使用示例
444
445### 1. 实现业务逻辑
446
447```rust
448use actr_framework::{{Context, ActorSystem}};
449use actr_protocol::ActorResult;
450
451pub struct MyService {{
452 // 业务状态
453}}
454
455#[async_trait]
456impl {handler_trait} for MyService {{
457{example_method}
458}}
459```
460
461### 2. 启动服务
462
463```rust
464#[tokio::main]
465async fn main() -> ActorResult<()> {{
466 let config = actr_config::Config::from_file("Actr.toml")?;
467 let service = MyService {{ /* ... */ }};
468
469 ActorSystem::new(config)?
470 .attach(service) // ← 自动获得 Workload + Dispatcher
471 .start()
472 .await?
473 .wait_for_shutdown()
474 .await
475}}
476```
477
478## 架构说明
479
480- **{handler_trait}**: 用户实现的业务逻辑接口
481- **{}Dispatcher**: zero-sized type static dispatcher(自动生成)
482- **Workload**: 通过 blanket impl 自动获得(自动生成)
483
484用户只需实现 {handler_trait},框架会自动提供路由和工作负载能力。
485*/
486"#,
487 self.service_name
488 ))
489 }
490
491 fn generate_client_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
493 let service_name_snake = self.service_name.to_snake_case();
494 let method_name_snake = methods
495 .first()
496 .map(|m| m.name().to_snake_case())
497 .unwrap_or("unknown_method".to_string());
498
499 Ok(format!(
500 r#"/*
501## 客户端使用示例
502
503```rust
504use actr_framework::Context;
505use actr_protocol::ActorResult;
506
507async fn call_remote_service(ctx: &Context) -> ActorResult<()> {{
508 use super::ContextExt;
509
510 // 类型安全的远程调用
511 let response = ctx.{service_name_snake}()
512 .{method_name_snake}(request)
513 .await?;
514
515 Ok(())
516}}
517```
518
519## 编译时路由
520
521所有远程调用在编译时确定目标服务和方法,无需运行时查找。
522*/
523"#
524 ))
525 }
526
527 fn extract_message_type(&self, type_name: &str) -> Result<String> {
529 let cleaned = type_name.trim_start_matches('.');
530 if let Some(last_part) = cleaned.split('.').next_back() {
531 Ok(last_part.to_string())
532 } else {
533 Ok(cleaned.to_string())
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use prost_types::MethodDescriptorProto;
542
543 #[test]
544 fn test_extract_message_type() {
545 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
546
547 assert_eq!(
548 generator
549 .extract_message_type(".test.v1.EchoRequest")
550 .unwrap(),
551 "EchoRequest"
552 );
553 assert_eq!(
554 generator
555 .extract_message_type("test.v1.EchoResponse")
556 .unwrap(),
557 "EchoResponse"
558 );
559 assert_eq!(
560 generator.extract_message_type("SimpleMessage").unwrap(),
561 "SimpleMessage"
562 );
563 }
564
565 #[test]
566 fn test_generate_message_impls_includes_payload_type() {
567 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
568
569 let methods = vec![MethodDescriptorProto {
570 name: Some("Echo".to_string()),
571 input_type: Some(".test.v1.EchoRequest".to_string()),
572 output_type: Some(".test.v1.EchoResponse".to_string()),
573 options: None,
574 ..Default::default()
575 }];
576
577 let result = generator.generate_message_impls(&methods).unwrap();
578
579 eprintln!("Generated code:\n{result}");
581
582 assert!(
584 result.contains("fn payload_type"),
585 "Should contain 'fn payload_type'"
586 );
587 assert!(
588 result.contains("PayloadType"),
589 "Should contain 'PayloadType'"
590 );
591 assert!(
593 result.contains("RpcReliable"),
594 "Should contain 'RpcReliable'"
595 );
596 }
597
598 #[test]
599 fn test_generate_imports_includes_payload_type() {
600 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
601 let imports = generator.generate_imports();
602
603 assert!(imports.contains("PayloadType"));
605 assert!(imports.contains(
606 "use actr_protocol::{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}"
607 ));
608 }
609
610 #[test]
611 fn test_generate_client_code() {
612 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
613
614 let methods = vec![MethodDescriptorProto {
615 name: Some("Echo".to_string()),
616 input_type: Some(".test.v1.EchoRequest".to_string()),
617 output_type: Some(".test.v1.EchoResponse".to_string()),
618 options: None,
619 ..Default::default()
620 }];
621
622 let result = generator.generate(&methods);
623 assert!(result.is_ok());
624
625 let code = result.unwrap();
626 assert!(code.contains("impl RpcRequest for EchoRequest"));
628 assert!(code.contains("fn payload_type() -> PayloadType"));
629 }
630
631 #[test]
632 fn test_generate_server_code() {
633 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
634
635 let methods = vec![MethodDescriptorProto {
636 name: Some("Echo".to_string()),
637 input_type: Some(".test.v1.EchoRequest".to_string()),
638 output_type: Some(".test.v1.EchoResponse".to_string()),
639 options: None,
640 ..Default::default()
641 }];
642
643 let result = generator.generate(&methods);
644 assert!(result.is_ok());
645
646 let code = result.unwrap();
647 assert!(code.contains("pub trait TestServiceHandler"));
649 assert!(code.contains("pub struct TestServiceDispatcher"));
651 assert!(code.contains("fn payload_type() -> PayloadType"));
653 }
654}