1use actr_protocol::ActrType;
9use anyhow::{Result, anyhow};
10use heck::ToSnakeCase;
11use prost_types::MethodDescriptorProto;
12use quote::{format_ident, quote};
13
14use crate::payload_type_extractor::extract_payload_type_or_default;
15
16#[derive(Debug, Clone)]
18pub struct RemoteServiceInfo {
19 pub package_name: String,
20 pub service_name: String,
21 pub methods: Vec<String>,
22 pub actr_type: String,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum GeneratorRole {
28 ServerSide,
30 ClientSide,
32}
33
34pub struct ModernGenerator {
36 package_name: String,
37 service_name: String,
38 role: GeneratorRole,
39}
40
41impl ModernGenerator {
42 pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
43 Self {
44 package_name: package_name.to_string(),
45 service_name: service_name.to_string(),
46 role,
47 }
48 }
49
50 pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
52 match self.role {
53 GeneratorRole::ServerSide => self.generate_server_code(methods, &[]),
54 GeneratorRole::ClientSide => self.generate_client_code(methods),
55 }
56 }
57
58 pub fn generate_with_remotes(
60 &self,
61 methods: &[MethodDescriptorProto],
62 remote_services: &[RemoteServiceInfo],
63 ) -> Result<String> {
64 match self.role {
65 GeneratorRole::ServerSide => self.generate_server_code(methods, remote_services),
66 GeneratorRole::ClientSide => self.generate_client_code(methods),
67 }
68 }
69
70 fn generate_server_code(
72 &self,
73 methods: &[MethodDescriptorProto],
74 remote_services: &[RemoteServiceInfo],
75 ) -> Result<String> {
76 let sections = [
77 self.generate_imports(),
79 self.generate_message_impls(methods)?,
81 self.generate_handler_trait(methods)?,
83 self.generate_router_impl(methods, remote_services)?,
85 self.generate_workload_blanket_impl(methods)?,
87 self.generate_usage_docs(methods)?,
89 ];
90
91 Ok(sections.join("\n\n"))
92 }
93
94 fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
96 let sections = [
97 self.generate_imports(),
99 self.generate_message_impls(methods)?,
101 self.generate_context_extensions(methods)?,
103 self.generate_client_usage_docs(methods)?,
105 ];
106
107 Ok(sections.join("\n\n"))
108 }
109
110 fn generate_imports(&self) -> String {
112 let proto_module = self.package_name.replace('.', "_");
115
116 format!(
117 r#"// Auto-generated code - DO NOT EDIT
118// Generated by actr-cli's protoc-gen-actrframework plugin
119#[allow(dead_code, unused_imports)]
120use async_trait::async_trait;
121use bytes::Bytes;
122use prost::Message as ProstMessage;
123
124use actr_framework::{{Context, Dest, MessageDispatcher, Workload}};
125use actr_protocol::{{ActrId, ActorResult, RpcRequest, RpcEnvelope, PayloadType}};
126
127// Import protobuf message types (generated by prost)
128use super::{proto_module}::*;
129"#
130 )
131 }
132
133 fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
144 let mut impls = Vec::new();
145
146 for method in methods {
147 let input_type = self.extract_message_type(method.input_type())?;
148 let output_type = self.extract_message_type(method.output_type())?;
149
150 let route_key = format!(
152 "{}.{}.{}",
153 self.package_name,
154 self.service_name,
155 method.name()
156 );
157
158 let payload_type = extract_payload_type_or_default(method);
160
161 let payload_type_code = payload_type.as_rust_variant();
163
164 let impl_code = format!(
166 r#"/// RpcRequest trait implementation - associates Request and Response types
167///
168/// This enables type-safe RPC calls with automatic response type inference:
169/// ```rust,ignore
170/// let response: {output_type} = ctx.call(&target, request).await?;
171/// ```
172impl RpcRequest for {input_type} {{
173 type Response = {output_type};
174
175 fn route_key() -> &'static str {{
176 "{route_key}"
177 }}
178
179 fn payload_type() -> PayloadType {{
180 {payload_type_code}
181 }}
182}}"#
183 );
184
185 impls.push(impl_code);
186 }
187
188 Ok(impls.join("\n\n"))
189 }
190
191 fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
193 let handler_trait_name = format!("{}Handler", self.service_name);
194 let handler_trait_ident = format_ident!("{}", handler_trait_name);
195
196 let mut method_sigs = Vec::new();
197 for method in methods {
198 let method_name = method.name().to_snake_case();
199 let method_ident = format_ident!("{}", method_name);
200 let input_type = self.extract_message_type(method.input_type())?;
201 let output_type = self.extract_message_type(method.output_type())?;
202 let input_ident = format_ident!("{}", input_type);
203 let output_ident = format_ident!("{}", output_type);
204
205 method_sigs.push(quote! {
206 async fn #method_ident<C: Context>(
208 &self,
209 req: #input_ident,
210 ctx: &C,
211 ) -> ActorResult<#output_ident>;
212 });
213 }
214
215 let handler_trait_without_attr = quote! {
225 pub trait #handler_trait_ident: actr_framework::MaybeSendSync + 'static {
241 #(#method_sigs)*
242 }
243 };
244
245 Ok(format!(
252 "#[cfg_attr(not(target_arch = \"wasm32\"), async_trait)]\n\
253 #[cfg_attr(target_arch = \"wasm32\", async_trait(?Send))]\n\
254 {handler_trait_without_attr}"
255 ))
256 }
257
258 fn generate_router_impl(
260 &self,
261 methods: &[MethodDescriptorProto],
262 remote_services: &[RemoteServiceInfo],
263 ) -> Result<String> {
264 let router_name = format!("{}Dispatcher", self.service_name);
265 let router_ident = format_ident!("{}", router_name);
266 let workload_name = format!("{}Workload", self.service_name);
267 let workload_ident = format_ident!("{}", workload_name);
268 let handler_trait = format!("{}Handler", self.service_name);
269 let handler_trait_ident = format_ident!("{}", handler_trait);
270
271 let mut match_arms = Vec::new();
273 for method in methods {
274 let route_key = format!(
275 "{}.{}.{}",
276 self.package_name,
277 self.service_name,
278 method.name()
279 );
280 let method_name = method.name().to_snake_case();
281 let method_ident = format_ident!("{}", method_name);
282 let input_type = self.extract_message_type(method.input_type())?;
283 let input_ident = format_ident!("{}", input_type);
284
285 match_arms.push(quote! {
286 #route_key => {
287 let payload = envelope.payload.as_ref()
289 .ok_or_else(|| actr_protocol::ActrError::DecodeFailure(
290 "Missing payload in RpcEnvelope".to_string()
291 ))?;
292
293 let req = #input_ident::decode(&**payload)
295 .map_err(|e| actr_protocol::ActrError::DecodeFailure(
296 format!("Failed to decode {}: {}", stringify!(#input_ident), e)
297 ))?;
298
299 let resp = workload.0.#method_ident(req, ctx).await?;
301
302 Ok(resp.encode_to_vec().into())
304 }
305 });
306 }
307
308 use std::collections::HashMap;
311 let mut services_by_actr_type: HashMap<String, Vec<&RemoteServiceInfo>> = HashMap::new();
312 for remote_service in remote_services {
313 services_by_actr_type
314 .entry(remote_service.actr_type.clone())
315 .or_default()
316 .push(remote_service);
317 }
318
319 for (actr_type_str, services) in services_by_actr_type {
321 let parsed = ActrType::from_string_repr(&actr_type_str).map_err(|e| {
322 anyhow!(
323 "Invalid remote actr_type '{}': expected <manufacturer>:<name>[:<version>] ({})",
324 actr_type_str,
325 e
326 )
327 })?;
328 let manufacturer = parsed.manufacturer;
329 let name = parsed.name;
330
331 let mut route_keys = Vec::new();
333 for service in &services {
334 for method in &service.methods {
335 let route_key = format!(
336 "{}.{}.{}",
337 service.package_name, service.service_name, method
338 );
339 route_keys.push(route_key);
340 }
341 }
342
343 match_arms.push(quote! {
345 #(#route_keys)|* => {
346 let target_type = actr_protocol::ActrType {
347 manufacturer: #manufacturer.to_string(),
348 name: #name.to_string(),
349 version: "1.0.0".to_string(),
350 };
351 let target_id = ctx.discover_route_candidate(&target_type).await?;
352 ctx.call_raw(
353 &target_id,
354 envelope.route_key.as_str(),
355 envelope.payload.clone().unwrap_or_default(),
356 ).await
357 }
358 });
359 }
360
361 let workload_struct = quote! {
363 pub struct #workload_ident<T: #handler_trait_ident>(pub T);
367
368 impl<T: #handler_trait_ident> #workload_ident<T> {
369 pub fn new(handler: T) -> Self {
371 Self(handler)
372 }
373 }
374 };
375
376 let router_struct = quote! {
377 pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
388 };
389
390 let router_impl_without_attr = quote! {
391 impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
392 type Workload = #workload_ident<T>;
393
394 async fn dispatch<C: Context>(
395 workload: &Self::Workload,
396 envelope: RpcEnvelope,
397 ctx: &C,
398 ) -> ActorResult<Bytes> {
399 match envelope.route_key.as_str() {
400 #(#match_arms,)*
401 _ => Err(actr_protocol::ActrError::UnknownRoute(
402 envelope.route_key.to_string()
403 ))
404 }
405 }
406 }
407 };
408
409 let router_impl = format!(
415 "#[cfg_attr(not(target_arch = \"wasm32\"), async_trait)]\n\
416 #[cfg_attr(target_arch = \"wasm32\", async_trait(?Send))]\n\
417 {router_impl_without_attr}"
418 );
419
420 Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
421 }
422
423 fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
425 let router_name = format!("{}Dispatcher", self.service_name);
426 let router_ident = format_ident!("{}", router_name);
427 let workload_name = format!("{}Workload", self.service_name);
428 let workload_ident = format_ident!("{}", workload_name);
429 let handler_trait = format!("{}Handler", self.service_name);
430 let handler_trait_ident = format_ident!("{}", handler_trait);
431
432 Ok(quote! {
444 impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
449 type Dispatcher = #router_ident<T>;
450 }
451
452 impl<T: #handler_trait_ident> actr_framework::ServiceHandler for #workload_ident<T> {
461 type Workload = Self;
462 }
463 }
464 .to_string())
465 }
466
467 fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
469 let client_struct_name = format!("{}Client", self.service_name);
470 let client_ident = format_ident!("{}", client_struct_name);
471
472 let mut client_methods = Vec::new();
473 for method in methods {
474 let method_name = method.name().to_snake_case();
475 let method_ident = format_ident!("{}", method_name);
476 let input_type = self.extract_message_type(method.input_type())?;
477 let output_type = self.extract_message_type(method.output_type())?;
478 let input_ident = format_ident!("{}", input_type);
479 let output_ident = format_ident!("{}", output_type);
480
481 client_methods.push(quote! {
482 pub async fn #method_ident(
484 &self,
485 target: ActrId,
486 req: #input_ident,
487 ) -> ActorResult<#output_ident> {
488 self.ctx.call(&Dest::from(target), req).await
489 }
490 });
491 }
492
493 let extension_method_name = self.service_name.to_snake_case();
495 let extension_method_ident = format_ident!("{}", extension_method_name);
496
497 Ok(quote! {
498 pub struct #client_ident<'a, C: Context> {
502 ctx: &'a C,
503 }
504
505 impl<'a, C: Context> #client_ident<'a, C> {
506 #(#client_methods)*
507 }
508
509 pub trait ContextExt {
513 fn #extension_method_ident(&self) -> #client_ident<'_, Self> where Self: Sized + Context;
514 }
515
516 impl<T: Context> ContextExt for T {
517 fn #extension_method_ident(&self) -> #client_ident<'_, Self> {
518 #client_ident { ctx: self }
519 }
520 }
521 }
522 .to_string())
523 }
524
525 fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
527 let handler_trait = format!("{}Handler", self.service_name);
528 let first_method = methods.first();
529
530 let example_method = if let Some(method) = first_method {
531 let method_name = method.name().to_snake_case();
532 let input_type = self.extract_message_type(method.input_type())?;
533 let output_type = self.extract_message_type(method.output_type())?;
534 format!(
535 r#"
536 async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
537 // Implement business logic
538 Ok({output_type}::default())
539 }}"#
540 )
541 } else {
542 " // Implement methods...".to_string()
543 };
544
545 Ok(format!(
546 r#"/*
547## Usage Example
548
549### 1. Implement Business Logic
550
551```rust
552use actr_framework::Context;
553use actr_protocol::ActorResult;
554use async_trait::async_trait;
555
556pub struct MyService {{
557 // Business state
558}}
559
560#[async_trait]
561impl {handler_trait} for MyService {{
562{example_method}
563}}
564```
565
566### 2. Register the Entry Point
567
568```rust
569actr_framework::entry!({}Workload<MyService>);
570```
571
572## Architecture
573
574- **{handler_trait}**: user-implemented business logic interface
575- **{}Dispatcher**: zero-sized type static dispatcher (auto-generated)
576- **{}Workload<T>**: generated wrapper that satisfies orphan rules
577
578Users only need to implement {handler_trait}; the framework auto-provides routing and workload capabilities.
579*/
580"#,
581 self.service_name, self.service_name, self.service_name
582 ))
583 }
584
585 fn generate_client_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
587 let service_name_snake = self.service_name.to_snake_case();
588 let method_name_snake = methods
589 .first()
590 .map(|m| m.name().to_snake_case())
591 .unwrap_or("unknown_method".to_string());
592
593 Ok(format!(
594 r#"/*
595## Dependency Usage Example
596
597```rust
598use actr_framework::Context;
599use actr_protocol::{{ActorResult, ActrId}};
600
601async fn call_remote_service(ctx: &impl Context, target: ActrId) -> ActorResult<()> {{
602 use super::ContextExt;
603
604 // Type-safe remote call
605 let response = ctx.{service_name_snake}()
606 .{method_name_snake}(target, request)
607 .await?;
608
609 Ok(())
610}}
611```
612
613## Compile-time Routing
614
615All remote calls determine the target service and method at compile time — no runtime lookup needed.
616*/
617"#
618 ))
619 }
620
621 fn extract_message_type(&self, type_name: &str) -> Result<String> {
623 let cleaned = type_name.trim_start_matches('.');
624 if let Some(last_part) = cleaned.split('.').next_back() {
625 Ok(last_part.to_string())
626 } else {
627 Ok(cleaned.to_string())
628 }
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use prost_types::MethodDescriptorProto;
636
637 #[test]
638 fn test_extract_message_type() {
639 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
640
641 assert_eq!(
642 generator
643 .extract_message_type(".test.v1.EchoRequest")
644 .unwrap(),
645 "EchoRequest"
646 );
647 assert_eq!(
648 generator
649 .extract_message_type("test.v1.EchoResponse")
650 .unwrap(),
651 "EchoResponse"
652 );
653 assert_eq!(
654 generator.extract_message_type("SimpleMessage").unwrap(),
655 "SimpleMessage"
656 );
657 }
658
659 #[test]
660 fn test_generate_message_impls_includes_payload_type() {
661 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
662
663 let methods = vec![MethodDescriptorProto {
664 name: Some("Echo".to_string()),
665 input_type: Some(".test.v1.EchoRequest".to_string()),
666 output_type: Some(".test.v1.EchoResponse".to_string()),
667 options: None,
668 ..Default::default()
669 }];
670
671 let result = generator.generate_message_impls(&methods).unwrap();
672
673 eprintln!("Generated code:\n{result}");
675
676 assert!(
678 result.contains("fn payload_type"),
679 "Should contain 'fn payload_type'"
680 );
681 assert!(
682 result.contains("PayloadType"),
683 "Should contain 'PayloadType'"
684 );
685 assert!(
687 result.contains("RpcReliable"),
688 "Should contain 'RpcReliable'"
689 );
690 }
691
692 #[test]
693 fn test_generate_imports_includes_payload_type() {
694 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
695 let imports = generator.generate_imports();
696
697 assert!(imports.contains("PayloadType"));
699 assert!(imports.contains(
700 "use actr_protocol::{ActrId, ActorResult, RpcRequest, RpcEnvelope, PayloadType}"
701 ));
702 assert!(
703 imports.contains("use actr_framework::{Context, Dest, MessageDispatcher, Workload}")
704 );
705 }
706
707 #[test]
708 fn test_generate_client_code() {
709 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
710
711 let methods = vec![MethodDescriptorProto {
712 name: Some("Echo".to_string()),
713 input_type: Some(".test.v1.EchoRequest".to_string()),
714 output_type: Some(".test.v1.EchoResponse".to_string()),
715 options: None,
716 ..Default::default()
717 }];
718
719 let result = generator.generate(&methods);
720 assert!(result.is_ok());
721
722 let code = result.unwrap();
723 assert!(code.contains("impl RpcRequest for EchoRequest"));
725 assert!(code.contains("fn payload_type() -> PayloadType"));
726 assert!(code.contains("use actr_framework::{Context, Dest, MessageDispatcher, Workload}"));
727 assert!(code.contains(
728 "use actr_protocol::{ActrId, ActorResult, RpcRequest, RpcEnvelope, PayloadType}"
729 ));
730 }
731
732 #[test]
733 fn test_generate_server_code() {
734 let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
735
736 let methods = vec![MethodDescriptorProto {
737 name: Some("Echo".to_string()),
738 input_type: Some(".test.v1.EchoRequest".to_string()),
739 output_type: Some(".test.v1.EchoResponse".to_string()),
740 options: None,
741 ..Default::default()
742 }];
743
744 let result = generator.generate(&methods);
745 assert!(result.is_ok());
746
747 let code = result.unwrap();
748 assert!(code.contains("pub trait TestServiceHandler"));
750 assert!(code.contains("pub struct TestServiceDispatcher"));
752 assert!(code.contains("fn payload_type() -> PayloadType"));
754 }
755
756 #[test]
757 fn test_generate_server_code_with_no_local_methods() {
758 let generator = ModernGenerator::new("test.v1", "BridgeService", GeneratorRole::ServerSide);
759
760 let code = generator.generate(&[]).unwrap();
761
762 assert!(code.contains("pub trait BridgeServiceHandler"));
763 assert!(code.contains("pub struct BridgeServiceWorkload"));
764 assert!(code.contains("pub struct BridgeServiceDispatcher"));
765 assert!(code.contains("UnknownRoute"));
766 }
767
768 #[test]
769 fn test_generate_server_code_with_remote_forwarding_and_no_local_methods() {
770 let generator =
771 ModernGenerator::new("demo.app", "DemoClientApp", GeneratorRole::ServerSide);
772 let remote_services = vec![RemoteServiceInfo {
773 package_name: "echo".to_string(),
774 service_name: "EchoService".to_string(),
775 methods: vec!["Echo".to_string()],
776 actr_type: "acme:EchoService:1.0.0".to_string(),
777 }];
778
779 let code = generator
780 .generate_with_remotes(&[], &remote_services)
781 .unwrap();
782
783 assert!(code.contains("pub trait DemoClientAppHandler"));
784 assert!(code.contains("pub struct DemoClientAppWorkload"));
785 assert!(code.contains("\"echo.EchoService.Echo\""));
786 assert!(code.contains("manufacturer"));
787 assert!(code.contains("\"acme\""));
788 assert!(code.contains("name"));
789 assert!(code.contains("\"EchoService\""));
790 assert!(code.contains("discover_route_candidate"));
791 assert!(code.contains("call_raw"));
792 }
793}