actix_prost_build/
generator.rs

1use crate::{config::HttpRule, method::Method, Config};
2use proc_macro2::TokenStream;
3use prost_build::{Service, ServiceGenerator};
4use quote::quote;
5use std::{collections::HashMap, fs::File, path::Path, rc::Rc};
6use syn::Item;
7
8pub struct ActixGenerator {
9    messages: Rc<HashMap<String, syn::ItemStruct>>,
10    config: Config,
11}
12
13#[derive(thiserror::Error, Debug)]
14pub enum Error {
15    #[error("could not open file {0}")]
16    File(#[from] std::io::Error),
17    #[error("could not parse the config {0}")]
18    Parse(#[from] serde_yaml::Error),
19}
20
21impl ActixGenerator {
22    pub fn new(path: impl AsRef<Path>) -> Result<ActixGenerator, Error> {
23        let file = File::open(path)?;
24        let config: Config = serde_yaml::from_reader(file)?;
25
26        Ok(ActixGenerator {
27            messages: Default::default(),
28            config,
29        })
30    }
31
32    fn map_methods_with_rules<'a, 'b>(
33        &'a self,
34        service: &'b Service,
35    ) -> Vec<(&'b prost_build::Method, &'a HttpRule)> {
36        let map: HashMap<String, &HttpRule> = self
37            .config
38            .http
39            .rules
40            .iter()
41            .map(|r| (r.selector.clone(), r))
42            .collect();
43        service
44            .methods
45            .iter()
46            .filter_map(|m| {
47                map.get(&format!(
48                    "{}.{}.{}",
49                    service.package, service.proto_name, m.proto_name
50                ))
51                .map(|r| (m, *r))
52            })
53            .collect()
54    }
55
56    fn router(&self, service: &Service) -> TokenStream {
57        let service_name = crate::string::naive_snake_case(&service.name);
58
59        let name = quote::format_ident!("route_{}", service_name);
60        let mod_name = quote::format_ident!("{}_actix", service_name);
61
62        let tonic_mod_name = quote::format_ident!("{}_server", service_name);
63        let trait_name = quote::format_ident!("{}", service.name);
64        let full_trait = quote::quote!(super::#tonic_mod_name::#trait_name);
65
66        let methods_with_config = self.map_methods_with_rules(service);
67
68        let methods: Vec<_> = methods_with_config
69            .into_iter()
70            .map(|(method, config)| {
71                Method::new(
72                    method.clone(),
73                    self.messages.get(&method.input_type).unwrap().clone(),
74                    self.messages.get(&method.output_type).unwrap().clone(),
75                    config.clone(),
76                    trait_name.clone(),
77                )
78            })
79            .collect();
80
81        if methods.is_empty() {
82            return quote!();
83        }
84        let request_structs = methods.iter().map(|m| m.request().generate_structs());
85        let fns = methods.iter().map(|m| m.generate_route());
86        let configs = methods.iter().map(|m| m.generate_config());
87        quote!(
88            pub mod #mod_name {
89                #![allow(unused_variables, dead_code, missing_docs)]
90
91                use super::*;
92                use #full_trait;
93                use std::sync::Arc;
94                use actix_web::Responder;
95
96                #(#request_structs)*
97
98                #(#fns)*
99
100                pub fn #name(
101                    config: &mut ::actix_web::web::ServiceConfig,
102                    service: Arc<dyn #trait_name + Send + Sync + 'static>,
103                ) {
104                    config.app_data(::actix_web::web::Data::from(service));
105                    #(#configs)*
106                }
107            }
108        )
109    }
110
111    fn parse_messages(&mut self, buf: &mut str) {
112        let file: syn::File = syn::parse_str(buf).unwrap();
113        self.messages = Rc::new(
114            file.items
115                .into_iter()
116                .filter_map(|item| match item {
117                    Item::Struct(message) => Some(message),
118                    _ => None,
119                })
120                .map(|message| (message.ident.to_string(), message))
121                .collect(),
122        );
123    }
124
125    fn token_stream_to_code(&self, tokens: TokenStream) -> String {
126        let ast: syn::File = syn::parse2(tokens).expect("not a valid tokenstream");
127        prettyplease::unparse(&ast)
128    }
129}
130
131impl ServiceGenerator for ActixGenerator {
132    fn generate(&mut self, service: Service, buf: &mut String) {
133        self.parse_messages(buf);
134        let router = self.router(&service);
135        buf.push_str(&self.token_stream_to_code(router));
136
137        #[cfg(feature = "conversions")]
138        {
139            use crate::conversions::ConversionsGenerator;
140            let conversions = ConversionsGenerator::new().ok().map(|mut g| {
141                g.messages = Rc::clone(&self.messages);
142                g.create_conversions(&service)
143            });
144
145            if let Some(conversions) = conversions {
146                buf.push_str(&self.token_stream_to_code(conversions));
147            }
148        }
149    }
150}