pbdb_macros/
lib.rs

1use std::{
2  env,
3  path::{Path, PathBuf},
4};
5
6use proc_macro2::TokenStream;
7use prost::Message;
8use quote::{format_ident, quote};
9
10mod descriptor {
11  include!(concat!(env!("OUT_DIR"), "/pbdb.descriptor.rs"));
12}
13
14#[proc_macro]
15pub fn pbdb_impls(_: proc_macro::TokenStream) -> proc_macro::TokenStream {
16  process_fds(&read_descriptor(
17    &PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR environment variable not set"))
18      .join("file_descriptor_set.bin"),
19  ))
20  .into()
21}
22
23fn read_descriptor(path: &Path) -> descriptor::FileDescriptorSet {
24  let bytes = std::fs::read(path).unwrap();
25  descriptor::FileDescriptorSet::decode(bytes.as_slice()).unwrap()
26}
27
28fn process_fds(fds: &descriptor::FileDescriptorSet) -> TokenStream {
29  let (globals, options): (Vec<_>, Vec<_>) = fds
30    .file
31    .iter()
32    .map(|file| &file.message_type)
33    .flatten()
34    .filter_map(|dp| process_dp(dp))
35    .unzip();
36  quote! {
37    pub fn open_db(
38      path: &std::path::Path
39    ) -> Result<::pbdb::DbGuard, pbdb::private::rocksdb::Error> {
40      use ::pbdb::private::{DB, rocksdb};
41      let mut opts = rocksdb::Options::default();
42      opts.create_if_missing(true);
43      opts.create_missing_column_families(true);
44      let mut cfs = vec![];
45      cfs.push(
46        rocksdb::ColumnFamilyDescriptor::new(
47          "__SingleRecord",
48          rocksdb::Options::default()
49        )
50      );
51      #(#options)*
52      let db = rocksdb::DB::open_cf_descriptors(&opts, path, cfs)?;
53      let mut write = DB.write();
54      assert!((*write).is_none(), "Trying to open DB without closing previous one.");
55      *write = Some(db);
56      Ok(::pbdb::DbGuard{})
57    }
58    #(#globals)*
59  }
60}
61
62fn process_dp(dp: &descriptor::DescriptorProto) -> Option<(TokenStream, TokenStream)> {
63  generate_collection(dp).or_else(|| generate_single_record(dp))
64}
65
66fn generate_collection(dp: &descriptor::DescriptorProto) -> Option<(TokenStream, TokenStream)> {
67  let id_fields: Vec<_> = dp
68    .field
69    .iter()
70    .filter(|field| {
71      field.options.as_ref().map_or(false, |options| {
72        options.id() != descriptor::field_options::IdType::NotUsed
73      })
74    })
75    .collect();
76  if id_fields.len() > 1 {
77    unimplemented!("Multiple id fields are not supported yet");
78  }
79  if let Some(id_field) = id_fields.first() {
80    if id_field.r#type() != descriptor::field_descriptor_proto::Type::String {
81      unimplemented!("Non-string id fields are not supported yet");
82    }
83    if id_field.label() == descriptor::field_descriptor_proto::Label::Repeated {
84      unimplemented!("Repeated id fields are not supported yet");
85    }
86    let message_name = format_ident!("{}", dp.name());
87    let id_field_name = format_ident!("{}", id_field.name());
88    let conversion =
89      if id_field.options.as_ref().unwrap().id() == descriptor::field_options::IdType::Default {
90        quote! {
91          as_bytes().to_vec()
92        }
93      } else {
94        quote! {
95          to_lowercase().as_bytes().to_vec()
96        }
97      };
98    Some((
99      quote! {
100        impl ::pbdb::Collection for #message_name {
101          const CF_NAME: &'static str = stringify!(#message_name);
102          type Id = String;
103          type SerializedId = Vec<u8>;
104
105          fn get_id(&self) -> Self::SerializedId {
106            self.#id_field_name.#conversion
107          }
108
109          fn build_id(id: &Self::Id) -> Self::SerializedId {
110            id.#conversion
111          }
112        }
113      },
114      quote! {
115        cfs.push(
116          rocksdb::ColumnFamilyDescriptor::new(
117            stringify!(#message_name),
118            rocksdb::Options::default()
119          )
120        );
121      },
122    ))
123  } else {
124    None
125  }
126}
127
128fn generate_single_record(dp: &descriptor::DescriptorProto) -> Option<(TokenStream, TokenStream)> {
129  if dp
130    .options
131    .as_ref()
132    .map_or(false, |options| options.single_record == Some(true))
133  {
134    let message_name = format_ident!("{}", dp.name());
135    Some((
136      quote! {
137        impl ::pbdb::SingleRecord for #message_name {
138          const RECORD_ID: &'static str = stringify!(#message_name);
139        }
140      },
141      quote! {},
142    ))
143  } else {
144    None
145  }
146}