1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr};
5
6#[derive(Clone, Copy, PartialEq)]
7enum Style {
8 App,
9 Record,
10}
11
12struct HaskellAttrs {
13 name: Option<String>,
14 transparent: bool,
15 style: Option<Style>,
16 skip: bool,
17 bound_to: Option<String>,
18 bound_from: Option<String>,
19}
20
21fn parse_haskell_attrs(attrs: &[syn::Attribute]) -> HaskellAttrs {
22 let mut name = None;
23 let mut transparent = false;
24 let mut style = None;
25 let mut skip = false;
26 let mut bound_to = None;
27 let mut bound_from = None;
28 for attr in attrs {
29 if attr.path().is_ident("haskell") {
30 let _ = attr.parse_nested_meta(|meta| {
31 if meta.path.is_ident("name") {
32 let value = meta.value()?;
33 let s: LitStr = value.parse()?;
34 name = Some(s.value());
35 } else if meta.path.is_ident("transparent") {
36 transparent = true;
37 } else if meta.path.is_ident("style") {
38 let value = meta.value()?;
39 let s: LitStr = value.parse()?;
40 match s.value().as_str() {
41 "app" => style = Some(Style::App),
42 "record" => style = Some(Style::Record),
43 _ => {
44 return Err(meta.error("expected \"app\" or \"record\""));
45 }
46 }
47 } else if meta.path.is_ident("skip") {
48 skip = true;
49 } else if meta.path.is_ident("bound") {
50 meta.parse_nested_meta(|inner| {
51 if inner.path.is_ident("ToHaskell") {
52 let value = inner.value()?;
53 let s: LitStr = value.parse()?;
54 bound_to = Some(s.value());
55 } else if inner.path.is_ident("FromHaskell") {
56 let value = inner.value()?;
57 let s: LitStr = value.parse()?;
58 bound_from = Some(s.value());
59 } else {
60 return Err(inner.error("expected `ToHaskell` or `FromHaskell`"));
61 }
62 Ok(())
63 })?;
64 }
65 Ok(())
66 });
67 }
68 }
69 HaskellAttrs {
70 name,
71 transparent,
72 style,
73 skip,
74 bound_to,
75 bound_from,
76 }
77}
78
79fn add_trait_bounds(mut generics: syn::Generics, trait_path: &TokenStream2) -> syn::Generics {
80 for param in &mut generics.params {
81 if let GenericParam::Type(type_param) = param {
82 type_param.bounds.push(syn::parse_quote!(#trait_path));
83 }
84 }
85 generics
86}
87
88fn apply_custom_bounds(
89 generics: &syn::Generics,
90 bound_str: &str,
91) -> syn::Result<(TokenStream2, TokenStream2, TokenStream2)> {
92 let (impl_generics, ty_generics, _) = generics.split_for_impl();
93 let impl_generics = quote! { #impl_generics };
94 let ty_generics = quote! { #ty_generics };
95 let predicates: TokenStream2 = bound_str.parse().map_err(|e| {
96 syn::Error::new(
97 proc_macro2::Span::call_site(),
98 format!("failed to parse bound: {e}"),
99 )
100 })?;
101 let where_clause = quote! { where #predicates };
102 Ok((impl_generics, ty_generics, where_clause))
103}
104
105fn resolve_style(style: Option<Style>) -> Style {
107 style.unwrap_or(Style::Record)
108}
109
110#[proc_macro_derive(ToHaskell, attributes(haskell))]
133pub fn derive_to_haskell(input: TokenStream) -> TokenStream {
134 let input = parse_macro_input!(input as DeriveInput);
135 match expand_to_haskell(input) {
136 Ok(ts) => ts.into(),
137 Err(e) => e.to_compile_error().into(),
138 }
139}
140
141fn expand_to_haskell(input: DeriveInput) -> syn::Result<TokenStream2> {
142 let attrs = parse_haskell_attrs(&input.attrs);
143 let ident = &input.ident;
144 let haskell_name = attrs.name.unwrap_or_else(|| ident.to_string());
145 let container_style = attrs.style;
146
147 let (impl_generics, ty_generics, where_clause) = if let Some(ref bound) = attrs.bound_to {
148 let r = apply_custom_bounds(&input.generics, bound)?;
149 (r.0, r.1, r.2)
150 } else {
151 let trait_path: TokenStream2 = quote!(::ghci::ToHaskell);
152 let generics = add_trait_bounds(input.generics.clone(), &trait_path);
153 let (ig, tg, wc) = generics.split_for_impl();
154 (quote! { #ig }, quote! { #tg }, quote! { #wc })
155 };
156
157 let body = match &input.data {
158 Data::Struct(s) => {
159 if attrs.transparent {
160 expand_to_haskell_transparent_struct(&s.fields)?
161 } else {
162 expand_to_haskell_struct(&s.fields, &haskell_name, container_style)?
163 }
164 }
165 Data::Enum(e) => {
166 if attrs.transparent {
167 return Err(syn::Error::new_spanned(
168 ident,
169 "`#[haskell(transparent)]` is not supported on enums",
170 ));
171 }
172 let arms = e
173 .variants
174 .iter()
175 .map(|v| {
176 let variant_ident = &v.ident;
177 let variant_attrs = parse_haskell_attrs(&v.attrs);
178 if variant_attrs.transparent {
179 return expand_to_haskell_transparent_variant(
180 ident,
181 variant_ident,
182 &v.fields,
183 );
184 }
185 let variant_name = variant_attrs
186 .name
187 .unwrap_or_else(|| variant_ident.to_string());
188 let variant_style = variant_attrs.style.or(container_style);
189 expand_to_haskell_variant(
190 ident,
191 variant_ident,
192 &v.fields,
193 &variant_name,
194 variant_style,
195 )
196 })
197 .collect::<syn::Result<Vec<_>>>()?;
198 quote! {
199 fn write_haskell(&self, buf: &mut impl ::std::fmt::Write) -> ::std::fmt::Result {
200 match self {
201 #(#arms)*
202 }
203 }
204 }
205 }
206 Data::Union(u) => {
207 return Err(syn::Error::new_spanned(
208 u.union_token,
209 "`ToHaskell` cannot be derived for unions",
210 ));
211 }
212 };
213
214 Ok(quote! {
215 #[automatically_derived]
216 impl #impl_generics ::ghci::ToHaskell for #ident #ty_generics #where_clause {
217 #body
218 }
219 })
220}
221
222fn expand_to_haskell_struct(
223 fields: &Fields,
224 haskell_name: &str,
225 style: Option<Style>,
226) -> syn::Result<TokenStream2> {
227 match fields {
228 Fields::Named(named) => {
229 let effective_style = resolve_style(style);
230 if effective_style == Style::App {
231 let arg_calls = named
232 .named
233 .iter()
234 .filter(|f| !parse_haskell_attrs(&f.attrs).skip)
235 .map(|f| {
236 let field_ident = f.ident.as_ref().unwrap();
237 quote! { .arg(&self.#field_ident) }
238 });
239 Ok(quote! {
240 fn write_haskell(&self, buf: &mut impl ::std::fmt::Write) -> ::std::fmt::Result {
241 ::ghci::haskell::app(buf, #haskell_name)
242 #(#arg_calls)*
243 .finish()
244 }
245 })
246 } else {
247 let field_calls = named
248 .named
249 .iter()
250 .filter(|f| !parse_haskell_attrs(&f.attrs).skip)
251 .map(|f| {
252 let field_ident = f.ident.as_ref().unwrap();
253 let fattrs = parse_haskell_attrs(&f.attrs);
254 let field_name = fattrs.name.unwrap_or_else(|| field_ident.to_string());
255 quote! { .field(#field_name, &self.#field_ident) }
256 });
257 Ok(quote! {
258 fn write_haskell(&self, buf: &mut impl ::std::fmt::Write) -> ::std::fmt::Result {
259 ::ghci::haskell::record(buf, #haskell_name)
260 #(#field_calls)*
261 .finish()
262 }
263 })
264 }
265 }
266 Fields::Unnamed(unnamed) => {
267 let arg_calls = unnamed.unnamed.iter().enumerate().map(|(i, _)| {
268 let index = syn::Index::from(i);
269 quote! { .arg(&self.#index) }
270 });
271 Ok(quote! {
272 fn write_haskell(&self, buf: &mut impl ::std::fmt::Write) -> ::std::fmt::Result {
273 ::ghci::haskell::app(buf, #haskell_name)
274 #(#arg_calls)*
275 .finish()
276 }
277 })
278 }
279 Fields::Unit => Ok(quote! {
280 fn write_haskell(&self, buf: &mut impl ::std::fmt::Write) -> ::std::fmt::Result {
281 buf.write_str(#haskell_name)
282 }
283 }),
284 }
285}
286
287fn expand_to_haskell_transparent_struct(fields: &Fields) -> syn::Result<TokenStream2> {
288 match fields {
289 Fields::Named(named) if named.named.len() == 1 => {
290 let field_ident = named.named[0].ident.as_ref().unwrap();
291 Ok(quote! {
292 fn write_haskell(&self, buf: &mut impl ::std::fmt::Write) -> ::std::fmt::Result {
293 ::ghci::ToHaskell::write_haskell(&self.#field_ident, buf)
294 }
295 })
296 }
297 Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => Ok(quote! {
298 fn write_haskell(&self, buf: &mut impl ::std::fmt::Write) -> ::std::fmt::Result {
299 ::ghci::ToHaskell::write_haskell(&self.0, buf)
300 }
301 }),
302 _ => Err(syn::Error::new(
303 proc_macro2::Span::call_site(),
304 "`#[haskell(transparent)]` requires exactly one field",
305 )),
306 }
307}
308
309fn expand_to_haskell_transparent_variant(
310 enum_ident: &syn::Ident,
311 variant_ident: &syn::Ident,
312 fields: &Fields,
313) -> syn::Result<TokenStream2> {
314 match fields {
315 Fields::Named(named) if named.named.len() == 1 => {
316 let field_ident = named.named[0].ident.as_ref().unwrap();
317 Ok(quote! {
318 #enum_ident::#variant_ident { #field_ident } => {
319 ::ghci::ToHaskell::write_haskell(#field_ident, buf)
320 }
321 })
322 }
323 Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => Ok(quote! {
324 #enum_ident::#variant_ident(__f0) => {
325 ::ghci::ToHaskell::write_haskell(__f0, buf)
326 }
327 }),
328 _ => Err(syn::Error::new_spanned(
329 variant_ident,
330 "`#[haskell(transparent)]` requires exactly one field",
331 )),
332 }
333}
334
335fn expand_to_haskell_variant(
336 enum_ident: &syn::Ident,
337 variant_ident: &syn::Ident,
338 fields: &Fields,
339 haskell_name: &str,
340 style: Option<Style>,
341) -> syn::Result<TokenStream2> {
342 match fields {
343 Fields::Named(named) => {
344 let effective_style = resolve_style(style);
345 let all_field_idents: Vec<_> = named
346 .named
347 .iter()
348 .map(|f| f.ident.as_ref().unwrap())
349 .collect();
350 if effective_style == Style::App {
351 let non_skipped: Vec<_> = named
352 .named
353 .iter()
354 .filter(|f| !parse_haskell_attrs(&f.attrs).skip)
355 .map(|f| f.ident.as_ref().unwrap())
356 .collect();
357 Ok(quote! {
358 #enum_ident::#variant_ident { #(#all_field_idents),* } => {
359 ::ghci::haskell::app(buf, #haskell_name)
360 #(.arg(#non_skipped))*
361 .finish()
362 }
363 })
364 } else {
365 let field_stmts: Vec<_> = named
366 .named
367 .iter()
368 .filter(|f| !parse_haskell_attrs(&f.attrs).skip)
369 .map(|f| {
370 let fattrs = parse_haskell_attrs(&f.attrs);
371 let field_ident = f.ident.as_ref().unwrap();
372 let field_name = fattrs.name.unwrap_or_else(|| field_ident.to_string());
373 (field_name, field_ident)
374 })
375 .collect();
376 let field_names: Vec<_> = field_stmts.iter().map(|(n, _)| n.clone()).collect();
377 let field_idents: Vec<_> = field_stmts.iter().map(|(_, i)| *i).collect();
378 Ok(quote! {
379 #enum_ident::#variant_ident { #(#all_field_idents),* } => {
380 ::ghci::haskell::record(buf, #haskell_name)
381 #(.field(#field_names, #field_idents))*
382 .finish()
383 }
384 })
385 }
386 }
387 Fields::Unnamed(unnamed) => {
388 let vars: Vec<_> = (0..unnamed.unnamed.len())
389 .map(|i| format_ident!("__f{}", i))
390 .collect();
391 Ok(quote! {
392 #enum_ident::#variant_ident(#(#vars),*) => {
393 ::ghci::haskell::app(buf, #haskell_name)
394 #(.arg(#vars))*
395 .finish()
396 }
397 })
398 }
399 Fields::Unit => Ok(quote! {
400 #enum_ident::#variant_ident => buf.write_str(#haskell_name),
401 }),
402 }
403}
404
405#[proc_macro_derive(FromHaskell, attributes(haskell))]
431pub fn derive_from_haskell(input: TokenStream) -> TokenStream {
432 let input = parse_macro_input!(input as DeriveInput);
433 match expand_from_haskell(input) {
434 Ok(ts) => ts.into(),
435 Err(e) => e.to_compile_error().into(),
436 }
437}
438
439fn expand_from_haskell(input: DeriveInput) -> syn::Result<TokenStream2> {
440 let attrs = parse_haskell_attrs(&input.attrs);
441 let ident = &input.ident;
442 let haskell_name = attrs.name.unwrap_or_else(|| ident.to_string());
443 let container_style = attrs.style;
444
445 let (impl_generics, ty_generics, where_clause) = if let Some(ref bound) = attrs.bound_from {
446 let r = apply_custom_bounds(&input.generics, bound)?;
447 (r.0, r.1, r.2)
448 } else {
449 let trait_path: TokenStream2 = quote!(::ghci::FromHaskell);
450 let generics = add_trait_bounds(input.generics.clone(), &trait_path);
451 let (ig, tg, wc) = generics.split_for_impl();
452 (quote! { #ig }, quote! { #tg }, quote! { #wc })
453 };
454
455 let body = match &input.data {
456 Data::Struct(s) => {
457 if attrs.transparent {
458 expand_from_haskell_transparent_struct(&s.fields)?
459 } else {
460 expand_from_haskell_struct(&s.fields, &haskell_name, container_style)?
461 }
462 }
463 Data::Enum(e) => {
464 if attrs.transparent {
465 return Err(syn::Error::new_spanned(
466 ident,
467 "`#[haskell(transparent)]` is not supported on enums",
468 ));
469 }
470 let type_name = ident.to_string();
471 let tries = e
472 .variants
473 .iter()
474 .map(|v| {
475 let variant_ident = &v.ident;
476 let variant_attrs = parse_haskell_attrs(&v.attrs);
477 if variant_attrs.transparent {
478 return expand_from_haskell_transparent_variant(
479 ident,
480 variant_ident,
481 &v.fields,
482 );
483 }
484 let variant_name = variant_attrs
485 .name
486 .unwrap_or_else(|| variant_ident.to_string());
487 let variant_style = variant_attrs.style.or(container_style);
488 expand_from_haskell_variant(
489 ident,
490 variant_ident,
491 &v.fields,
492 &variant_name,
493 variant_style,
494 )
495 })
496 .collect::<syn::Result<Vec<_>>>()?;
497 quote! {
498 fn parse_haskell(input: &str) -> ::core::result::Result<(Self, &str), ::ghci::HaskellParseError> {
499 #(#tries)*
500 ::core::result::Result::Err(::ghci::HaskellParseError::ParseError {
501 message: ::std::format!("failed to parse {} from {:?}", #type_name, input),
502 })
503 }
504 }
505 }
506 Data::Union(u) => {
507 return Err(syn::Error::new_spanned(
508 u.union_token,
509 "`FromHaskell` cannot be derived for unions",
510 ));
511 }
512 };
513
514 Ok(quote! {
515 #[automatically_derived]
516 impl #impl_generics ::ghci::FromHaskell for #ident #ty_generics #where_clause {
517 #body
518 }
519 })
520}
521
522fn expand_from_haskell_struct(
523 fields: &Fields,
524 haskell_name: &str,
525 style: Option<Style>,
526) -> syn::Result<TokenStream2> {
527 match fields {
528 Fields::Named(named) => {
529 let effective_style = resolve_style(style);
530 if effective_style == Style::App {
531 let field_inits: Vec<_> = named
532 .named
533 .iter()
534 .map(|f| {
535 let field_ident = f.ident.as_ref().unwrap();
536 let fattrs = parse_haskell_attrs(&f.attrs);
537 if fattrs.skip {
538 quote! { #field_ident: ::core::default::Default::default() }
539 } else {
540 quote! { #field_ident: __p.arg()? }
541 }
542 })
543 .collect();
544 Ok(quote! {
545 fn parse_haskell(input: &str) -> ::core::result::Result<(Self, &str), ::ghci::HaskellParseError> {
546 let mut __p = ::ghci::haskell::parse_app(#haskell_name, input)?;
547 let __val = Self { #(#field_inits),* };
548 let rest = __p.finish()?;
549 ::core::result::Result::Ok((__val, rest))
550 }
551 })
552 } else {
553 let field_inits = named.named.iter().map(|f| {
554 let field_ident = f.ident.as_ref().unwrap();
555 let fattrs = parse_haskell_attrs(&f.attrs);
556 if fattrs.skip {
557 quote! { #field_ident: ::core::default::Default::default() }
558 } else {
559 let field_name = fattrs.name.unwrap_or_else(|| field_ident.to_string());
560 quote! { #field_ident: rec.field(#field_name)? }
561 }
562 });
563 Ok(quote! {
564 fn parse_haskell(input: &str) -> ::core::result::Result<(Self, &str), ::ghci::HaskellParseError> {
565 let (rec, rest) = ::ghci::haskell::parse_record(#haskell_name, input)?;
566 ::core::result::Result::Ok((Self { #(#field_inits),* }, rest))
567 }
568 })
569 }
570 }
571 Fields::Unnamed(unnamed) => {
572 let vars: Vec<_> = (0..unnamed.unnamed.len())
573 .map(|i| format_ident!("__f{}", i))
574 .collect();
575 Ok(quote! {
576 fn parse_haskell(input: &str) -> ::core::result::Result<(Self, &str), ::ghci::HaskellParseError> {
577 let mut __p = ::ghci::haskell::parse_app(#haskell_name, input)?;
578 #(let #vars = __p.arg()?;)*
579 let rest = __p.finish()?;
580 ::core::result::Result::Ok((Self(#(#vars),*), rest))
581 }
582 })
583 }
584 Fields::Unit => Ok(quote! {
585 fn parse_haskell(input: &str) -> ::core::result::Result<(Self, &str), ::ghci::HaskellParseError> {
586 let mut __p = ::ghci::haskell::parse_app(#haskell_name, input)?;
587 let rest = __p.finish()?;
588 ::core::result::Result::Ok((Self, rest))
589 }
590 }),
591 }
592}
593
594fn expand_from_haskell_transparent_struct(fields: &Fields) -> syn::Result<TokenStream2> {
595 match fields {
596 Fields::Named(named) if named.named.len() == 1 => {
597 let field_ident = named.named[0].ident.as_ref().unwrap();
598 Ok(quote! {
599 fn parse_haskell(input: &str) -> ::core::result::Result<(Self, &str), ::ghci::HaskellParseError> {
600 let (val, rest) = ::ghci::FromHaskell::parse_haskell(input)?;
601 ::core::result::Result::Ok((Self { #field_ident: val }, rest))
602 }
603 })
604 }
605 Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => Ok(quote! {
606 fn parse_haskell(input: &str) -> ::core::result::Result<(Self, &str), ::ghci::HaskellParseError> {
607 let (val, rest) = ::ghci::FromHaskell::parse_haskell(input)?;
608 ::core::result::Result::Ok((Self(val), rest))
609 }
610 }),
611 _ => Err(syn::Error::new(
612 proc_macro2::Span::call_site(),
613 "`#[haskell(transparent)]` requires exactly one field",
614 )),
615 }
616}
617
618fn expand_from_haskell_transparent_variant(
619 enum_ident: &syn::Ident,
620 variant_ident: &syn::Ident,
621 fields: &Fields,
622) -> syn::Result<TokenStream2> {
623 match fields {
624 Fields::Named(named) if named.named.len() == 1 => {
625 let field_ident = named.named[0].ident.as_ref().unwrap();
626 Ok(quote! {
627 if let ::core::result::Result::Ok((val, rest)) = ::ghci::FromHaskell::parse_haskell(input) {
628 return ::core::result::Result::Ok((#enum_ident::#variant_ident { #field_ident: val }, rest));
629 }
630 })
631 }
632 Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => Ok(quote! {
633 if let ::core::result::Result::Ok((val, rest)) = ::ghci::FromHaskell::parse_haskell(input) {
634 return ::core::result::Result::Ok((#enum_ident::#variant_ident(val), rest));
635 }
636 }),
637 _ => Err(syn::Error::new_spanned(
638 variant_ident,
639 "`#[haskell(transparent)]` requires exactly one field",
640 )),
641 }
642}
643
644fn expand_from_haskell_variant(
645 enum_ident: &syn::Ident,
646 variant_ident: &syn::Ident,
647 fields: &Fields,
648 haskell_name: &str,
649 style: Option<Style>,
650) -> syn::Result<TokenStream2> {
651 match fields {
652 Fields::Named(named) => {
653 let effective_style = resolve_style(style);
654 if effective_style == Style::App {
655 let field_inits: Vec<_> = named
656 .named
657 .iter()
658 .map(|f| {
659 let field_ident = f.ident.as_ref().unwrap();
660 let fattrs = parse_haskell_attrs(&f.attrs);
661 if fattrs.skip {
662 quote! { #field_ident: ::core::default::Default::default() }
663 } else {
664 quote! { #field_ident: __p.arg()? }
665 }
666 })
667 .collect();
668 Ok(quote! {
669 if let ::core::result::Result::Ok(mut __p) = ::ghci::haskell::parse_app(#haskell_name, input) {
670 let __val = #enum_ident::#variant_ident { #(#field_inits),* };
671 let rest = __p.finish()?;
672 return ::core::result::Result::Ok((__val, rest));
673 }
674 })
675 } else {
676 let field_inits: Vec<_> = named
677 .named
678 .iter()
679 .map(|f| {
680 let field_ident = f.ident.as_ref().unwrap();
681 let fattrs = parse_haskell_attrs(&f.attrs);
682 if fattrs.skip {
683 quote! { #field_ident: ::core::default::Default::default() }
684 } else {
685 let field_name = fattrs
686 .name
687 .unwrap_or_else(|| f.ident.as_ref().unwrap().to_string());
688 quote! { #field_ident: rec.field(#field_name)? }
689 }
690 })
691 .collect();
692 Ok(quote! {
693 if let ::core::result::Result::Ok((rec, rest)) = ::ghci::haskell::parse_record(#haskell_name, input) {
694 return ::core::result::Result::Ok((
695 #enum_ident::#variant_ident { #(#field_inits),* },
696 rest,
697 ));
698 }
699 })
700 }
701 }
702 Fields::Unnamed(unnamed) => {
703 let vars: Vec<_> = (0..unnamed.unnamed.len())
704 .map(|i| format_ident!("__f{}", i))
705 .collect();
706 Ok(quote! {
707 if let ::core::result::Result::Ok(mut __p) = ::ghci::haskell::parse_app(#haskell_name, input) {
708 #(let #vars = __p.arg()?;)*
709 let rest = __p.finish()?;
710 return ::core::result::Result::Ok((#enum_ident::#variant_ident(#(#vars),*), rest));
711 }
712 })
713 }
714 Fields::Unit => Ok(quote! {
715 if let ::core::result::Result::Ok(mut __p) = ::ghci::haskell::parse_app(#haskell_name, input) {
716 let rest = __p.finish()?;
717 return ::core::result::Result::Ok((#enum_ident::#variant_ident, rest));
718 }
719 }),
720 }
721}