actr_framework_protoc_codegen/
modern_generator.rs

1//! 现代化代码生成器
2//!
3//! 基于 actr-framework 的实际架构生成代码:
4//! - MessageDispatcher trait: zero-sized type static dispatcher
5//! - Workload trait: 业务工作负载,associates Dispatcher type
6//! - {Service}Handler trait: 用户实现的业务逻辑接口
7
8use anyhow::Result;
9use heck::ToSnakeCase;
10use prost_types::MethodDescriptorProto;
11use quote::{format_ident, quote};
12
13use crate::payload_type_extractor::extract_payload_type_or_default;
14
15/// 代码生成器角色
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum GeneratorRole {
18    /// 为 exports 生成服务端代码
19    ServerSide,
20    /// 为 dependencies 生成客户端代码
21    ClientSide,
22}
23
24/// 现代化代码生成器
25pub struct ModernGenerator {
26    package_name: String,
27    service_name: String,
28    role: GeneratorRole,
29}
30
31impl ModernGenerator {
32    pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
33        Self {
34            package_name: package_name.to_string(),
35            service_name: service_name.to_string(),
36            role,
37        }
38    }
39
40    /// 生成完整代码
41    pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
42        match self.role {
43            GeneratorRole::ServerSide => self.generate_server_code(methods),
44            GeneratorRole::ClientSide => self.generate_client_code(methods),
45        }
46    }
47
48    /// 生成服务端代码(exports)
49    fn generate_server_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
50        let sections = [
51            // 1. 生成导入
52            self.generate_imports(),
53            // 2. 生成 RpcRequest trait 实现(类型安全的 Request → Response 关联)
54            self.generate_message_impls(methods)?,
55            // 3. 生成 Handler trait(用户实现的接口)
56            self.generate_handler_trait(methods)?,
57            // 4. Generate Dispatcher implementation(zero-sized type static dispatcher)
58            self.generate_router_impl(methods)?,
59            // 5. 生成 Workload blanket 实现
60            self.generate_workload_blanket_impl(methods)?,
61            // 6. 生成使用文档
62            self.generate_usage_docs(methods)?,
63        ];
64
65        Ok(sections.join("\n\n"))
66    }
67
68    /// 生成客户端代码(dependencies)
69    fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
70        let sections = [
71            // 1. 生成导入
72            self.generate_imports(),
73            // 2. 生成 RpcRequest trait 实现(客户端也需要用于类型安全调用)
74            self.generate_message_impls(methods)?,
75            // 3. 生成 Context 扩展方法
76            self.generate_context_extensions(methods)?,
77            // 4. 生成使用文档
78            self.generate_client_usage_docs(methods)?,
79        ];
80
81        Ok(sections.join("\n\n"))
82    }
83
84    /// 生成导入语句
85    fn generate_imports(&self) -> String {
86        // 生成protobuf消息导入
87        // 假设消息类型在同级的 proto 模块中(由 prost 生成)
88        let proto_module = self.package_name.replace('.', "_");
89
90        format!(
91            r#"//! 自动生成的代码 - 请勿手动编辑
92//!
93//! 由 actr-cli 的 protoc-gen-actrframework 插件生成
94
95#![allow(dead_code, unused_imports)]
96
97use async_trait::async_trait;
98use bytes::Bytes;
99use prost::Message as ProstMessage;
100
101use actr_framework::{{Context, MessageDispatcher, Workload}};
102use actr_protocol::{{ActorResult, RpcRequest, RpcEnvelope, PayloadType}};
103
104// 导入 protobuf 消息类型(由 prost 生成)
105use super::{proto_module}::*;
106"#
107        )
108    }
109
110    /// 生成 RpcRequest trait 实现
111    ///
112    /// 为每个 RPC 方法的 Request 类型生成 RpcRequest trait 实现,
113    /// 关联其对应的 Response 类型。这使得客户端可以使用类型安全的 API:
114    ///
115    /// ```rust,ignore
116    /// let response: EchoResponse = ctx.call(&target, request).await?;
117    /// //              ^^^^^^^^^^^^ 从 EchoRequest::Response 推导
118    /// ```
119    fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
120        let mut impls = Vec::new();
121
122        for method in methods {
123            let input_type = self.extract_message_type(method.input_type())?;
124            let output_type = self.extract_message_type(method.output_type())?;
125
126            // 生成路由键
127            let route_key = format!(
128                "{}.{}.{}",
129                self.package_name,
130                self.service_name,
131                method.name()
132            );
133
134            // 提取 PayloadType
135            let payload_type = extract_payload_type_or_default(method);
136
137            // 生成 PayloadType 枚举路径(不能用 quote! 因为会加引号)
138            let payload_type_code = payload_type.as_rust_variant();
139
140            // 手动构造代码字符串避免 quote! 添加引号
141            let impl_code = format!(
142                r#"/// RpcRequest trait implementation - associates Request and Response types
143///
144/// This enables type-safe RPC calls with automatic response type inference:
145/// ```rust,ignore
146/// let response: {output_type} = ctx.call(&target, request).await?;
147/// ```
148impl RpcRequest for {input_type} {{
149    type Response = {output_type};
150
151    fn route_key() -> &'static str {{
152        "{route_key}"
153    }}
154
155    fn payload_type() -> PayloadType {{
156        {payload_type_code}
157    }}
158}}"#
159            );
160
161            impls.push(impl_code);
162        }
163
164        Ok(impls.join("\n\n"))
165    }
166
167    /// 生成 Handler trait
168    fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
169        let handler_trait_name = format!("{}Handler", self.service_name);
170        let handler_trait_ident = format_ident!("{}", handler_trait_name);
171
172        let mut method_sigs = Vec::new();
173        for method in methods {
174            let method_name = method.name().to_snake_case();
175            let method_ident = format_ident!("{}", method_name);
176            let input_type = self.extract_message_type(method.input_type())?;
177            let output_type = self.extract_message_type(method.output_type())?;
178            let input_ident = format_ident!("{}", input_type);
179            let output_ident = format_ident!("{}", output_type);
180
181            method_sigs.push(quote! {
182                /// RPC 方法:#method_name
183                async fn #method_ident<C: Context>(
184                    &self,
185                    req: #input_ident,
186                    ctx: &C,
187                ) -> ActorResult<#output_ident>;
188            });
189        }
190
191        let handler_trait_without_attr = quote! {
192            /// 服务处理器 trait - 用户需要实现此 trait
193            ///
194            /// # 示例
195            ///
196            /// ```rust,ignore
197            /// pub struct MyService { /* ... */ }
198            ///
199            /// #[async_trait]
200            /// impl #handler_trait_ident for MyService {
201            ///     async fn method_name(&self, req: Request, ctx: &Context) -> ActorResult<Response> {
202            ///         // 业务逻辑
203            ///         Ok(Response::default())
204            ///     }
205            /// }
206            /// ```
207            pub trait #handler_trait_ident: Send + Sync + 'static {
208                #(#method_sigs)*
209            }
210        };
211
212        // 手动添加 #[async_trait] 属性,避免 quote! 宏插入空格
213        Ok(format!("#[async_trait]\n{handler_trait_without_attr}"))
214    }
215
216    /// Generate Dispatcher and Workload 包装类型
217    fn generate_router_impl(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
218        let router_name = format!("{}Dispatcher", self.service_name);
219        let router_ident = format_ident!("{}", router_name);
220        let workload_name = format!("{}Workload", self.service_name);
221        let workload_ident = format_ident!("{}", workload_name);
222        let handler_trait = format!("{}Handler", self.service_name);
223        let handler_trait_ident = format_ident!("{}", handler_trait);
224
225        // 生成 match 分支
226        let mut match_arms = Vec::new();
227        for method in methods {
228            let route_key = format!(
229                "{}.{}.{}",
230                self.package_name,
231                self.service_name,
232                method.name()
233            );
234            let method_name = method.name().to_snake_case();
235            let method_ident = format_ident!("{}", method_name);
236            let input_type = self.extract_message_type(method.input_type())?;
237            let input_ident = format_ident!("{}", input_type);
238
239            match_arms.push(quote! {
240                #route_key => {
241                    // Extract payload from envelope
242                    let payload = envelope.payload.as_ref()
243                        .ok_or_else(|| actr_protocol::ProtocolError::DecodeError(
244                            "Missing payload in RpcEnvelope".to_string()
245                        ))?;
246
247                    // Deserialize request
248                    let req = #input_ident::decode(&**payload)
249                        .map_err(|e| actr_protocol::ProtocolError::Actr(
250                            actr_protocol::ActrError::DecodeFailure {
251                                message: format!("Failed to decode {}: {}", stringify!(#input_ident), e)
252                            }
253                        ))?;
254
255                    // 调用业务逻辑
256                    let resp = workload.0.#method_ident(req, ctx).await?;
257
258                    // 序列化响应
259                    Ok(resp.encode_to_vec().into())
260                }
261            });
262        }
263
264        // 分开生成各个部分以确保属性正确输出
265        let workload_struct = quote! {
266            /// Workload 包装类型
267            ///
268            /// 包装用户的 Handler 实现,满足孤儿规则
269            pub struct #workload_ident<T: #handler_trait_ident>(pub T);
270
271            impl<T: #handler_trait_ident> #workload_ident<T> {
272                /// 创建新的 Workload 实例
273                pub fn new(handler: T) -> Self {
274                    Self(handler)
275                }
276            }
277        };
278
279        let router_struct = quote! {
280            /// Message dispatcher - 零大小类型 (ZST)
281            ///
282            /// 此路由器由代码生成器自动生成,将 route_key 静态路由到对应的处理方法。
283            ///
284            /// # 性能特性
285            ///
286            /// - 零内存开销(PhantomData)
287            /// - 静态 match 派发,约 5-10ns
288            /// - 编译器完全内联
289            pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
290        };
291
292        let router_impl_without_attr = quote! {
293            impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
294                type Workload = #workload_ident<T>;
295
296                async fn dispatch<C: Context>(
297                    workload: &Self::Workload,
298                    envelope: RpcEnvelope,
299                    ctx: &C,
300                ) -> ActorResult<Bytes> {
301                    match envelope.route_key.as_str() {
302                        #(#match_arms,)*
303                        _ => Err(actr_protocol::ProtocolError::Actr(
304                            actr_protocol::ActrError::UnknownRoute {
305                                route_key: envelope.route_key.to_string()
306                            }
307                        ))
308                    }
309                }
310            }
311        };
312
313        // 手动添加 #[async_trait] 属性,避免 quote! 宏插入空格
314        let router_impl = format!("#[async_trait]\n{router_impl_without_attr}");
315
316        Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
317    }
318
319    /// 生成 Workload 实现
320    fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
321        let router_name = format!("{}Dispatcher", self.service_name);
322        let router_ident = format_ident!("{}", router_name);
323        let workload_name = format!("{}Workload", self.service_name);
324        let workload_ident = format_ident!("{}", workload_name);
325        let handler_trait = format!("{}Handler", self.service_name);
326        let handler_trait_ident = format_ident!("{}", handler_trait);
327
328        Ok(quote! {
329            /// Workload trait 实现
330            ///
331            /// 为包装类型实现 Workload,使其可被 ActorSystem 识别和调度
332            impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
333                type Dispatcher = #router_ident<T>;
334            }
335        }
336        .to_string())
337    }
338
339    /// 生成 Context 扩展方法(客户端)
340    fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
341        let client_struct_name = format!("{}Client", self.service_name);
342        let client_ident = format_ident!("{}", client_struct_name);
343
344        let mut client_methods = Vec::new();
345        for method in methods {
346            let method_name = method.name().to_snake_case();
347            let method_ident = format_ident!("{}", method_name);
348            let input_type = self.extract_message_type(method.input_type())?;
349            let output_type = self.extract_message_type(method.output_type())?;
350            let input_ident = format_ident!("{}", input_type);
351            let output_ident = format_ident!("{}", output_type);
352
353            let route_key = format!(
354                "{}.{}.{}",
355                self.package_name,
356                self.service_name,
357                method.name()
358            );
359
360            client_methods.push(quote! {
361                /// 调用远程方法:#method_name
362                pub async fn #method_ident(
363                    &self,
364                    req: #input_ident,
365                ) -> ActorResult<#output_ident> {
366                    self.ctx.call_remote(#route_key, req).await
367                }
368            });
369        }
370
371        // 生成 Context 扩展
372        let extension_method_name = self.service_name.to_snake_case();
373        let extension_method_ident = format_ident!("{}", extension_method_name);
374
375        Ok(quote! {
376            /// 客户端接口
377            ///
378            /// 提供类型安全的远程调用方法
379            pub struct #client_ident<'a> {
380                ctx: &'a Context,
381            }
382
383            impl<'a> #client_ident<'a> {
384                #(#client_methods)*
385            }
386
387            /// Context 扩展 trait
388            ///
389            /// 为 Context 添加便捷的客户端方法
390            pub trait ContextExt {
391                fn #extension_method_ident(&self) -> #client_ident;
392            }
393
394            impl ContextExt for Context {
395                fn #extension_method_ident(&self) -> #client_ident {
396                    #client_ident { ctx: self }
397                }
398            }
399        }
400        .to_string())
401    }
402
403    /// 生成服务端使用文档
404    fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
405        let handler_trait = format!("{}Handler", self.service_name);
406        let first_method = methods.first();
407
408        let example_method = if let Some(method) = first_method {
409            let method_name = method.name().to_snake_case();
410            let input_type = self.extract_message_type(method.input_type())?;
411            let output_type = self.extract_message_type(method.output_type())?;
412            format!(
413                r#"
414    async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
415        // 实现业务逻辑
416        Ok({output_type}::default())
417    }}"#
418            )
419        } else {
420            "    // 实现方法...".to_string()
421        };
422
423        Ok(format!(
424            r#"/*
425## 使用示例
426
427### 1. 实现业务逻辑
428
429```rust
430use actr_framework::{{Context, ActorSystem}};
431use actr_protocol::ActorResult;
432
433pub struct MyService {{
434    // 业务状态
435}}
436
437#[async_trait]
438impl {handler_trait} for MyService {{
439{example_method}
440}}
441```
442
443### 2. 启动服务
444
445```rust
446#[tokio::main]
447async fn main() -> ActorResult<()> {{
448    let config = actr_config::Config::from_file("Actr.toml")?;
449    let service = MyService {{ /* ... */ }};
450
451    ActorSystem::new(config)?
452        .attach(service)  // ← 自动获得 Workload + Dispatcher
453        .start()
454        .await?
455        .wait_for_shutdown()
456        .await
457}}
458```
459
460## 架构说明
461
462- **{handler_trait}**: 用户实现的业务逻辑接口
463- **{}Dispatcher**: zero-sized type static dispatcher(自动生成)
464- **Workload**: 通过 blanket impl 自动获得(自动生成)
465
466用户只需实现 {handler_trait},框架会自动提供路由和工作负载能力。
467*/
468"#,
469            self.service_name
470        ))
471    }
472
473    /// 生成客户端使用文档
474    fn generate_client_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
475        let service_name_snake = self.service_name.to_snake_case();
476        let method_name_snake = methods
477            .first()
478            .map(|m| m.name().to_snake_case())
479            .unwrap_or("unknown_method".to_string());
480
481        Ok(format!(
482            r#"/*
483## 客户端使用示例
484
485```rust
486use actr_framework::Context;
487use actr_protocol::ActorResult;
488
489async fn call_remote_service(ctx: &Context) -> ActorResult<()> {{
490    use super::ContextExt;
491
492    // 类型安全的远程调用
493    let response = ctx.{service_name_snake}()
494        .{method_name_snake}(request)
495        .await?;
496
497    Ok(())
498}}
499```
500
501## 编译时路由
502
503所有远程调用在编译时确定目标服务和方法,无需运行时查找。
504*/
505"#
506        ))
507    }
508
509    /// 提取消息类型名称
510    fn extract_message_type(&self, type_name: &str) -> Result<String> {
511        let cleaned = type_name.trim_start_matches('.');
512        if let Some(last_part) = cleaned.split('.').next_back() {
513            Ok(last_part.to_string())
514        } else {
515            Ok(cleaned.to_string())
516        }
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use actr_protocol::ServiceName;
524    use prost_types::MethodDescriptorProto;
525
526    #[test]
527    fn test_extract_message_type() {
528        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
529
530        assert_eq!(
531            generator
532                .extract_message_type(".test.v1.EchoRequest")
533                .unwrap(),
534            "EchoRequest"
535        );
536        assert_eq!(
537            generator
538                .extract_message_type("test.v1.EchoResponse")
539                .unwrap(),
540            "EchoResponse"
541        );
542        assert_eq!(
543            generator.extract_message_type("SimpleMessage").unwrap(),
544            "SimpleMessage"
545        );
546    }
547
548    #[test]
549    fn test_generate_message_impls_includes_payload_type() {
550        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
551
552        let methods = vec![MethodDescriptorProto {
553            name: Some("Echo".to_string()),
554            input_type: Some(".test.v1.EchoRequest".to_string()),
555            output_type: Some(".test.v1.EchoResponse".to_string()),
556            options: None,
557            ..Default::default()
558        }];
559
560        let result = generator.generate_message_impls(&methods).unwrap();
561
562        // Debug: print generated code
563        eprintln!("Generated code:\n{result}");
564
565        // 验证生成的代码包含 payload_type() 方法
566        assert!(
567            result.contains("fn payload_type"),
568            "Should contain 'fn payload_type'"
569        );
570        assert!(
571            result.contains("PayloadType"),
572            "Should contain 'PayloadType'"
573        );
574        // 验证默认值是 RpcReliable
575        assert!(
576            result.contains("RpcReliable"),
577            "Should contain 'RpcReliable'"
578        );
579    }
580
581    #[test]
582    fn test_generate_imports_includes_payload_type() {
583        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
584        let imports = generator.generate_imports();
585
586        // 验证导入了 PayloadType
587        assert!(imports.contains("PayloadType"));
588        assert!(
589            imports
590                .contains("use actr_protocol::{ActorResult, RpcRequest, RpcEnvelope, PayloadType}")
591        );
592    }
593
594    #[test]
595    fn test_generate_client_code() {
596        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
597
598        let methods = vec![MethodDescriptorProto {
599            name: Some("Echo".to_string()),
600            input_type: Some(".test.v1.EchoRequest".to_string()),
601            output_type: Some(".test.v1.EchoResponse".to_string()),
602            options: None,
603            ..Default::default()
604        }];
605
606        let result = generator.generate(&methods);
607        assert!(result.is_ok());
608
609        let code = result.unwrap();
610        // 客户端代码也应该包含 RpcRequest impl
611        assert!(code.contains("impl RpcRequest for EchoRequest"));
612        assert!(code.contains("fn payload_type() -> PayloadType"));
613    }
614
615    #[test]
616    fn test_generate_server_code() {
617        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
618
619        let methods = vec![MethodDescriptorProto {
620            name: Some("Echo".to_string()),
621            input_type: Some(".test.v1.EchoRequest".to_string()),
622            output_type: Some(".test.v1.EchoResponse".to_string()),
623            options: None,
624            ..Default::default()
625        }];
626
627        let result = generator.generate(&methods);
628        assert!(result.is_ok());
629
630        let code = result.unwrap();
631        // 验证生成了 Handler trait
632        assert!(code.contains("pub trait TestServiceHandler"));
633        // 验证生成了 Dispatcher
634        assert!(code.contains("pub struct TestServiceDispatcher"));
635        // 验证生成了 payload_type
636        assert!(code.contains("fn payload_type() -> PayloadType"));
637    }
638}