actr_framework_protoc_codegen/
modern_generator.rs

1//! 现代化代码生成器
2//!
3//! 基于 actr-framework 的实际架构生成代码:
4//! - MessageDispatcher trait: zero-sized type static dispatcher
5//! - Workload trait: 业务工作负载,associates Dispatcher type
6//! - {Service}Handler trait: 用户实现的业务逻辑接口
7
8use anyhow::Result;
9use heck::ToSnakeCase;
10use prost_types::MethodDescriptorProto;
11use quote::{format_ident, quote};
12
13use crate::payload_type_extractor::extract_payload_type_or_default;
14
15/// 代码生成器角色
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum GeneratorRole {
18    /// 为 exports 生成服务端代码
19    ServerSide,
20    /// 为 dependencies 生成客户端代码
21    ClientSide,
22}
23
24/// 现代化代码生成器
25pub struct ModernGenerator {
26    package_name: String,
27    service_name: String,
28    role: GeneratorRole,
29    manufacturer: String,
30}
31
32impl ModernGenerator {
33    pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
34        Self {
35            package_name: package_name.to_string(),
36            service_name: service_name.to_string(),
37            role,
38            manufacturer: "acme".to_string(), // Default for backward compatibility
39        }
40    }
41
42    /// Set manufacturer (read from Actr.toml via --actrframework_opt)
43    pub fn set_manufacturer(&mut self, manufacturer: String) {
44        self.manufacturer = manufacturer;
45    }
46
47    /// 生成完整代码
48    pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
49        match self.role {
50            GeneratorRole::ServerSide => self.generate_server_code(methods),
51            GeneratorRole::ClientSide => self.generate_client_code(methods),
52        }
53    }
54
55    /// 生成服务端代码(exports)
56    fn generate_server_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
57        let sections = [
58            // 1. 生成导入
59            self.generate_imports(),
60            // 2. 生成 RpcRequest trait 实现(类型安全的 Request → Response 关联)
61            self.generate_message_impls(methods)?,
62            // 3. 生成 Handler trait(用户实现的接口)
63            self.generate_handler_trait(methods)?,
64            // 4. Generate Dispatcher implementation(zero-sized type static dispatcher)
65            self.generate_router_impl(methods)?,
66            // 5. 生成 Workload blanket 实现
67            self.generate_workload_blanket_impl(methods)?,
68            // 6. 生成使用文档
69            self.generate_usage_docs(methods)?,
70        ];
71
72        Ok(sections.join("\n\n"))
73    }
74
75    /// 生成客户端代码(dependencies)
76    fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
77        let sections = [
78            // 1. 生成导入
79            self.generate_imports(),
80            // 2. 生成 RpcRequest trait 实现(客户端也需要用于类型安全调用)
81            self.generate_message_impls(methods)?,
82            // 3. 生成 Context 扩展方法
83            self.generate_context_extensions(methods)?,
84            // 4. 生成使用文档
85            self.generate_client_usage_docs(methods)?,
86        ];
87
88        Ok(sections.join("\n\n"))
89    }
90
91    /// 生成导入语句
92    fn generate_imports(&self) -> String {
93        // 生成protobuf消息导入
94        // 假设消息类型在同级的 proto 模块中(由 prost 生成)
95        let proto_module = self.package_name.replace('.', "_");
96
97        format!(
98            r#"//! 自动生成的代码 - 请勿手动编辑
99//!
100//! 由 actr-cli 的 protoc-gen-actrframework 插件生成
101
102#![allow(dead_code, unused_imports)]
103
104use async_trait::async_trait;
105use bytes::Bytes;
106use prost::Message as ProstMessage;
107
108use actr_framework::{{Context, MessageDispatcher, Workload}};
109use actr_protocol::{{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}};
110
111// 导入 protobuf 消息类型(由 prost 生成)
112use super::{proto_module}::*;
113"#
114        )
115    }
116
117    /// 生成 RpcRequest trait 实现
118    ///
119    /// 为每个 RPC 方法的 Request 类型生成 RpcRequest trait 实现,
120    /// 关联其对应的 Response 类型。这使得客户端可以使用类型安全的 API:
121    ///
122    /// ```rust,ignore
123    /// let response: EchoResponse = ctx.call(&target, request).await?;
124    /// //              ^^^^^^^^^^^^ 从 EchoRequest::Response 推导
125    /// ```
126    fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
127        let mut impls = Vec::new();
128
129        for method in methods {
130            let input_type = self.extract_message_type(method.input_type())?;
131            let output_type = self.extract_message_type(method.output_type())?;
132
133            // 生成路由键
134            let route_key = format!(
135                "{}.{}.{}",
136                self.package_name,
137                self.service_name,
138                method.name()
139            );
140
141            // 提取 PayloadType
142            let payload_type = extract_payload_type_or_default(method);
143
144            // 生成 PayloadType 枚举路径(不能用 quote! 因为会加引号)
145            let payload_type_code = payload_type.as_rust_variant();
146
147            // 手动构造代码字符串避免 quote! 添加引号
148            let impl_code = format!(
149                r#"/// RpcRequest trait implementation - associates Request and Response types
150///
151/// This enables type-safe RPC calls with automatic response type inference:
152/// ```rust,ignore
153/// let response: {output_type} = ctx.call(&target, request).await?;
154/// ```
155impl RpcRequest for {input_type} {{
156    type Response = {output_type};
157
158    fn route_key() -> &'static str {{
159        "{route_key}"
160    }}
161
162    fn payload_type() -> PayloadType {{
163        {payload_type_code}
164    }}
165}}"#
166            );
167
168            impls.push(impl_code);
169        }
170
171        Ok(impls.join("\n\n"))
172    }
173
174    /// 生成 Handler trait
175    fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
176        let handler_trait_name = format!("{}Handler", self.service_name);
177        let handler_trait_ident = format_ident!("{}", handler_trait_name);
178
179        let mut method_sigs = Vec::new();
180        for method in methods {
181            let method_name = method.name().to_snake_case();
182            let method_ident = format_ident!("{}", method_name);
183            let input_type = self.extract_message_type(method.input_type())?;
184            let output_type = self.extract_message_type(method.output_type())?;
185            let input_ident = format_ident!("{}", input_type);
186            let output_ident = format_ident!("{}", output_type);
187
188            method_sigs.push(quote! {
189                /// RPC 方法:#method_name
190                async fn #method_ident<C: Context>(
191                    &self,
192                    req: #input_ident,
193                    ctx: &C,
194                ) -> ActorResult<#output_ident>;
195            });
196        }
197
198        let handler_trait_without_attr = quote! {
199            /// 服务处理器 trait - 用户需要实现此 trait
200            ///
201            /// # 示例
202            ///
203            /// ```rust,ignore
204            /// pub struct MyService { /* ... */ }
205            ///
206            /// #[async_trait]
207            /// impl #handler_trait_ident for MyService {
208            ///     async fn method_name(&self, req: Request, ctx: &Context) -> ActorResult<Response> {
209            ///         // 业务逻辑
210            ///         Ok(Response::default())
211            ///     }
212            /// }
213            /// ```
214            pub trait #handler_trait_ident: Send + Sync + 'static {
215                #(#method_sigs)*
216            }
217        };
218
219        // 手动添加 #[async_trait] 属性,避免 quote! 宏插入空格
220        Ok(format!("#[async_trait]\n{handler_trait_without_attr}"))
221    }
222
223    /// Generate Dispatcher and Workload 包装类型
224    fn generate_router_impl(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
225        let router_name = format!("{}Dispatcher", self.service_name);
226        let router_ident = format_ident!("{}", router_name);
227        let workload_name = format!("{}Workload", self.service_name);
228        let workload_ident = format_ident!("{}", workload_name);
229        let handler_trait = format!("{}Handler", self.service_name);
230        let handler_trait_ident = format_ident!("{}", handler_trait);
231
232        // 生成 match 分支
233        let mut match_arms = Vec::new();
234        for method in methods {
235            let route_key = format!(
236                "{}.{}.{}",
237                self.package_name,
238                self.service_name,
239                method.name()
240            );
241            let method_name = method.name().to_snake_case();
242            let method_ident = format_ident!("{}", method_name);
243            let input_type = self.extract_message_type(method.input_type())?;
244            let input_ident = format_ident!("{}", input_type);
245
246            match_arms.push(quote! {
247                #route_key => {
248                    // Extract payload from envelope
249                    let payload = envelope.payload.as_ref()
250                        .ok_or_else(|| actr_protocol::ProtocolError::DecodeError(
251                            "Missing payload in RpcEnvelope".to_string()
252                        ))?;
253
254                    // Deserialize request
255                    let req = #input_ident::decode(&**payload)
256                        .map_err(|e| actr_protocol::ProtocolError::Actr(
257                            actr_protocol::ActrError::DecodeFailure {
258                                message: format!("Failed to decode {}: {}", stringify!(#input_ident), e)
259                            }
260                        ))?;
261
262                    // 调用业务逻辑
263                    let resp = workload.0.#method_ident(req, ctx).await?;
264
265                    // 序列化响应
266                    Ok(resp.encode_to_vec().into())
267                }
268            });
269        }
270
271        // 分开生成各个部分以确保属性正确输出
272        let workload_struct = quote! {
273            /// Workload 包装类型
274            ///
275            /// 包装用户的 Handler 实现,满足孤儿规则
276            pub struct #workload_ident<T: #handler_trait_ident>(pub T);
277
278            impl<T: #handler_trait_ident> #workload_ident<T> {
279                /// 创建新的 Workload 实例
280                pub fn new(handler: T) -> Self {
281                    Self(handler)
282                }
283            }
284        };
285
286        let router_struct = quote! {
287            /// Message dispatcher - 零大小类型 (ZST)
288            ///
289            /// 此路由器由代码生成器自动生成,将 route_key 静态路由到对应的处理方法。
290            ///
291            /// # 性能特性
292            ///
293            /// - 零内存开销(PhantomData)
294            /// - 静态 match 派发,约 5-10ns
295            /// - 编译器完全内联
296            pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
297        };
298
299        let router_impl_without_attr = quote! {
300            impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
301                type Workload = #workload_ident<T>;
302
303                async fn dispatch<C: Context>(
304                    workload: &Self::Workload,
305                    envelope: RpcEnvelope,
306                    ctx: &C,
307                ) -> ActorResult<Bytes> {
308                    match envelope.route_key.as_str() {
309                        #(#match_arms,)*
310                        _ => Err(actr_protocol::ProtocolError::Actr(
311                            actr_protocol::ActrError::UnknownRoute {
312                                route_key: envelope.route_key.to_string()
313                            }
314                        ))
315                    }
316                }
317            }
318        };
319
320        // 手动添加 #[async_trait] 属性,避免 quote! 宏插入空格
321        let router_impl = format!("#[async_trait]\n{router_impl_without_attr}");
322
323        Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
324    }
325
326    /// 生成 Workload 实现
327    fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
328        let router_name = format!("{}Dispatcher", self.service_name);
329        let router_ident = format_ident!("{}", router_name);
330        let workload_name = format!("{}Workload", self.service_name);
331        let workload_ident = format_ident!("{}", workload_name);
332        let handler_trait = format!("{}Handler", self.service_name);
333        let handler_trait_ident = format_ident!("{}", handler_trait);
334
335        // 从 package_name 和 service_name 构造 ActrType
336        let actor_type_str = format!("{}.{}", self.package_name, self.service_name);
337        let manufacturer = &self.manufacturer;
338
339        Ok(quote! {
340            /// Workload trait 实现
341            ///
342            /// 为包装类型实现 Workload,使其可被 ActorSystem 识别和调度
343            impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
344                type Dispatcher = #router_ident<T>;
345
346                fn actor_type(&self) -> ActrType {
347                    ActrType {
348                        manufacturer: #manufacturer.to_string(),
349                        name: #actor_type_str.to_string(),
350                    }
351                }
352            }
353        }
354        .to_string())
355    }
356
357    /// 生成 Context 扩展方法(客户端)
358    fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
359        let client_struct_name = format!("{}Client", self.service_name);
360        let client_ident = format_ident!("{}", client_struct_name);
361
362        let mut client_methods = Vec::new();
363        for method in methods {
364            let method_name = method.name().to_snake_case();
365            let method_ident = format_ident!("{}", method_name);
366            let input_type = self.extract_message_type(method.input_type())?;
367            let output_type = self.extract_message_type(method.output_type())?;
368            let input_ident = format_ident!("{}", input_type);
369            let output_ident = format_ident!("{}", output_type);
370
371            let route_key = format!(
372                "{}.{}.{}",
373                self.package_name,
374                self.service_name,
375                method.name()
376            );
377
378            client_methods.push(quote! {
379                /// 调用远程方法:#method_name
380                pub async fn #method_ident(
381                    &self,
382                    req: #input_ident,
383                ) -> ActorResult<#output_ident> {
384                    self.ctx.call_remote(#route_key, req).await
385                }
386            });
387        }
388
389        // 生成 Context 扩展
390        let extension_method_name = self.service_name.to_snake_case();
391        let extension_method_ident = format_ident!("{}", extension_method_name);
392
393        Ok(quote! {
394            /// 客户端接口
395            ///
396            /// 提供类型安全的远程调用方法
397            pub struct #client_ident<'a> {
398                ctx: &'a Context,
399            }
400
401            impl<'a> #client_ident<'a> {
402                #(#client_methods)*
403            }
404
405            /// Context 扩展 trait
406            ///
407            /// 为 Context 添加便捷的客户端方法
408            pub trait ContextExt {
409                fn #extension_method_ident(&self) -> #client_ident;
410            }
411
412            impl ContextExt for Context {
413                fn #extension_method_ident(&self) -> #client_ident {
414                    #client_ident { ctx: self }
415                }
416            }
417        }
418        .to_string())
419    }
420
421    /// 生成服务端使用文档
422    fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
423        let handler_trait = format!("{}Handler", self.service_name);
424        let first_method = methods.first();
425
426        let example_method = if let Some(method) = first_method {
427            let method_name = method.name().to_snake_case();
428            let input_type = self.extract_message_type(method.input_type())?;
429            let output_type = self.extract_message_type(method.output_type())?;
430            format!(
431                r#"
432    async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
433        // 实现业务逻辑
434        Ok({output_type}::default())
435    }}"#
436            )
437        } else {
438            "    // 实现方法...".to_string()
439        };
440
441        Ok(format!(
442            r#"/*
443## 使用示例
444
445### 1. 实现业务逻辑
446
447```rust
448use actr_framework::{{Context, ActorSystem}};
449use actr_protocol::ActorResult;
450
451pub struct MyService {{
452    // 业务状态
453}}
454
455#[async_trait]
456impl {handler_trait} for MyService {{
457{example_method}
458}}
459```
460
461### 2. 启动服务
462
463```rust
464#[tokio::main]
465async fn main() -> ActorResult<()> {{
466    let config = actr_config::Config::from_file("Actr.toml")?;
467    let service = MyService {{ /* ... */ }};
468
469    ActorSystem::new(config)?
470        .attach(service)  // ← 自动获得 Workload + Dispatcher
471        .start()
472        .await?
473        .wait_for_shutdown()
474        .await
475}}
476```
477
478## 架构说明
479
480- **{handler_trait}**: 用户实现的业务逻辑接口
481- **{}Dispatcher**: zero-sized type static dispatcher(自动生成)
482- **Workload**: 通过 blanket impl 自动获得(自动生成)
483
484用户只需实现 {handler_trait},框架会自动提供路由和工作负载能力。
485*/
486"#,
487            self.service_name
488        ))
489    }
490
491    /// 生成客户端使用文档
492    fn generate_client_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
493        let service_name_snake = self.service_name.to_snake_case();
494        let method_name_snake = methods
495            .first()
496            .map(|m| m.name().to_snake_case())
497            .unwrap_or("unknown_method".to_string());
498
499        Ok(format!(
500            r#"/*
501## 客户端使用示例
502
503```rust
504use actr_framework::Context;
505use actr_protocol::ActorResult;
506
507async fn call_remote_service(ctx: &Context) -> ActorResult<()> {{
508    use super::ContextExt;
509
510    // 类型安全的远程调用
511    let response = ctx.{service_name_snake}()
512        .{method_name_snake}(request)
513        .await?;
514
515    Ok(())
516}}
517```
518
519## 编译时路由
520
521所有远程调用在编译时确定目标服务和方法,无需运行时查找。
522*/
523"#
524        ))
525    }
526
527    /// 提取消息类型名称
528    fn extract_message_type(&self, type_name: &str) -> Result<String> {
529        let cleaned = type_name.trim_start_matches('.');
530        if let Some(last_part) = cleaned.split('.').next_back() {
531            Ok(last_part.to_string())
532        } else {
533            Ok(cleaned.to_string())
534        }
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use prost_types::MethodDescriptorProto;
542
543    #[test]
544    fn test_extract_message_type() {
545        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
546
547        assert_eq!(
548            generator
549                .extract_message_type(".test.v1.EchoRequest")
550                .unwrap(),
551            "EchoRequest"
552        );
553        assert_eq!(
554            generator
555                .extract_message_type("test.v1.EchoResponse")
556                .unwrap(),
557            "EchoResponse"
558        );
559        assert_eq!(
560            generator.extract_message_type("SimpleMessage").unwrap(),
561            "SimpleMessage"
562        );
563    }
564
565    #[test]
566    fn test_generate_message_impls_includes_payload_type() {
567        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
568
569        let methods = vec![MethodDescriptorProto {
570            name: Some("Echo".to_string()),
571            input_type: Some(".test.v1.EchoRequest".to_string()),
572            output_type: Some(".test.v1.EchoResponse".to_string()),
573            options: None,
574            ..Default::default()
575        }];
576
577        let result = generator.generate_message_impls(&methods).unwrap();
578
579        // Debug: print generated code
580        eprintln!("Generated code:\n{result}");
581
582        // 验证生成的代码包含 payload_type() 方法
583        assert!(
584            result.contains("fn payload_type"),
585            "Should contain 'fn payload_type'"
586        );
587        assert!(
588            result.contains("PayloadType"),
589            "Should contain 'PayloadType'"
590        );
591        // 验证默认值是 RpcReliable
592        assert!(
593            result.contains("RpcReliable"),
594            "Should contain 'RpcReliable'"
595        );
596    }
597
598    #[test]
599    fn test_generate_imports_includes_payload_type() {
600        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
601        let imports = generator.generate_imports();
602
603        // 验证导入了 PayloadType
604        assert!(imports.contains("PayloadType"));
605        assert!(imports.contains(
606            "use actr_protocol::{ActorResult, ActrType, RpcRequest, RpcEnvelope, PayloadType}"
607        ));
608    }
609
610    #[test]
611    fn test_generate_client_code() {
612        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
613
614        let methods = vec![MethodDescriptorProto {
615            name: Some("Echo".to_string()),
616            input_type: Some(".test.v1.EchoRequest".to_string()),
617            output_type: Some(".test.v1.EchoResponse".to_string()),
618            options: None,
619            ..Default::default()
620        }];
621
622        let result = generator.generate(&methods);
623        assert!(result.is_ok());
624
625        let code = result.unwrap();
626        // 客户端代码也应该包含 RpcRequest impl
627        assert!(code.contains("impl RpcRequest for EchoRequest"));
628        assert!(code.contains("fn payload_type() -> PayloadType"));
629    }
630
631    #[test]
632    fn test_generate_server_code() {
633        let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
634
635        let methods = vec![MethodDescriptorProto {
636            name: Some("Echo".to_string()),
637            input_type: Some(".test.v1.EchoRequest".to_string()),
638            output_type: Some(".test.v1.EchoResponse".to_string()),
639            options: None,
640            ..Default::default()
641        }];
642
643        let result = generator.generate(&methods);
644        assert!(result.is_ok());
645
646        let code = result.unwrap();
647        // 验证生成了 Handler trait
648        assert!(code.contains("pub trait TestServiceHandler"));
649        // 验证生成了 Dispatcher
650        assert!(code.contains("pub struct TestServiceDispatcher"));
651        // 验证生成了 payload_type
652        assert!(code.contains("fn payload_type() -> PayloadType"));
653    }
654}