actr_framework_protoc_codegen/
modern_generator.rs

1//! 现代化代码生成器
2//!
3//! 基于 actr-framework 的实际架构生成代码:
4//! - MessageDispatcher trait: zero-sized type static dispatcher
5//! - Workload trait: 业务工作负载,associates Dispatcher type
6//! - {Service}Handler trait: 用户实现的业务逻辑接口
7
8use anyhow::Result;
9use heck::ToSnakeCase;
10use prost_types::MethodDescriptorProto;
11use quote::{format_ident, quote};
12
13use crate::payload_type_extractor::extract_payload_type_or_default;
14
15/// 代码生成器角色
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum GeneratorRole {
18    /// 为 exports 生成服务端代码
19    ServerSide,
20    /// 为 dependencies 生成客户端代码
21    ClientSide,
22}
23
24/// 现代化代码生成器
25pub struct ModernGenerator {
26    package_name: String,
27    service_name: String,
28    role: GeneratorRole,
29    manufacturer: String,
30}
31
32impl ModernGenerator {
33    pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
34        Self {
35            package_name: package_name.to_string(),
36            service_name: service_name.to_string(),
37            role,
38            manufacturer: "acme".to_string(), // Default for backward compatibility
39        }
40    }
41
42    /// Set manufacturer (read from Actr.toml via --actrframework_opt)
43    pub fn set_manufacturer(&mut self, manufacturer: String) {
44        self.manufacturer = manufacturer;
45    }
46
47    /// 生成完整代码
48    pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
49        match self.role {
50            GeneratorRole::ServerSide => self.generate_server_code(methods),
51            GeneratorRole::ClientSide => self.generate_client_code(methods),
52        }
53    }
54
55    /// 生成服务端代码(exports)
56    fn generate_server_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
57        let sections = [
58            // 1. 生成导入
59            self.generate_imports(),
60            // 2. 生成 RpcRequest trait 实现(类型安全的 Request → Response 关联)
61            self.generate_message_impls(methods)?,
62            // 3. 生成 Handler trait(用户实现的接口)
63            self.generate_handler_trait(methods)?,
64            // 4. Generate Dispatcher implementation(zero-sized type static dispatcher)
65            self.generate_router_impl(methods)?,
66            // 5. 生成 Workload blanket 实现
67            self.generate_workload_blanket_impl(methods)?,
68            // 6. 生成使用文档
69            self.generate_usage_docs(methods)?,
70        ];
71
72        Ok(sections.join("\n\n"))
73    }
74
75    /// 生成客户端代码(dependencies)
76    fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
77        let sections = [
78            // 1. 生成导入
79            self.generate_imports(),
80            // 2. 生成 RpcRequest trait 实现(客户端也需要用于类型安全调用)
81            self.generate_message_impls(methods)?,
82            // 3. 生成 Context 扩展方法
83            self.generate_context_extensions(methods)?,
84            // 4. 生成使用文档
85            self.generate_client_usage_docs(methods)?,
86        ];
87
88        Ok(sections.join("\n\n"))
89    }
90
91    /// 生成导入语句
92    fn generate_imports(&self) -> String {
93        // 生成protobuf消息导入
94        // 假设消息类型在同级的 proto 模块中(由 prost 生成)
95        let proto_module = self.package_name.replace('.', "_");
96
97        format!(
98            r#"//! 自动生成的代码 - 请勿手动编辑
99//!
100//! 由 actr-cli 的 protoc-gen-actrframework 插件生成
101
102#![allow(dead_code, unused_imports)]
103
104use async_trait::async_trait;
105use bytes::Bytes;
106use prost::Message as ProstMessage;
107
108use actr_framework::{{Context, MessageDispatcher, Workload}};
109use actr_protocol::{{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}};
110
111// 导入 protobuf 消息类型(由 prost 生成)
112use super::{proto_module}::*;
113"#
114        )
115    }
116
117    /// 生成 RpcRequest trait 实现
118    ///
119    /// 为每个 RPC 方法的 Request 类型生成 RpcRequest trait 实现,
120    /// 关联其对应的 Response 类型。这使得客户端可以使用类型安全的 API:
121    ///
122    /// ```rust,ignore
123    /// let response: EchoResponse = ctx.call(&target, request).await?;
124    /// //              ^^^^^^^^^^^^ 从 EchoRequest::Response 推导
125    /// ```
126    fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
127        let mut impls = Vec::new();
128
129        for method in methods {
130            let input_type = self.extract_message_type(method.input_type())?;
131            let output_type = self.extract_message_type(method.output_type())?;
132
133            // 生成路由键
134            let route_key = format!(
135                "{}.{}.{}",
136                self.package_name,
137                self.service_name,
138                method.name()
139            );
140
141            // 提取 PayloadType
142            let payload_type = extract_payload_type_or_default(method);
143
144            // 生成 PayloadType 枚举路径(不能用 quote! 因为会加引号)
145            let payload_type_code = match payload_type {
146                crate::payload_type_extractor::PayloadType::RpcReliable => {
147                    "PayloadType::RpcReliable"
148                }
149                crate::payload_type_extractor::PayloadType::RpcSignal => "PayloadType::RpcSignal",
150                crate::payload_type_extractor::PayloadType::StreamReliable => {
151                    "PayloadType::StreamReliable"
152                }
153                crate::payload_type_extractor::PayloadType::StreamLatencyFirst => {
154                    "PayloadType::StreamLatencyFirst"
155                }
156                crate::payload_type_extractor::PayloadType::MediaRtp => "PayloadType::MediaRtp",
157            };
158
159            // 手动构造代码字符串避免 quote! 添加引号
160            let impl_code = format!(
161                r#"/// RpcRequest trait implementation - associates Request and Response types
162///
163/// This enables type-safe RPC calls with automatic response type inference:
164/// ```rust,ignore
165/// let response: {output_type} = ctx.call(&target, request).await?;
166/// ```
167impl RpcRequest for {input_type} {{
168    type Response = {output_type};
169
170    fn route_key() -> &'static str {{
171        "{route_key}"
172    }}
173
174    fn payload_type() -> PayloadType {{
175        {payload_type_code}
176    }}
177}}"#
178            );
179
180            impls.push(impl_code);
181        }
182
183        Ok(impls
184            .iter()
185            .map(|i| i.to_string())
186            .collect::<Vec<_>>()
187            .join("\n\n"))
188    }
189
190    /// 生成 Handler trait
191    fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
192        let handler_trait_name = format!("{}Handler", self.service_name);
193        let handler_trait_ident = format_ident!("{}", handler_trait_name);
194
195        let mut method_sigs = Vec::new();
196        for method in methods {
197            let method_name = method.name().to_snake_case();
198            let method_ident = format_ident!("{}", method_name);
199            let input_type = self.extract_message_type(method.input_type())?;
200            let output_type = self.extract_message_type(method.output_type())?;
201            let input_ident = format_ident!("{}", input_type);
202            let output_ident = format_ident!("{}", output_type);
203
204            method_sigs.push(quote! {
205                /// RPC 方法:#method_name
206                async fn #method_ident<C: Context>(
207                    &self,
208                    req: #input_ident,
209                    ctx: &C,
210                ) -> ActorResult<#output_ident>;
211            });
212        }
213
214        let handler_trait_without_attr = quote! {
215            /// 服务处理器 trait - 用户需要实现此 trait
216            ///
217            /// # 示例
218            ///
219            /// ```rust,ignore
220            /// pub struct MyService { /* ... */ }
221            ///
222            /// #[async_trait]
223            /// impl #handler_trait_ident for MyService {
224            ///     async fn method_name(&self, req: Request, ctx: &Context) -> ActorResult<Response> {
225            ///         // 业务逻辑
226            ///         Ok(Response::default())
227            ///     }
228            /// }
229            /// ```
230            pub trait #handler_trait_ident: Send + Sync + 'static {
231                #(#method_sigs)*
232            }
233        }
234        .to_string();
235
236        // 手动添加 #[async_trait] 属性,避免 quote! 宏插入空格
237        Ok(format!("#[async_trait]\n{handler_trait_without_attr}"))
238    }
239
240    /// Generate Dispatcher and Workload 包装类型
241    fn generate_router_impl(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
242        let router_name = format!("{}Dispatcher", self.service_name);
243        let router_ident = format_ident!("{}", router_name);
244        let workload_name = format!("{}Workload", self.service_name);
245        let workload_ident = format_ident!("{}", workload_name);
246        let handler_trait = format!("{}Handler", self.service_name);
247        let handler_trait_ident = format_ident!("{}", handler_trait);
248
249        // 生成 match 分支
250        let mut match_arms = Vec::new();
251        for method in methods {
252            let route_key = format!(
253                "{}.{}.{}",
254                self.package_name,
255                self.service_name,
256                method.name()
257            );
258            let method_name = method.name().to_snake_case();
259            let method_ident = format_ident!("{}", method_name);
260            let input_type = self.extract_message_type(method.input_type())?;
261            let input_ident = format_ident!("{}", input_type);
262
263            match_arms.push(quote! {
264                #route_key => {
265                    // Extract payload from envelope
266                    let payload = envelope.payload.as_ref()
267                        .ok_or_else(|| actr_protocol::ProtocolError::DecodeError(
268                            "Missing payload in RpcEnvelope".to_string()
269                        ))?;
270
271                    // Deserialize request
272                    let req = #input_ident::decode(&**payload)
273                        .map_err(|e| actr_protocol::ProtocolError::Actr(
274                            actr_protocol::ActrError::DecodeFailure {
275                                message: format!("Failed to decode {}: {}", stringify!(#input_ident), e)
276                            }
277                        ))?;
278
279                    // 调用业务逻辑
280                    let resp = workload.0.#method_ident(req, ctx).await?;
281
282                    // 序列化响应
283                    Ok(resp.encode_to_vec().into())
284                }
285            });
286        }
287
288        // 分开生成各个部分以确保属性正确输出
289        let workload_struct = quote! {
290            /// Workload 包装类型
291            ///
292            /// 包装用户的 Handler 实现,满足孤儿规则
293            pub struct #workload_ident<T: #handler_trait_ident>(pub T);
294
295            impl<T: #handler_trait_ident> #workload_ident<T> {
296                /// 创建新的 Workload 实例
297                pub fn new(handler: T) -> Self {
298                    Self(handler)
299                }
300            }
301        }
302        .to_string();
303
304        let router_struct = quote! {
305            /// Message dispatcher - 零大小类型 (ZST)
306            ///
307            /// 此路由器由代码生成器自动生成,将 route_key 静态路由到对应的处理方法。
308            ///
309            /// # 性能特性
310            ///
311            /// - 零内存开销(PhantomData)
312            /// - 静态 match 派发,约 5-10ns
313            /// - 编译器完全内联
314            pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
315        }
316        .to_string();
317
318        let router_impl_without_attr = quote! {
319            impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
320                type Workload = #workload_ident<T>;
321
322                async fn dispatch<C: Context>(
323                    workload: &Self::Workload,
324                    envelope: RpcEnvelope,
325                    ctx: &C,
326                ) -> ActorResult<Bytes> {
327                    match envelope.route_key.as_str() {
328                        #(#match_arms,)*
329                        _ => Err(actr_protocol::ProtocolError::Actr(
330                            actr_protocol::ActrError::UnknownRoute {
331                                route_key: envelope.route_key.to_string()
332                            }
333                        ))
334                    }
335                }
336            }
337        }
338        .to_string();
339
340        // 手动添加 #[async_trait] 属性,避免 quote! 宏插入空格
341        let router_impl = format!("#[async_trait]\n{router_impl_without_attr}");
342
343        Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
344    }
345
346    /// 生成 Workload 实现
347    fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
348        let router_name = format!("{}Dispatcher", self.service_name);
349        let router_ident = format_ident!("{}", router_name);
350        let workload_name = format!("{}Workload", self.service_name);
351        let workload_ident = format_ident!("{}", workload_name);
352        let handler_trait = format!("{}Handler", self.service_name);
353        let handler_trait_ident = format_ident!("{}", handler_trait);
354
355        // 从 package_name 和 service_name 构造 ActrType
356        let actor_type_str = format!("{}.{}", self.package_name, self.service_name);
357        let manufacturer = &self.manufacturer;
358
359        Ok(quote! {
360            /// Workload trait 实现
361            ///
362            /// 为包装类型实现 Workload,使其可被 ActorSystem 识别和调度
363            impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
364                type Dispatcher = #router_ident<T>;
365
366                fn actor_type(&self) -> ActrType {
367                    ActrType {
368                        manufacturer: #manufacturer.to_string(),
369                        name: #actor_type_str.to_string(),
370                    }
371                }
372            }
373        }
374        .to_string())
375    }
376
377    /// 生成 Context 扩展方法(客户端)
378    fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
379        let client_struct_name = format!("{}Client", self.service_name);
380        let client_ident = format_ident!("{}", client_struct_name);
381
382        let mut client_methods = Vec::new();
383        for method in methods {
384            let method_name = method.name().to_snake_case();
385            let method_ident = format_ident!("{}", method_name);
386            let input_type = self.extract_message_type(method.input_type())?;
387            let output_type = self.extract_message_type(method.output_type())?;
388            let input_ident = format_ident!("{}", input_type);
389            let output_ident = format_ident!("{}", output_type);
390
391            let route_key = format!(
392                "{}.{}.{}",
393                self.package_name,
394                self.service_name,
395                method.name()
396            );
397
398            client_methods.push(quote! {
399                /// 调用远程方法:#method_name
400                pub async fn #method_ident(
401                    &self,
402                    req: #input_ident,
403                ) -> ActorResult<#output_ident> {
404                    self.ctx.call_remote(#route_key, req).await
405                }
406            });
407        }
408
409        // 生成 Context 扩展
410        let extension_method_name = self.service_name.to_snake_case();
411        let extension_method_ident = format_ident!("{}", extension_method_name);
412
413        Ok(quote! {
414            /// 客户端接口
415            ///
416            /// 提供类型安全的远程调用方法
417            pub struct #client_ident<'a> {
418                ctx: &'a Context,
419            }
420
421            impl<'a> #client_ident<'a> {
422                #(#client_methods)*
423            }
424
425            /// Context 扩展 trait
426            ///
427            /// 为 Context 添加便捷的客户端方法
428            pub trait ContextExt {
429                fn #extension_method_ident(&self) -> #client_ident;
430            }
431
432            impl ContextExt for Context {
433                fn #extension_method_ident(&self) -> #client_ident {
434                    #client_ident { ctx: self }
435                }
436            }
437        }
438        .to_string())
439    }
440
441    /// 生成服务端使用文档
442    fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
443        let handler_trait = format!("{}Handler", self.service_name);
444        let first_method = methods.first();
445
446        let example_method = if let Some(method) = first_method {
447            let method_name = method.name().to_snake_case();
448            let input_type = self.extract_message_type(method.input_type())?;
449            let output_type = self.extract_message_type(method.output_type())?;
450            format!(
451                r#"
452    async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
453        // 实现业务逻辑
454        Ok({output_type}::default())
455    }}"#
456            )
457        } else {
458            "    // 实现方法...".to_string()
459        };
460
461        Ok(format!(
462            r#"/*
463## 使用示例
464
465### 1. 实现业务逻辑
466
467```rust
468use actr_framework::{{Context, ActorSystem}};
469use actr_protocol::ActorResult;
470
471pub struct MyService {{
472    // 业务状态
473}}
474
475#[async_trait]
476impl {handler_trait} for MyService {{
477{example_method}
478}}
479```
480
481### 2. 启动服务
482
483```rust
484#[tokio::main]
485async fn main() -> ActorResult<()> {{
486    let config = actr_config::Config::from_file("Actr.toml")?;
487    let service = MyService {{ /* ... */ }};
488
489    ActorSystem::new(config)?
490        .attach(service)  // ← 自动获得 Workload + Dispatcher
491        .start()
492        .await?
493        .wait_for_shutdown()
494        .await
495}}
496```
497
498## 架构说明
499
500- **{handler_trait}**: 用户实现的业务逻辑接口
501- **{}Dispatcher**: zero-sized type static dispatcher(自动生成)
502- **Workload**: 通过 blanket impl 自动获得(自动生成)
503
504用户只需实现 {handler_trait},框架会自动提供路由和工作负载能力。
505*/
506"#,
507            self.service_name
508        ))
509    }
510
511    /// 生成客户端使用文档
512    fn generate_client_usage_docs(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
513        let service_name_snake = self.service_name.to_snake_case();
514
515        Ok(format!(
516            r#"/*
517## 客户端使用示例
518
519```rust
520use actr_framework::Context;
521use actr_protocol::ActorResult;
522
523async fn call_remote_service(ctx: &Context) -> ActorResult<()> {{
524    use super::ContextExt;
525
526    // 类型安全的远程调用
527    let response = ctx.{service_name_snake}()
528        .method_name(request)
529        .await?;
530
531    Ok(())
532}}
533```
534
535## 编译时路由
536
537所有远程调用在编译时确定目标服务和方法,无需运行时查找。
538*/
539"#
540        ))
541    }
542
543    /// 提取消息类型名称
544    fn extract_message_type(&self, type_name: &str) -> Result<String> {
545        let cleaned = type_name.trim_start_matches('.');
546        if let Some(last_part) = cleaned.split('.').next_back() {
547            Ok(last_part.to_string())
548        } else {
549            Ok(cleaned.to_string())
550        }
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use prost_types::MethodDescriptorProto;
558
559    #[test]
560    fn test_extract_message_type() {
561        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
562
563        assert_eq!(
564            generator
565                .extract_message_type(".test.v1.EchoRequest")
566                .unwrap(),
567            "EchoRequest"
568        );
569        assert_eq!(
570            generator
571                .extract_message_type("test.v1.EchoResponse")
572                .unwrap(),
573            "EchoResponse"
574        );
575        assert_eq!(
576            generator.extract_message_type("SimpleMessage").unwrap(),
577            "SimpleMessage"
578        );
579    }
580
581    #[test]
582    fn test_generate_message_impls_includes_payload_type() {
583        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
584
585        let methods = vec![MethodDescriptorProto {
586            name: Some("Echo".to_string()),
587            input_type: Some(".test.v1.EchoRequest".to_string()),
588            output_type: Some(".test.v1.EchoResponse".to_string()),
589            options: None,
590            ..Default::default()
591        }];
592
593        let result = generator.generate_message_impls(&methods).unwrap();
594
595        // Debug: print generated code
596        eprintln!("Generated code:\n{result}");
597
598        // 验证生成的代码包含 payload_type() 方法
599        assert!(
600            result.contains("fn payload_type"),
601            "Should contain 'fn payload_type'"
602        );
603        assert!(
604            result.contains("PayloadType"),
605            "Should contain 'PayloadType'"
606        );
607        // 验证默认值是 RpcReliable
608        assert!(
609            result.contains("RpcReliable"),
610            "Should contain 'RpcReliable'"
611        );
612    }
613
614    #[test]
615    fn test_generate_imports_includes_payload_type() {
616        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
617        let imports = generator.generate_imports();
618
619        // 验证导入了 PayloadType
620        assert!(imports.contains("PayloadType"));
621        assert!(imports.contains(
622            "use actr_protocol::{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}"
623        ));
624    }
625
626    #[test]
627    fn test_generate_client_code() {
628        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
629
630        let methods = vec![MethodDescriptorProto {
631            name: Some("Echo".to_string()),
632            input_type: Some(".test.v1.EchoRequest".to_string()),
633            output_type: Some(".test.v1.EchoResponse".to_string()),
634            options: None,
635            ..Default::default()
636        }];
637
638        let result = generator.generate(&methods);
639        assert!(result.is_ok());
640
641        let code = result.unwrap();
642        // 客户端代码也应该包含 RpcRequest impl
643        assert!(code.contains("impl RpcRequest for EchoRequest"));
644        assert!(code.contains("fn payload_type() -> PayloadType"));
645    }
646
647    #[test]
648    fn test_generate_server_code() {
649        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
650
651        let methods = vec![MethodDescriptorProto {
652            name: Some("Echo".to_string()),
653            input_type: Some(".test.v1.EchoRequest".to_string()),
654            output_type: Some(".test.v1.EchoResponse".to_string()),
655            options: None,
656            ..Default::default()
657        }];
658
659        let result = generator.generate(&methods);
660        assert!(result.is_ok());
661
662        let code = result.unwrap();
663        // 验证生成了 Handler trait
664        assert!(code.contains("pub trait TestServiceHandler"));
665        // 验证生成了 Dispatcher
666        assert!(code.contains("pub struct TestServiceDispatcher"));
667        // 验证生成了 payload_type
668        assert!(code.contains("fn payload_type() -> PayloadType"));
669    }
670}