mongo_orm_macro/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields};
#[proc_macro_derive(mongo_doc, attributes(foreign_key))]
pub fn mongo_entity_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
// Extract the struct name
let name = input.ident;
// Extract the fields from the struct
let data = match input.data {
Data::Struct(data) => data,
_ => panic!("MongoEntity can only be derived for structs"),
};
let mut to_doc_fields = Vec::new();
let mut from_doc_fields = Vec::new();
if let Fields::Named(fields) = data.fields {
for field in fields.named {
let field_name = field.ident.clone().unwrap();
let ty = field.ty.clone(); // Use the type of the field
let is_foreign_key = field
.attrs
.iter()
.any(|attr| attr.path().is_ident("foreign_key"));
if field_name == "id" {
// Special handling for `id` field
to_doc_fields.push(if let syn::Type::Path(type_path) = &ty {
let type_string = quote!(#type_path).to_string();
if type_string.contains("Option") {
// If the type is Option<String> or Option<ObjectId>
quote! {
if let Some(value) = &self.#field_name {
if let Ok(object_id) = bson::oid::ObjectId::parse_str(value) {
doc.insert("_id", object_id);
} else {
doc.insert("_id", value.clone());
}
} else {
doc.insert("_id", bson::oid::ObjectId::new());
}
}
} else if type_string.contains("String") {
// If the type is String
quote! {
if let Ok(object_id) = bson::oid::ObjectId::parse_str(&self.#field_name) {
doc.insert("_id", object_id);
} else {
doc.insert("_id", self.#field_name.clone());
}
}
} else {
// For other types, directly store the value
quote! {
doc.insert("_id", bson::to_bson(&self.#field_name).unwrap());
}
}
} else {
panic!("Unsupported type for `id` field");
});
// Generate logic for deserialization based on the type of `id`
from_doc_fields.push(if let syn::Type::Path(type_path) = &ty {
let type_string = quote!(#type_path).to_string();
if type_string.contains("Option") {
// If the type is Option<String> or Option<ObjectId>
quote! {
#field_name: match doc.get("_id") {
Some(bson::Bson::ObjectId(oid)) => Some(oid.to_hex()),
Some(bson::Bson::String(s)) => Some(s.clone()),
_ => None,
},
}
} else if type_string.contains("String") {
// If the type is String
quote! {
#field_name: match doc.get("_id") {
Some(bson::Bson::ObjectId(oid)) => oid.to_hex(),
Some(bson::Bson::String(s)) => s.clone(),
_ => panic!("Missing or invalid _id field for {}", stringify!(#field_name)),
},
}
} else {
// Any other type is directly deserialized
quote! {
#field_name: bson::from_bson(doc.remove("_id").expect()).unwrap(),
}
}
} else {
panic!("Unsupported type for `id` field");
});
} else if is_foreign_key {
to_doc_fields.push(if let syn::Type::Path(type_path) = &ty {
let type_string = quote!(#type_path).to_string();
if type_string.contains("Option") {
// Handle Option<String> or Option<ObjectId>
quote! {
if let Some(value) = &self.#field_name {
if let Ok(object_id) = bson::oid::ObjectId::parse_str(value) {
doc.insert(stringify!(#field_name), object_id);
} else {
doc.insert(stringify!(#field_name), value.clone());
}
}
}
} else if type_string.contains("String") {
// If the type is String
quote! {
if let Ok(object_id) = bson::oid::ObjectId::parse_str(&self.#field_name) {
doc.insert(stringify!(#field_name), object_id);
}
else {
doc.insert(stringify!(#field_name), self.#field_name.clone());
}
}
} else if type_string.contains("ObjectId") {
// If the type is ObjectId
quote! {
doc.insert(stringify!(#field_name), self.#field_name.clone());
}
} else {
panic!("Unsupported type for `foreign_key` field: {}", stringify!(#field_name));
}
} else {
panic!("Unsupported type for `foreign_key` field: {}", stringify!(#field_name));
});
from_doc_fields.push(if let syn::Type::Path(type_path) = &ty {
let type_string = quote!(#type_path).to_string();
if type_string.contains("Option") {
quote! {
#field_name: match doc.get(stringify!(#field_name)) {
Some(bson::Bson::ObjectId(oid)) => Some(oid.to_hex()),
Some(bson::Bson::String(s)) => Some(s.clone()),
_ => None,
},
}
}else if type_string.contains("String") {
// If the type is String
quote! {
#field_name: match doc.get(stringify!(#field_name)) {
Some(bson::Bson::ObjectId(oid)) => oid.to_hex(),
Some(bson::Bson::String(s)) => s.clone(),
_ => panic!("Invalid or missing foreign key: {}", stringify!(#field_name)),
},
}
} else if type_string.contains("ObjectId") {
// If the type is ObjectId
quote! {
#field_name: match doc.get(stringify!(#field_name)) {
Some(bson::Bson::ObjectId(oid)) => oid.clone(),
_ => panic!("Invalid or missing foreign key: {}", stringify!(#field_name)),
},
}
} else {
panic!("Unsupported type for `foreign_key` field: {}", stringify!(#field_name));
}
} else {
panic!("Unsupported type for `foreign_key` field: {}", stringify!(#field_name));
});
} else {
// Handle regular fields
to_doc_fields.push(quote! {
doc.insert(stringify!(#field_name), bson::to_bson(&self.#field_name).unwrap());
});
from_doc_fields.push(quote! {
#field_name: bson::from_bson(doc.remove(stringify!(#field_name)).unwrap()).unwrap(),
});
}
}
}
// Generate the implementation
let expanded = quote! {
impl MongoEntity for #name {
fn to_document(&self) -> bson::Document {
let mut doc = bson::Document::new();
#(#to_doc_fields)*
doc
}
fn from_document(mut doc: bson::Document) -> Self {
Self {
#(#from_doc_fields)*
}
}
}
};
// Convert the implementation into a token stream and return
TokenStream::from(expanded)
}