actr_framework_protoc_codegen/
modern_generator.rs1use anyhow::Result;
9use heck::ToSnakeCase;
10use prost_types::MethodDescriptorProto;
11use quote::{format_ident, quote};
12
13use crate::payload_type_extractor::extract_payload_type_or_default;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum GeneratorRole {
18 ServerSide,
20 ClientSide,
22}
23
24pub struct ModernGenerator {
26 package_name: String,
27 service_name: String,
28 role: GeneratorRole,
29 manufacturer: String,
30}
31
32impl ModernGenerator {
33 pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
34 Self {
35 package_name: package_name.to_string(),
36 service_name: service_name.to_string(),
37 role,
38 manufacturer: "acme".to_string(), }
40 }
41
42 pub fn set_manufacturer(&mut self, manufacturer: String) {
44 self.manufacturer = manufacturer;
45 }
46
47 pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
49 match self.role {
50 GeneratorRole::ServerSide => self.generate_server_code(methods),
51 GeneratorRole::ClientSide => self.generate_client_code(methods),
52 }
53 }
54
55 fn generate_server_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
57 let sections = [
58 self.generate_imports(),
60 self.generate_message_impls(methods)?,
62 self.generate_handler_trait(methods)?,
64 self.generate_router_impl(methods)?,
66 self.generate_workload_blanket_impl(methods)?,
68 self.generate_usage_docs(methods)?,
70 ];
71
72 Ok(sections.join("\n\n"))
73 }
74
75 fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
77 let sections = [
78 self.generate_imports(),
80 self.generate_message_impls(methods)?,
82 self.generate_context_extensions(methods)?,
84 self.generate_client_usage_docs(methods)?,
86 ];
87
88 Ok(sections.join("\n\n"))
89 }
90
91 fn generate_imports(&self) -> String {
93 let proto_module = self.package_name.replace('.', "_");
96
97 format!(
98 r#"//! 自动生成的代码 - 请勿手动编辑
99//!
100//! 由 actr-cli 的 protoc-gen-actrframework 插件生成
101
102#![allow(dead_code, unused_imports)]
103
104use async_trait::async_trait;
105use bytes::Bytes;
106use prost::Message as ProstMessage;
107
108use actr_framework::{{Context, MessageDispatcher, Workload}};
109use actr_protocol::{{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}};
110
111// 导入 protobuf 消息类型(由 prost 生成)
112use super::{proto_module}::*;
113"#
114 )
115 }
116
117 fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
127 let mut impls = Vec::new();
128
129 for method in methods {
130 let input_type = self.extract_message_type(method.input_type())?;
131 let output_type = self.extract_message_type(method.output_type())?;
132
133 let route_key = format!(
135 "{}.{}.{}",
136 self.package_name,
137 self.service_name,
138 method.name()
139 );
140
141 let payload_type = extract_payload_type_or_default(method);
143
144 let payload_type_code = match payload_type {
146 crate::payload_type_extractor::PayloadType::RpcReliable => {
147 "PayloadType::RpcReliable"
148 }
149 crate::payload_type_extractor::PayloadType::RpcSignal => "PayloadType::RpcSignal",
150 crate::payload_type_extractor::PayloadType::StreamReliable => {
151 "PayloadType::StreamReliable"
152 }
153 crate::payload_type_extractor::PayloadType::StreamLatencyFirst => {
154 "PayloadType::StreamLatencyFirst"
155 }
156 crate::payload_type_extractor::PayloadType::MediaRtp => "PayloadType::MediaRtp",
157 };
158
159 let impl_code = format!(
161 r#"/// RpcRequest trait implementation - associates Request and Response types
162///
163/// This enables type-safe RPC calls with automatic response type inference:
164/// ```rust,ignore
165/// let response: {output_type} = ctx.call(&target, request).await?;
166/// ```
167impl RpcRequest for {input_type} {{
168 type Response = {output_type};
169
170 fn route_key() -> &'static str {{
171 "{route_key}"
172 }}
173
174 fn payload_type() -> PayloadType {{
175 {payload_type_code}
176 }}
177}}"#
178 );
179
180 impls.push(impl_code);
181 }
182
183 Ok(impls
184 .iter()
185 .map(|i| i.to_string())
186 .collect::<Vec<_>>()
187 .join("\n\n"))
188 }
189
190 fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
192 let handler_trait_name = format!("{}Handler", self.service_name);
193 let handler_trait_ident = format_ident!("{}", handler_trait_name);
194
195 let mut method_sigs = Vec::new();
196 for method in methods {
197 let method_name = method.name().to_snake_case();
198 let method_ident = format_ident!("{}", method_name);
199 let input_type = self.extract_message_type(method.input_type())?;
200 let output_type = self.extract_message_type(method.output_type())?;
201 let input_ident = format_ident!("{}", input_type);
202 let output_ident = format_ident!("{}", output_type);
203
204 method_sigs.push(quote! {
205 async fn #method_ident<C: Context>(
207 &self,
208 req: #input_ident,
209 ctx: &C,
210 ) -> ActorResult<#output_ident>;
211 });
212 }
213
214 let handler_trait_without_attr = quote! {
215 pub trait #handler_trait_ident: Send + Sync + 'static {
231 #(#method_sigs)*
232 }
233 }
234 .to_string();
235
236 Ok(format!("#[async_trait]\n{handler_trait_without_attr}"))
238 }
239
240 fn generate_router_impl(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
242 let router_name = format!("{}Dispatcher", self.service_name);
243 let router_ident = format_ident!("{}", router_name);
244 let workload_name = format!("{}Workload", self.service_name);
245 let workload_ident = format_ident!("{}", workload_name);
246 let handler_trait = format!("{}Handler", self.service_name);
247 let handler_trait_ident = format_ident!("{}", handler_trait);
248
249 let mut match_arms = Vec::new();
251 for method in methods {
252 let route_key = format!(
253 "{}.{}.{}",
254 self.package_name,
255 self.service_name,
256 method.name()
257 );
258 let method_name = method.name().to_snake_case();
259 let method_ident = format_ident!("{}", method_name);
260 let input_type = self.extract_message_type(method.input_type())?;
261 let input_ident = format_ident!("{}", input_type);
262
263 match_arms.push(quote! {
264 #route_key => {
265 let payload = envelope.payload.as_ref()
267 .ok_or_else(|| actr_protocol::ProtocolError::DecodeError(
268 "Missing payload in RpcEnvelope".to_string()
269 ))?;
270
271 let req = #input_ident::decode(&**payload)
273 .map_err(|e| actr_protocol::ProtocolError::Actr(
274 actr_protocol::ActrError::DecodeFailure {
275 message: format!("Failed to decode {}: {}", stringify!(#input_ident), e)
276 }
277 ))?;
278
279 let resp = workload.0.#method_ident(req, ctx).await?;
281
282 Ok(resp.encode_to_vec().into())
284 }
285 });
286 }
287
288 let workload_struct = quote! {
290 pub struct #workload_ident<T: #handler_trait_ident>(pub T);
294
295 impl<T: #handler_trait_ident> #workload_ident<T> {
296 pub fn new(handler: T) -> Self {
298 Self(handler)
299 }
300 }
301 }
302 .to_string();
303
304 let router_struct = quote! {
305 pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
315 }
316 .to_string();
317
318 let router_impl_without_attr = quote! {
319 impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
320 type Workload = #workload_ident<T>;
321
322 async fn dispatch<C: Context>(
323 workload: &Self::Workload,
324 envelope: RpcEnvelope,
325 ctx: &C,
326 ) -> ActorResult<Bytes> {
327 match envelope.route_key.as_str() {
328 #(#match_arms,)*
329 _ => Err(actr_protocol::ProtocolError::Actr(
330 actr_protocol::ActrError::UnknownRoute {
331 route_key: envelope.route_key.to_string()
332 }
333 ))
334 }
335 }
336 }
337 }
338 .to_string();
339
340 let router_impl = format!("#[async_trait]\n{router_impl_without_attr}");
342
343 Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
344 }
345
346 fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
348 let router_name = format!("{}Dispatcher", self.service_name);
349 let router_ident = format_ident!("{}", router_name);
350 let workload_name = format!("{}Workload", self.service_name);
351 let workload_ident = format_ident!("{}", workload_name);
352 let handler_trait = format!("{}Handler", self.service_name);
353 let handler_trait_ident = format_ident!("{}", handler_trait);
354
355 let actor_type_str = format!("{}.{}", self.package_name, self.service_name);
357 let manufacturer = &self.manufacturer;
358
359 Ok(quote! {
360 impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
364 type Dispatcher = #router_ident<T>;
365
366 fn actor_type(&self) -> ActrType {
367 ActrType {
368 manufacturer: #manufacturer.to_string(),
369 name: #actor_type_str.to_string(),
370 }
371 }
372 }
373 }
374 .to_string())
375 }
376
377 fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
379 let client_struct_name = format!("{}Client", self.service_name);
380 let client_ident = format_ident!("{}", client_struct_name);
381
382 let mut client_methods = Vec::new();
383 for method in methods {
384 let method_name = method.name().to_snake_case();
385 let method_ident = format_ident!("{}", method_name);
386 let input_type = self.extract_message_type(method.input_type())?;
387 let output_type = self.extract_message_type(method.output_type())?;
388 let input_ident = format_ident!("{}", input_type);
389 let output_ident = format_ident!("{}", output_type);
390
391 let route_key = format!(
392 "{}.{}.{}",
393 self.package_name,
394 self.service_name,
395 method.name()
396 );
397
398 client_methods.push(quote! {
399 pub async fn #method_ident(
401 &self,
402 req: #input_ident,
403 ) -> ActorResult<#output_ident> {
404 self.ctx.call_remote(#route_key, req).await
405 }
406 });
407 }
408
409 let extension_method_name = self.service_name.to_snake_case();
411 let extension_method_ident = format_ident!("{}", extension_method_name);
412
413 Ok(quote! {
414 pub struct #client_ident<'a> {
418 ctx: &'a Context,
419 }
420
421 impl<'a> #client_ident<'a> {
422 #(#client_methods)*
423 }
424
425 pub trait ContextExt {
429 fn #extension_method_ident(&self) -> #client_ident;
430 }
431
432 impl ContextExt for Context {
433 fn #extension_method_ident(&self) -> #client_ident {
434 #client_ident { ctx: self }
435 }
436 }
437 }
438 .to_string())
439 }
440
441 fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
443 let handler_trait = format!("{}Handler", self.service_name);
444 let first_method = methods.first();
445
446 let example_method = if let Some(method) = first_method {
447 let method_name = method.name().to_snake_case();
448 let input_type = self.extract_message_type(method.input_type())?;
449 let output_type = self.extract_message_type(method.output_type())?;
450 format!(
451 r#"
452 async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
453 // 实现业务逻辑
454 Ok({output_type}::default())
455 }}"#
456 )
457 } else {
458 " // 实现方法...".to_string()
459 };
460
461 Ok(format!(
462 r#"/*
463## 使用示例
464
465### 1. 实现业务逻辑
466
467```rust
468use actr_framework::{{Context, ActorSystem}};
469use actr_protocol::ActorResult;
470
471pub struct MyService {{
472 // 业务状态
473}}
474
475#[async_trait]
476impl {handler_trait} for MyService {{
477{example_method}
478}}
479```
480
481### 2. 启动服务
482
483```rust
484#[tokio::main]
485async fn main() -> ActorResult<()> {{
486 let config = actr_config::Config::from_file("Actr.toml")?;
487 let service = MyService {{ /* ... */ }};
488
489 ActorSystem::new(config)?
490 .attach(service) // ← 自动获得 Workload + Dispatcher
491 .start()
492 .await?
493 .wait_for_shutdown()
494 .await
495}}
496```
497
498## 架构说明
499
500- **{handler_trait}**: 用户实现的业务逻辑接口
501- **{}Dispatcher**: zero-sized type static dispatcher(自动生成)
502- **Workload**: 通过 blanket impl 自动获得(自动生成)
503
504用户只需实现 {handler_trait},框架会自动提供路由和工作负载能力。
505*/
506"#,
507 self.service_name
508 ))
509 }
510
511 fn generate_client_usage_docs(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
513 let service_name_snake = self.service_name.to_snake_case();
514
515 Ok(format!(
516 r#"/*
517## 客户端使用示例
518
519```rust
520use actr_framework::Context;
521use actr_protocol::ActorResult;
522
523async fn call_remote_service(ctx: &Context) -> ActorResult<()> {{
524 use super::ContextExt;
525
526 // 类型安全的远程调用
527 let response = ctx.{service_name_snake}()
528 .method_name(request)
529 .await?;
530
531 Ok(())
532}}
533```
534
535## 编译时路由
536
537所有远程调用在编译时确定目标服务和方法,无需运行时查找。
538*/
539"#
540 ))
541 }
542
543 fn extract_message_type(&self, type_name: &str) -> Result<String> {
545 let cleaned = type_name.trim_start_matches('.');
546 if let Some(last_part) = cleaned.split('.').next_back() {
547 Ok(last_part.to_string())
548 } else {
549 Ok(cleaned.to_string())
550 }
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use prost_types::MethodDescriptorProto;
558
559 #[test]
560 fn test_extract_message_type() {
561 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
562
563 assert_eq!(
564 generator
565 .extract_message_type(".test.v1.EchoRequest")
566 .unwrap(),
567 "EchoRequest"
568 );
569 assert_eq!(
570 generator
571 .extract_message_type("test.v1.EchoResponse")
572 .unwrap(),
573 "EchoResponse"
574 );
575 assert_eq!(
576 generator.extract_message_type("SimpleMessage").unwrap(),
577 "SimpleMessage"
578 );
579 }
580
581 #[test]
582 fn test_generate_message_impls_includes_payload_type() {
583 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
584
585 let methods = vec![MethodDescriptorProto {
586 name: Some("Echo".to_string()),
587 input_type: Some(".test.v1.EchoRequest".to_string()),
588 output_type: Some(".test.v1.EchoResponse".to_string()),
589 options: None,
590 ..Default::default()
591 }];
592
593 let result = generator.generate_message_impls(&methods).unwrap();
594
595 eprintln!("Generated code:\n{result}");
597
598 assert!(
600 result.contains("fn payload_type"),
601 "Should contain 'fn payload_type'"
602 );
603 assert!(
604 result.contains("PayloadType"),
605 "Should contain 'PayloadType'"
606 );
607 assert!(
609 result.contains("RpcReliable"),
610 "Should contain 'RpcReliable'"
611 );
612 }
613
614 #[test]
615 fn test_generate_imports_includes_payload_type() {
616 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
617 let imports = generator.generate_imports();
618
619 assert!(imports.contains("PayloadType"));
621 assert!(imports.contains(
622 "use actr_protocol::{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}"
623 ));
624 }
625
626 #[test]
627 fn test_generate_client_code() {
628 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
629
630 let methods = vec![MethodDescriptorProto {
631 name: Some("Echo".to_string()),
632 input_type: Some(".test.v1.EchoRequest".to_string()),
633 output_type: Some(".test.v1.EchoResponse".to_string()),
634 options: None,
635 ..Default::default()
636 }];
637
638 let result = generator.generate(&methods);
639 assert!(result.is_ok());
640
641 let code = result.unwrap();
642 assert!(code.contains("impl RpcRequest for EchoRequest"));
644 assert!(code.contains("fn payload_type() -> PayloadType"));
645 }
646
647 #[test]
648 fn test_generate_server_code() {
649 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
650
651 let methods = vec![MethodDescriptorProto {
652 name: Some("Echo".to_string()),
653 input_type: Some(".test.v1.EchoRequest".to_string()),
654 output_type: Some(".test.v1.EchoResponse".to_string()),
655 options: None,
656 ..Default::default()
657 }];
658
659 let result = generator.generate(&methods);
660 assert!(result.is_ok());
661
662 let code = result.unwrap();
663 assert!(code.contains("pub trait TestServiceHandler"));
665 assert!(code.contains("pub struct TestServiceDispatcher"));
667 assert!(code.contains("fn payload_type() -> PayloadType"));
669 }
670}