proto_convert_derive/
lib.rs1use proc_macro::TokenStream;
153use proc_macro2::Span;
154use quote::quote;
155use syn::parse::Parser;
156use syn::{self, Attribute, DeriveInput, Expr, Field, Lit, Meta, Type};
157use syn::{punctuated::Punctuated, token::Comma};
158
159#[proc_macro_derive(ProtoConvert, attributes(proto))]
160pub fn proto_convert_derive(input: TokenStream) -> TokenStream {
161 let ast: DeriveInput = syn::parse(input).unwrap();
162 let name = &ast.ident;
163 let proto_module = get_proto_module(&ast.attrs).unwrap_or_else(|| "proto".to_string());
164 let proto_name = get_proto_struct_rename(&ast.attrs).unwrap_or_else(|| name.to_string());
165 let proto_path =
166 syn::parse_str::<syn::Path>(&format!("{}::{}", proto_module, proto_name)).unwrap();
167
168 match &ast.data {
169 syn::Data::Struct(data_struct) => {
170 match &data_struct.fields {
171 syn::Fields::Named(fields_named) => {
172 let fields = &fields_named.named;
173 let primitives = ["i32", "u32", "i64", "u64", "f32", "f64", "bool", "String"];
174 let from_proto_fields = fields.iter().map(|field| {
175 let field_name = field.ident.as_ref().unwrap();
176 if has_proto_ignore(field) {
177 quote! {
178 #field_name: Default::default()
179 }
180 } else {
181 let proto_field_ident = if let Some(rename) = get_proto_rename(field) {
182 syn::Ident::new(&rename, Span::call_site())
183 } else {
184 field_name.clone()
185 };
186 let field_type = &field.ty;
187 let is_transparent = has_transparent_attr(field);
188 let derive_from_with = get_proto_derive_from_with(field);
189
190 if let Some(from_with_path) = derive_from_with {
191 let from_with_path: syn::Path = syn::parse_str(&from_with_path).expect("Failed to parse derive_from_with path");
192 quote! {
193 #field_name: #from_with_path(proto_struct.#proto_field_ident)
194 }
195 } else if is_transparent {
196 quote! {
197 #field_name: <#field_type>::from(proto_struct.#proto_field_ident)
198 }
199 } else if is_option_type(field_type) {
200 let inner_type = get_inner_type_from_option(field_type).unwrap();
201 if is_vec_type(&inner_type) {
202 quote! {
203 #field_name: proto_struct.#proto_field_ident.into_iter().map(Into::into).collect()
204 }
205 } else {
206 quote! {
207 #field_name: proto_struct.#proto_field_ident.map(Into::into)
208 }
209 }
210 } else if is_vec_type(field_type) {
211 if let Some(inner_type) = get_inner_type_from_vec(field_type) {
212 if is_proto_type_with_module(&inner_type, &proto_module) {
213 quote! {
214 #field_name: proto_struct.#proto_field_ident
215 }
216 } else {
217 quote! {
218 #field_name: proto_struct.#proto_field_ident.into_iter().map(Into::into).collect()
219 }
220 }
221 } else {
222 quote! {
223 #field_name: proto_struct.#proto_field_ident.into_iter().map(Into::into).collect()
224 }
225 }
226 } else if let syn::Type::Path(type_path) = field_type {
227 let is_primitive = type_path.path.segments.len() == 1 &&
228 primitives.iter().any(|&p| type_path.path.segments[0].ident == p);
229 let is_proto_type = type_path.path.segments.first()
230 .is_some_and(|segment| segment.ident == proto_module.as_str());
231 if is_primitive {
232 quote! { #field_name: proto_struct.#proto_field_ident }
233 } else if is_proto_type {
234 quote! {
235 #field_name: proto_struct.#proto_field_ident.expect(concat!("no ", stringify!(#proto_field_ident), " in proto"))
236 }
237 } else {
238 quote! {
239 #field_name: #field_type::from(proto_struct.#proto_field_ident)
240 }
241 }
242 } else {
243 panic!("Only path types are supported for field '{}'", field_name);
244 }
245 }
246 });
247
248 let from_my_fields = fields.iter().filter(|field| !has_proto_ignore(field)).map(|field| {
249 let field_name = field.ident.as_ref().unwrap();
250 let proto_field_ident = if let Some(rename) = get_proto_rename(field) {
251 syn::Ident::new(&rename, Span::call_site())
252 } else {
253 field_name.clone()
254 };
255 let field_type = &field.ty;
256 let is_transparent = has_transparent_attr(field);
257 let derive_into_with = get_proto_derive_into_with(field);
258
259 if let Some(into_with_path) = derive_into_with {
260 let into_with_path: syn::Path = syn::parse_str(&into_with_path).expect("Failed to parse derive_into_with path");
261 quote! {
262 #proto_field_ident: #into_with_path(my_struct.#field_name)
263 }
264 } else if is_transparent {
265 quote! {
266 #proto_field_ident: my_struct.#field_name.into()
267 }
268 } else if is_option_type(field_type) {
269 let inner_type = get_inner_type_from_option(field_type).unwrap();
270 if is_vec_type(&inner_type) {
271 quote! {
272 #proto_field_ident: my_struct.#field_name.into_iter().map(Into::into).collect()
273 }
274 } else {
275 quote! {
276 #proto_field_ident: my_struct.#field_name.map(Into::into)
277 }
278 }
279 } else if is_vec_type(field_type) {
280 if let Some(inner_type) = get_inner_type_from_vec(field_type) {
281 if is_proto_type_with_module(&inner_type, &proto_module) {
282 quote! {
283 #proto_field_ident: my_struct.#field_name
284 }
285 } else {
286 quote! {
287 #proto_field_ident: my_struct.#field_name.into_iter().map(Into::into).collect()
288 }
289 }
290 } else {
291 quote! {
292 #proto_field_ident: my_struct.#field_name.into_iter().map(Into::into).collect()
293 }
294 }
295 } else if let syn::Type::Path(type_path) = field_type {
296 let is_primitive = type_path.path.segments.len() == 1
297 && primitives.iter().any(|&p| type_path.path.segments[0].ident == p);
298 let is_proto_type = type_path.path.segments.first()
299 .is_some_and(|segment| segment.ident == proto_module.as_str());
300 if is_primitive {
301 quote! { #proto_field_ident: my_struct.#field_name }
302 } else if is_proto_type {
303 quote! { #proto_field_ident: Some(my_struct.#field_name) }
304 } else {
305 quote! { #proto_field_ident: my_struct.#field_name.into() }
306 }
307 } else {
308 panic!("Only path types are supported for field '{}'", field_name);
309 }
310 });
311
312 let gen = quote! {
313 impl From<#proto_path> for #name {
314 fn from(proto_struct: #proto_path) -> Self {
315 Self {
316 #(#from_proto_fields),*
317 }
318 }
319 }
320
321 impl From<#name> for #proto_path {
322 fn from(my_struct: #name) -> Self {
323 Self {
324 #(#from_my_fields),*
325 }
326 }
327 }
328 };
329 gen.into()
330 }
331 syn::Fields::Unnamed(fields_unnamed) => {
332 if fields_unnamed.unnamed.len() != 1 {
333 panic!("ProtoConvert only supports tuple structs with exactly one field, found {}", fields_unnamed.unnamed.len());
334 }
335 let inner_type = &fields_unnamed.unnamed[0].ty;
336 let gen = quote! {
337 impl From<#inner_type> for #name {
338 fn from(value: #inner_type) -> Self {
339 #name(value)
340 }
341 }
342
343 impl From<#name> for #inner_type {
344 fn from(my: #name) -> Self {
345 my.0
346 }
347 }
348 };
349 gen.into()
350 }
351 syn::Fields::Unit => {
352 panic!("ProtoConvert does not support unit structs");
353 }
354 }
355 }
356
357 syn::Data::Enum(data_enum) => {
358 let variants = &data_enum.variants;
359 let enum_name_str = name.to_string();
360 let enum_prefix = enum_name_str.to_uppercase();
361 let proto_enum_path: syn::Path = syn::parse_str(&format!("{}::{}", proto_module, name))
362 .expect("Failed to parse proto enum path");
363
364 let from_i32_arms = variants.iter().map(|variant| {
365 let variant_ident = &variant.ident;
366 let variant_str = variant_ident.to_string();
367 let direct_candidate = variant_str.clone();
368 let screaming_variant = to_screaming_snake_case(&variant_str);
369 let prefixed_candidate = format!("{}_{}", enum_prefix, screaming_variant);
370 let direct_candidate_lit = syn::LitStr::new(&direct_candidate, Span::call_site());
371 let prefixed_candidate_lit = syn::LitStr::new(&prefixed_candidate, Span::call_site());
372 quote! {
373 candidate if candidate == #direct_candidate_lit || candidate == #prefixed_candidate_lit => #name::#variant_ident,
374 }
375 });
376
377 let from_proto_arms = variants.iter().map(|variant| {
378 let variant_ident = &variant.ident;
379 let variant_str = variant_ident.to_string();
380 let screaming_variant = to_screaming_snake_case(&variant_str);
381 let prefixed_candidate = format!("{}_{}", enum_prefix, screaming_variant);
382 let prefixed_candidate_lit = syn::LitStr::new(&prefixed_candidate, Span::call_site());
383 quote! {
384 #name::#variant_ident => <#proto_enum_path>::from_str_name(#prefixed_candidate_lit)
385 .unwrap_or_else(|| panic!("No matching proto variant for {:?}", rust_enum)),
386 }
387 });
388
389 let gen = quote! {
390 impl From<i32> for #name {
391 fn from(value: i32) -> Self {
392 let proto_val = <#proto_enum_path>::from_i32(value)
393 .unwrap_or_else(|| panic!("Unknown enum value: {}", value));
394 let proto_str = proto_val.as_str_name();
395 match proto_str {
396 #(#from_i32_arms)*
397 _ => panic!("No matching Rust variant for proto enum string: {}", proto_str),
398 }
399 }
400 }
401
402 impl From<#name> for i32 {
403 fn from(rust_enum: #name) -> Self {
404 let proto: #proto_enum_path = rust_enum.into();
405 proto as i32
406 }
407 }
408
409 impl From<#name> for #proto_enum_path {
410 fn from(rust_enum: #name) -> Self {
411 match rust_enum {
412 #(#from_proto_arms)*
413 }
414 }
415 }
416
417 impl From<#proto_enum_path> for #name {
418 fn from(proto_enum: #proto_enum_path) -> Self {
419 let i32_val: i32 = proto_enum.into();
420 #name::from(i32_val)
421 }
422 }
423 };
424 gen.into()
425 }
426 _ => panic!("ProtoConvert only supports structs and enums, not unions"),
427 }
428}
429
430fn to_screaming_snake_case(s: &str) -> String {
431 let mut result = String::new();
432 for (i, c) in s.chars().enumerate() {
433 if c.is_uppercase() && i != 0 {
434 result.push('_');
435 }
436 result.push(c.to_ascii_uppercase());
437 }
438 result
439}
440
441fn is_option_type(ty: &Type) -> bool {
442 if let Type::Path(type_path) = ty {
443 if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Option" {
444 return true;
445 }
446 }
447 false
448}
449
450fn get_inner_type_from_option(ty: &Type) -> Option<Type> {
451 if let Type::Path(type_path) = ty {
452 if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Option" {
453 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
454 &type_path.path.segments[0].arguments
455 {
456 if let Some(syn::GenericArgument::Type(inner_type)) = angle_bracketed.args.first() {
457 return Some(inner_type.clone());
458 }
459 }
460 }
461 }
462 None
463}
464
465fn is_vec_type(ty: &Type) -> bool {
466 if let Type::Path(type_path) = ty {
467 if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Vec" {
468 return true;
469 }
470 }
471 false
472}
473
474fn get_inner_type_from_vec(ty: &Type) -> Option<Type> {
475 if let Type::Path(type_path) = ty {
476 if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Vec" {
477 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
478 &type_path.path.segments[0].arguments
479 {
480 if let Some(syn::GenericArgument::Type(inner_type)) = angle_bracketed.args.first() {
481 return Some(inner_type.clone());
482 }
483 }
484 }
485 }
486 None
487}
488
489fn is_proto_type_with_module(ty: &Type, proto_module: &str) -> bool {
490 if let Type::Path(type_path) = ty {
491 if let Some(segment) = type_path.path.segments.first() {
492 return segment.ident == proto_module;
493 }
494 }
495 false
496}
497
498fn get_proto_module(attrs: &[Attribute]) -> Option<String> {
499 for attr in attrs {
500 if attr.path().is_ident("proto") {
501 if let Meta::List(meta_list) = &attr.meta {
502 let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
503 .parse2(meta_list.tokens.clone())
504 .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
505 for meta in nested_metas {
506 if let Meta::NameValue(meta_nv) = meta {
507 if meta_nv.path.is_ident("module") {
508 if let Expr::Lit(expr_lit) = meta_nv.value {
509 if let Lit::Str(lit_str) = expr_lit.lit {
510 return Some(lit_str.value());
511 }
512 }
513 panic!("module value must be a string literal, e.g., #[proto(module = \"path\")]");
514 }
515 }
516 }
517 }
518 }
519 }
520 None
521}
522
523fn get_proto_struct_rename(attrs: &[Attribute]) -> Option<String> {
524 for attr in attrs {
525 if attr.path().is_ident("proto") {
526 if let Meta::List(meta_list) = &attr.meta {
527 let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
528 .parse2(meta_list.tokens.clone())
529 .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
530 for meta in nested_metas {
531 if let Meta::NameValue(meta_nv) = meta {
532 if meta_nv.path.is_ident("rename") {
533 if let Expr::Lit(expr_lit) = meta_nv.value {
534 if let Lit::Str(lit_str) = expr_lit.lit {
535 return Some(lit_str.value());
536 }
537 }
538 panic!("rename value must be a string literal, e.g., #[proto(rename = \"...\")]");
539 }
540 }
541 }
542 }
543 }
544 }
545 None
546}
547
548fn has_transparent_attr(field: &Field) -> bool {
549 for attr in &field.attrs {
550 if attr.path().is_ident("proto") {
551 if let Meta::List(meta_list) = &attr.meta {
552 let tokens = &meta_list.tokens;
553 let token_str = quote!(#tokens).to_string();
554 if token_str.contains("transparent") {
555 return true;
556 }
557 }
558 }
559 }
560 false
561}
562
563fn get_proto_rename(field: &Field) -> Option<String> {
564 for attr in &field.attrs {
565 if attr.path().is_ident("proto") {
566 if let Meta::List(meta_list) = &attr.meta {
567 let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
568 .parse2(meta_list.tokens.clone())
569 .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
570 for meta in nested_metas {
571 if let Meta::NameValue(meta_nv) = meta {
572 if meta_nv.path.is_ident("rename") {
573 if let Expr::Lit(expr_lit) = &meta_nv.value {
574 if let Lit::Str(lit_str) = &expr_lit.lit {
575 return Some(lit_str.value());
576 }
577 }
578 panic!("rename value must be a string literal, e.g., rename = \"xyz\"");
579 }
580 }
581 }
582 }
583 }
584 }
585 None
586}
587
588fn get_proto_derive_from_with(field: &Field) -> Option<String> {
589 for attr in &field.attrs {
590 if attr.path().is_ident("proto") {
591 if let Meta::List(meta_list) = &attr.meta {
592 let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
593 .parse2(meta_list.tokens.clone())
594 .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
595 for meta in nested_metas {
596 if let Meta::NameValue(meta_nv) = meta {
597 if meta_nv.path.is_ident("derive_from_with") {
598 if let Expr::Lit(expr_lit) = &meta_nv.value {
599 if let Lit::Str(lit_str) = &expr_lit.lit {
600 return Some(lit_str.value());
601 }
602 }
603 panic!("derive_from_with value must be a string literal, e.g., derive_from_with = \"path::to::function\"");
604 }
605 }
606 }
607 }
608 }
609 }
610 None
611}
612
613fn get_proto_derive_into_with(field: &Field) -> Option<String> {
614 for attr in &field.attrs {
615 if attr.path().is_ident("proto") {
616 if let Meta::List(meta_list) = &attr.meta {
617 let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
618 .parse2(meta_list.tokens.clone())
619 .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
620 for meta in nested_metas {
621 if let Meta::NameValue(meta_nv) = meta {
622 if meta_nv.path.is_ident("derive_into_with") {
623 if let Expr::Lit(expr_lit) = &meta_nv.value {
624 if let Lit::Str(lit_str) = &expr_lit.lit {
625 return Some(lit_str.value());
626 }
627 }
628 panic!("derive_into_with value must be a string literal, e.g., derive_into_with = \"path::to::function\"");
629 }
630 }
631 }
632 }
633 }
634 }
635 None
636}
637
638fn has_proto_ignore(field: &Field) -> bool {
639 for attr in &field.attrs {
640 if attr.path().is_ident("proto") {
641 if let Meta::List(meta_list) = &attr.meta {
642 let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
643 .parse2(meta_list.tokens.clone())
644 .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
645 for meta in nested_metas {
646 if let Meta::Path(path) = meta {
647 if path.is_ident("ignore") {
648 return true;
649 }
650 }
651 }
652 }
653 }
654 }
655 false
656}