runa_wayland_scanner_codegen/
lib.rs

1use std::borrow::Cow;
2
3use ahash::AHashMap as HashMap;
4use heck::{ToPascalCase, ToShoutySnekCase};
5use proc_macro2::{Ident, Span, TokenStream};
6use quote::{format_ident, quote};
7use spec_parser::protocol::{Enum, Interface, Message, Protocol};
8use syn::Lifetime;
9use thiserror::Error;
10#[derive(Error, Debug)]
11pub enum Error {
12    #[error("Destructor is not a valid argument type")]
13    InvalidArgumentType,
14
15    #[error("Enum refers to an unknown interface: {0}")]
16    UnknownInterface(String),
17
18    #[error("Unknown type: {0}")]
19    UnknownEnum(String),
20
21    #[error("Cannot decide type of enum: {0}")]
22    BadEnumType(String),
23}
24type Result<T, E = Error> = std::result::Result<T, E>;
25
26struct EnumInfo {
27    is_bitfield: bool,
28    is_int:      Option<bool>,
29}
30/// Mapping from enum names to whether it's a bitflags
31/// interface -> name -> bool.
32struct EnumInfos<'a>(HashMap<&'a str, HashMap<&'a str, EnumInfo>>);
33impl<'a> EnumInfos<'a> {
34    fn get(&self, current_iface_name: &str, enum_: &str) -> Result<&EnumInfo, Error> {
35        Ok(if let Some((iface, enum_)) = enum_.split_once('.') {
36            self.0
37                .get(iface)
38                .ok_or_else(|| Error::UnknownInterface(iface.to_string()))?
39                .get(enum_)
40                .ok_or_else(|| Error::UnknownEnum(enum_.to_string()))?
41        } else {
42            self.0
43                .get(current_iface_name)
44                .unwrap()
45                .get(enum_)
46                .ok_or_else(|| Error::UnknownEnum(enum_.to_string()))?
47        })
48    }
49
50    fn get_mut(&mut self, current_iface_name: &str, enum_: &str) -> Result<&mut EnumInfo, Error> {
51        Ok(if let Some((iface, enum_)) = enum_.split_once('.') {
52            self.0
53                .get_mut(iface)
54                .ok_or_else(|| Error::UnknownInterface(iface.to_string()))?
55                .get_mut(enum_)
56                .ok_or_else(|| Error::UnknownEnum(enum_.to_string()))?
57        } else {
58            self.0
59                .get_mut(current_iface_name)
60                .unwrap()
61                .get_mut(enum_)
62                .ok_or_else(|| Error::UnknownEnum(enum_.to_string()))?
63        })
64    }
65}
66fn to_path<'a>(arr: impl IntoIterator<Item = &'a str>, leading_colon: bool) -> syn::Path {
67    syn::Path {
68        leading_colon: if leading_colon {
69            Some(Default::default())
70        } else {
71            None
72        },
73        segments:      arr
74            .into_iter()
75            .map(|s| syn::PathSegment::from(syn::Ident::new(s, Span::call_site())))
76            .collect(),
77    }
78}
79macro_rules! path {
80    ($($seg:ident)::*) => {
81        to_path([ $( stringify!($seg) ),* ], false)
82    };
83    (::$($seg:ident)::*) => {
84        to_path([ $( stringify!($seg) ),* ], true)
85    };
86}
87macro_rules! type_path {
88    ($($seg:ident)::*) => {
89        syn::Type::Path(syn::TypePath {
90            qself: None,
91            path: to_path([ $( stringify!($seg) ),* ], false),
92        })
93    };
94    (::$($seg:ident)::*) => {
95        syn::Type::Path(syn::TypePath {
96            qself: None,
97            path: to_path([ $( stringify!($seg) ),* ], true),
98        })
99    };
100}
101fn enum_type_name(enum_: &str, iface_version: &HashMap<&str, u32>) -> syn::Path {
102    if let Some((iface, name)) = enum_.split_once('.') {
103        let version = iface_version.get(iface).unwrap();
104        to_path(
105            [
106                "__generated_root",
107                iface,
108                &format!("v{version}"),
109                "enums",
110                &name.to_pascal_case(),
111            ],
112            false,
113        )
114    } else {
115        to_path(["enums", &enum_.to_pascal_case()], false)
116    }
117}
118fn generate_arg_type_with_lifetime(
119    arg: &spec_parser::protocol::Arg,
120    lifetime: &Option<Lifetime>,
121    iface_version: &HashMap<&str, u32>,
122) -> syn::Type {
123    use spec_parser::protocol::Type::*;
124    if let Some(enum_) = &arg.enum_ {
125        syn::Type::Path(syn::TypePath {
126            path:  enum_type_name(enum_, iface_version),
127            qself: None,
128        })
129    } else {
130        match arg.typ {
131            Int => type_path!(i32),
132            Uint => type_path!(u32),
133            Fixed => type_path!(::runa_wayland_scanner::types::Fixed),
134            Array =>
135                if let Some(lifetime) = lifetime {
136                    // &#lifetime [u8]
137                    syn::Type::Reference(syn::TypeReference {
138                        and_token:  Default::default(),
139                        lifetime:   Some(lifetime.clone()),
140                        mutability: None,
141                        elem:       Box::new(syn::Type::Slice(syn::TypeSlice {
142                            bracket_token: Default::default(),
143                            elem:          Box::new(type_path!(u8)),
144                        })),
145                    })
146                } else {
147                    // Box<[u8]>
148                    syn::Type::Path(syn::TypePath {
149                        qself: None,
150                        path:  syn::Path {
151                            leading_colon: None,
152                            segments:      [syn::PathSegment {
153                                ident:     syn::Ident::new("Box", Span::call_site()),
154                                arguments: syn::PathArguments::AngleBracketed(
155                                    syn::AngleBracketedGenericArguments {
156                                        colon2_token: None,
157                                        lt_token:     Default::default(),
158                                        args:         [syn::GenericArgument::Type(
159                                            syn::Type::Slice(syn::TypeSlice {
160                                                bracket_token: Default::default(),
161                                                elem:          Box::new(type_path!(u8)),
162                                            }),
163                                        )]
164                                        .into_iter()
165                                        .collect(),
166                                        gt_token:     Default::default(),
167                                    },
168                                ),
169                            }]
170                            .into_iter()
171                            .collect(),
172                        },
173                    })
174                },
175            Fd => type_path!(::runa_wayland_scanner::types::Fd),
176            String =>
177                if let Some(lifetime) = lifetime {
178                    // Str<#lifetime>
179                    let mut ty = path!(::runa_wayland_scanner::types::Str);
180                    ty.segments.last_mut().unwrap().arguments =
181                        syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
182                            colon2_token: None,
183                            lt_token:     Default::default(),
184                            args:         [syn::GenericArgument::Lifetime(lifetime.clone())]
185                                .into_iter()
186                                .collect(),
187                            gt_token:     Default::default(),
188                        });
189                    syn::Type::Path(syn::TypePath {
190                        path:  ty,
191                        qself: None,
192                    })
193                } else {
194                    type_path!(::runa_wayland_scanner::types::String)
195                },
196            Object => type_path!(::runa_wayland_scanner::types::Object),
197            NewId => type_path!(::runa_wayland_scanner::types::NewId),
198            Destructor => panic!("InvalidArgumentType"),
199        }
200    }
201}
202fn generate_arg_type(
203    arg: &spec_parser::protocol::Arg,
204    is_owned: bool,
205    iface_version: &HashMap<&str, u32>,
206) -> syn::Type {
207    generate_arg_type_with_lifetime(
208        arg,
209        &(!is_owned).then(|| Lifetime::new("'a", Span::call_site())),
210        iface_version,
211    )
212}
213
214fn type_is_borrowed(ty: &spec_parser::protocol::Type) -> bool {
215    use spec_parser::protocol::Type::*;
216    match ty {
217        Int | Uint | Fixed | Fd | Object | NewId => false,
218        String | Array => true,
219        Destructor => false,
220    }
221}
222
223fn generate_serialize_for_type(
224    current_iface_name: &str,
225    name: &syn::Ident,
226    arg: &spec_parser::protocol::Arg,
227    enum_info: &EnumInfos,
228) -> TokenStream {
229    use spec_parser::protocol::Type::*;
230    if let Fd = arg.typ {
231        return quote! {
232            fds.extend(Some(self.#name.take().expect("trying to send raw fd")));
233        }
234    }
235    let get = match arg.typ {
236        Int | Uint =>
237            if let Some(enum_) = &arg.enum_ {
238                let info = enum_info.get(current_iface_name, enum_).unwrap();
239                let repr = if info.is_int.unwrap_or(false) {
240                    path!(i32)
241                } else {
242                    path!(u32)
243                };
244                if info.is_bitfield {
245                    quote! {
246                        &self.#name.bits().to_ne_bytes()[..]
247                    }
248                } else {
249                    quote! {
250                        {
251                            let b: #repr = self.#name.into();
252                            &b.to_ne_bytes()[..]
253                        }
254                    }
255                }
256            } else {
257                quote! {
258                    &self.#name.to_ne_bytes()[..]
259                }
260            },
261        Fixed | Object | NewId => quote! {
262            &self.#name.0.to_ne_bytes()[..]
263        },
264        Fd => unreachable!(),
265        String => quote! {
266            self.#name.0
267        },
268        Array => quote! {
269            self.#name
270        },
271        Destructor => quote! {},
272    };
273    match arg.typ {
274        Int | Uint | Fixed | Object | NewId => quote! {
275            buf.put_slice(#get);
276        },
277        Array => quote! {
278            let tmp = #get;
279            // buf aligned to 4 bytes, plus length prefix
280            let aligned_len = ((tmp.len() + 3) & !3) + 4;
281            // [0, 4): length
282            // [4, buf.len()+4): buf
283            // [buf.len()+4, aligned_len): alignment
284            buf.put_u32_ne(tmp.len() as u32);
285            buf.put_slice(tmp);
286            buf.put_bytes(0, aligned_len - tmp.len() - 4);
287        },
288        String => quote! {
289            let tmp = #get;
290            // tmp doesn't have the trailing nul byte, so we add 1 to obtain its real length
291            let aligned_len = ((tmp.len() + 1 + 3) & !3) + 4;
292            // [0, 4): length
293            // [4, buf.len()+4): buf
294            // [buf.len()+4, buf.len()+1+4): nul        -- written together
295            // [buf.len()+1+4, aligned_len): alignment  -╯
296            buf.put_u32_ne((tmp.len() + 1) as u32);
297            buf.put_slice(tmp);
298            buf.put_bytes(0, aligned_len - tmp.len() - 4);
299        },
300        Fd => unreachable!(),
301        Destructor => quote! {},
302    }
303}
304fn generate_deserialize_for_type(
305    current_iface_name: &str,
306    arg_name: &str,
307    arg: &spec_parser::protocol::Arg,
308    enum_info: &EnumInfos,
309    iface_version: &HashMap<&str, u32>,
310) -> TokenStream {
311    use spec_parser::protocol::Type::*;
312    match arg.typ {
313        Int | Uint => {
314            let v = if arg.typ == Int {
315                quote! {  pop_i32(&mut data) }
316            } else {
317                quote! { pop_u32(&mut data) }
318            };
319            let err = if arg.typ == Int {
320                quote! { ::runa_wayland_scanner::io::de::Error::InvalidIntEnum }
321            } else {
322                quote! { ::runa_wayland_scanner::io::de::Error::InvalidUintEnum }
323            };
324            if let Some(enum_) = &arg.enum_ {
325                let is_bitfield = enum_info
326                    .get(current_iface_name, enum_)
327                    .unwrap()
328                    .is_bitfield;
329                let enum_ty = enum_type_name(enum_, iface_version);
330                let enum_ty = quote! { super::#enum_ty };
331                if is_bitfield {
332                    quote! { {
333                        let tmp = #v;
334                        #enum_ty::from_bits(tmp).ok_or_else(|| #err(tmp, std::any::type_name::<#enum_ty>()))?
335                    } }
336                } else {
337                    quote! { {
338                        let tmp = #v;
339                        #enum_ty::try_from(tmp).map_err(|_| #err(tmp, std::any::type_name::<#enum_ty>()))?
340                    } }
341                }
342            } else {
343                v
344            }
345        },
346        Fixed => quote! { ::runa_wayland_scanner::types::Fixed::from_bits(pop_i32(&mut data)) },
347        Object | NewId => quote! { pop_u32(&mut data).into() },
348        Fd => quote! { pop_fd(&mut fds).into() },
349        String => quote! { {
350            let len = pop_u32(&mut data) as usize;
351            let bytes = pop_bytes(&mut data, len);
352            if bytes[len - 1] != b'\0' {
353                return Err(::runa_wayland_scanner::io::de::Error::MissingNul(#arg_name));
354            }
355            bytes[..len - 1].into()
356        } },
357        Array => quote! { {
358            let len = pop_u32(&mut data);
359            pop_bytes(&mut data, len as usize)
360        } },
361        Destructor => quote! {},
362    }
363}
364fn generate_message_variant(
365    iface_name: &str,
366    opcode: u16,
367    request: &Message,
368    is_owned: bool,
369    _parent: &Ident,
370    iface_version: &HashMap<&str, u32>,
371    enum_info: &EnumInfos,
372) -> TokenStream {
373    let args = &request.args;
374    let args = args.iter().map(|arg| {
375        let name = format_ident!("{}", arg.name);
376        let ty = generate_arg_type(arg, is_owned, iface_version);
377        let doc_comment = generate_doc_comment(&arg.description);
378        quote! {
379            #doc_comment
380            pub #name: #ty
381        }
382    });
383    let pname = request.name.as_str().to_pascal_case();
384    let name = format_ident!("{}", pname);
385    let is_borrowed = if !is_owned {
386        request.args.iter().any(|arg| type_is_borrowed(&arg.typ))
387    } else {
388        false
389    };
390    let empty = quote! {};
391    let lta = quote!(<'a>);
392    let lifetime = if is_borrowed { &lta } else { &empty };
393    // Generate experssion for length calculation
394    let fixed_len: u32 = request
395        .args
396        .iter()
397        .filter(|arg| match arg.typ {
398            | spec_parser::protocol::Type::Int
399            | spec_parser::protocol::Type::Uint
400            | spec_parser::protocol::Type::Fixed
401            | spec_parser::protocol::Type::Object
402            | spec_parser::protocol::Type::String // string and array has a length prefix, so they
403            | spec_parser::protocol::Type::Array  // count, too.
404            | spec_parser::protocol::Type::NewId => true,
405            | spec_parser::protocol::Type::Fd
406            | spec_parser::protocol::Type::Destructor=> false,
407        })
408        .count() as u32 *
409        4 +
410        8; // 8 bytes for header
411    let variable_len = request.args.iter().map(|arg| {
412        let name = format_ident!("{}", arg.name);
413        match &arg.typ {
414            spec_parser::protocol::Type::Int |
415            spec_parser::protocol::Type::Uint |
416            spec_parser::protocol::Type::Fixed |
417            spec_parser::protocol::Type::Object |
418            spec_parser::protocol::Type::NewId |
419            spec_parser::protocol::Type::Fd |
420            spec_parser::protocol::Type::Destructor => quote! {},
421            spec_parser::protocol::Type::String => quote! {
422                + ((self.#name.0.len() + 1 + 3) & !3) as u32 // + 1 for NUL byte
423            },
424            spec_parser::protocol::Type::Array => quote! {
425                + ((self.#name.len() + 3) & !3) as u32
426            },
427        }
428    });
429    let serialize = request.args.iter().map(|arg| {
430        let name = format_ident!("{}", arg.name);
431        generate_serialize_for_type(iface_name, &name, arg, enum_info)
432    });
433    let deserialize = request.args.iter().map(|arg| {
434        let name = format_ident!("{}", arg.name);
435        let deserialize =
436            generate_deserialize_for_type(iface_name, &arg.name, arg, enum_info, iface_version);
437        quote! {
438            #name: #deserialize
439        }
440    });
441    let doc_comment = generate_doc_comment(&request.description);
442    let nfds: u8 = request
443        .args
444        .iter()
445        .filter(|arg| arg.typ == spec_parser::protocol::Type::Fd)
446        .count()
447        .try_into()
448        .unwrap();
449    let mut_ = if nfds != 0 {
450        quote! {mut}
451    } else {
452        quote! {}
453    };
454    let extra_derives = if nfds != 0 {
455        quote! {}
456    } else {
457        quote! {, Clone, Copy}
458    };
459    let public = quote! {
460        #doc_comment
461        #[derive(Debug, PartialEq, Eq #extra_derives)]
462        pub struct #name #lifetime {
463            #(#args),*
464        }
465        impl #lifetime #name #lifetime {
466            pub const OPCODE: u16 = #opcode;
467        }
468        impl #lifetime ::runa_wayland_scanner::io::ser::Serialize for #name #lifetime {
469            fn serialize<Fds: Extend<std::os::unix::io::OwnedFd>>(
470                #mut_ self,
471                buf: &mut ::runa_wayland_scanner::BytesMut,
472                fds: &mut Fds,
473            ) {
474                use ::runa_wayland_scanner::BufMut;
475                let msg_len = self.len() as u32;
476                let prefix: u32 = (msg_len << 16) + (#opcode as u32);
477                buf.put_u32_ne(prefix);
478                #(#serialize)*
479            }
480            #[inline]
481            fn len(&self) -> u16 {
482                (#fixed_len #(#variable_len)*) as u16
483            }
484            #[inline]
485            fn nfds(&self) -> u8 {
486                #nfds
487            }
488        }
489        impl<'a> ::runa_wayland_scanner::io::de::Deserialize<'a> for #name #lifetime {
490            #[inline]
491            fn deserialize(
492                mut data: &'a [u8], mut fds: &'a [::std::os::unix::io::RawFd]
493            ) -> Result<Self, ::runa_wayland_scanner::io::de::Error> {
494                use ::runa_wayland_scanner::io::{pop_fd, pop_bytes, pop_i32, pop_u32};
495                Ok(Self {
496                    #(#deserialize),*
497                })
498            }
499        }
500    };
501
502    public
503}
504
505fn generate_dispatch_trait(
506    messages: &[Message],
507    event_or_request: EventOrRequest,
508    iface_version: &HashMap<&str, u32>,
509) -> TokenStream {
510    let ty = match event_or_request {
511        EventOrRequest::Event => format_ident!("EventDispatch"),
512        EventOrRequest::Request => format_ident!("RequestDispatch"),
513    };
514    let hidden = if matches!(event_or_request, EventOrRequest::Event) {
515        quote!(#[doc(hidden)])
516    } else {
517        quote!()
518    };
519    let methods = messages.iter().map(|m| {
520        let name = if m.name == "move" || m.name == "type" {
521            format_ident!("{}_", m.name)
522        } else {
523            format_ident!("{}", m.name)
524        };
525        let retty = format_ident!("{}Fut", m.name.to_pascal_case());
526        let args = m.args.iter().map(|arg| {
527            let name = format_ident!("{}", arg.name);
528            let typ = generate_arg_type(arg, false, iface_version);
529            quote! {
530                #name: #typ
531            }
532        });
533        let doc_comment = generate_doc_comment(&m.description);
534        quote! {
535            #doc_comment
536            fn #name<'a>(
537                ctx: &'a mut Ctx,
538                object_id: u32,
539                #(#args),*
540            )
541            -> Self::#retty<'a>;
542        }
543    });
544    let fut_docs = messages.iter().map(|m| {
545        let name: Cow<'_, _> = if m.name == "move" || m.name == "type" {
546            format!("{}_", m.name).into()
547        } else {
548            m.name.as_str().into()
549        };
550        format!("Type of future returned by [`{name}`](Self::{name})")
551    });
552    let futs = messages
553        .iter()
554        .map(|m| format_ident!("{}Fut", m.name.to_pascal_case()));
555    quote! {
556        #hidden
557        pub trait #ty<Ctx> {
558            type Error;
559            #(
560                #[doc = #fut_docs]
561                type #futs<'a>: ::std::future::Future<Output = ::std::result::Result<(), Self::Error>> + 'a
562                    where Ctx: 'a;
563            )*
564            #(#methods)*
565        }
566    }
567}
568
569#[derive(Clone, Copy, PartialEq, Eq, Debug)]
570enum EventOrRequest {
571    Event,
572    Request,
573}
574
575fn generate_event_or_request(
576    iface_name: &str,
577    messages: &[Message],
578    iface_version: &HashMap<&str, u32>,
579    enum_info: &EnumInfos,
580    event_or_request: EventOrRequest,
581) -> TokenStream {
582    if messages.is_empty() {
583        quote! {}
584    } else {
585        let mod_name = match event_or_request {
586            EventOrRequest::Event => format_ident!("events"),
587            EventOrRequest::Request => format_ident!("requests"),
588        };
589        let type_name = match event_or_request {
590            EventOrRequest::Event => format_ident!("Event"),
591            EventOrRequest::Request => format_ident!("Request"),
592        };
593        let enum_is_borrowed = messages
594            .iter()
595            .any(|v| v.args.iter().any(|arg| type_is_borrowed(&arg.typ)));
596        let public = messages.iter().enumerate().map(|(opcode, v)| {
597            generate_message_variant(
598                iface_name,
599                opcode as u16,
600                v,
601                false,
602                &mod_name,
603                iface_version,
604                enum_info,
605            )
606        });
607        let enum_lifetime = if enum_is_borrowed {
608            quote! { <'a> }
609        } else {
610            quote! {}
611        };
612        let enum_members = messages.iter().map(|v| {
613            let name = format_ident!("{}", v.name.to_pascal_case());
614            let is_borrowed = v.args.iter().any(|arg| type_is_borrowed(&arg.typ));
615            let lifetime = if is_borrowed {
616                quote! { <'a> }
617            } else {
618                quote! {}
619            };
620            quote! {
621                #name(#mod_name::#name #lifetime),
622            }
623        });
624
625        let enum_serialize_cases = messages.iter().map(|v| {
626            let name = format_ident!("{}", v.name.to_pascal_case());
627            quote! {
628                Self::#name(v) => v.serialize(buf, fds),
629            }
630        });
631        let enum_deserialize_cases = messages.iter().enumerate().map(|(opcode, v)| {
632            let name = format_ident!("{}", v.name.to_pascal_case());
633            let opcode = opcode as u32;
634            quote! {
635                #opcode => {
636                    Ok(Self::#name(<#mod_name::#name>::deserialize(data, fds)?))
637                },
638            }
639        });
640        let enum_len_cases = messages.iter().map(|v| {
641            let name = format_ident!("{}", v.name.to_pascal_case());
642            quote! {
643                Self::#name(v) => v.len(),
644            }
645        });
646        let enum_nfds_cases = messages.iter().map(|v| {
647            let name = format_ident!("{}", v.name.to_pascal_case());
648            quote! {
649                Self::#name(v) => v.nfds(),
650            }
651        });
652        let dispatch = generate_dispatch_trait(messages, event_or_request, iface_version);
653        let public = quote! {
654            pub mod #mod_name {
655                use super::enums;
656                use super::__generated_root;
657                #(#public)*
658            }
659            #[doc = "Collection of all possible types of messages, see individual message types "]
660            #[doc = "for more information."]
661            #[derive(Debug, PartialEq, Eq)]
662            pub enum #type_name #enum_lifetime {
663                #(#enum_members)*
664            }
665            #dispatch
666            impl #enum_lifetime ::runa_wayland_scanner::io::ser::Serialize for #type_name #enum_lifetime {
667                fn serialize<Fds: Extend<std::os::unix::io::OwnedFd>>(
668                    self,
669                    buf: &mut ::runa_wayland_scanner::BytesMut,
670                    fds: &mut Fds,
671                ) {
672                    match self {
673                        #(#enum_serialize_cases)*
674                    }
675                }
676                #[inline]
677                fn len(&self) -> u16 {
678                    match self {
679                        #(#enum_len_cases)*
680                    }
681                }
682                #[inline]
683                fn nfds(&self) -> u8 {
684                    match self {
685                        #(#enum_nfds_cases)*
686                    }
687                }
688            }
689            impl<'a> ::runa_wayland_scanner::io::de::Deserialize<'a> for #type_name #enum_lifetime {
690                fn deserialize(
691                    mut data: &'a [u8], mut fds: &'a [::std::os::unix::io::RawFd]
692                ) -> ::std::result::Result<Self, ::runa_wayland_scanner::io::de::Error> {
693                    use ::runa_wayland_scanner::io::pop_u32;
694                    let _object_id = pop_u32(&mut data);
695                    let header = pop_u32(&mut data);
696                    let opcode = header & 0xFFFF;
697                    match opcode {
698                        #(#enum_deserialize_cases)*
699                        _ => Err(
700                            ::runa_wayland_scanner::io::de::Error::UnknownOpcode(
701                                opcode, std::any::type_name::<Self>())),
702                    }
703                }
704            }
705        };
706        public
707    }
708}
709fn wrap_links(line: &str) -> String {
710    let links = linkify::LinkFinder::new().links(line);
711    let mut result = String::new();
712    let mut curr_pos = 0;
713    for link in links {
714        if link.start() > curr_pos {
715            result.push_str(&line[curr_pos..link.start()]);
716        }
717        result.push('<');
718        result.push_str(link.as_str());
719        result.push('>');
720        curr_pos = link.end();
721    }
722    result.push_str(&line[curr_pos..]);
723    result
724}
725
726use lazy_static::lazy_static;
727use regex::Regex;
728lazy_static! {
729    static ref LINKREF_REGEX: Regex = Regex::new(r"\[([0-9]+)\]").unwrap();
730}
731fn generate_doc_comment(description: &Option<(String, String)>) -> TokenStream {
732    if let Some((summary, desc)) = description {
733        let desc = desc.split('\n').map(|s| {
734            let s = s.trim();
735            if let Some(m) = LINKREF_REGEX.find(s) {
736                if m.start() == 0 {
737                    // Fix cases like "[0] link". Change it to "[0]: link"
738                    let s: Cow<'_, _> = if !s[m.end()..].starts_with(':') {
739                        format!("{}:{}", s[..m.end()].trim(), s[m.end()..].trim()).into()
740                    } else {
741                        s.into()
742                    };
743                    return quote! {
744                        #[doc = #s]
745                    }
746                }
747            }
748            let s = wrap_links(s);
749            quote! {
750                #[doc = #s]
751            }
752        });
753        quote! {
754            #[doc = #summary]
755            #[doc = ""]
756            #(#desc)*
757        }
758    } else {
759        quote! {}
760    }
761}
762fn generate_enums(enums: &[Enum], current_iface_name: &str, enum_info: &EnumInfos) -> TokenStream {
763    let enums = enums.iter().map(|e| {
764        let doc = generate_doc_comment(&e.description);
765        let name = format_ident!("{}", e.name.to_pascal_case());
766        let info = enum_info.get(current_iface_name, &e.name).unwrap();
767        let is_bitfield = e.bitfield;
768        assert_eq!(info.is_bitfield, is_bitfield);
769        let repr = if info.is_int.unwrap_or(false) {
770            quote! { i32 }
771        } else {
772            quote! { u32 }
773        };
774        let members = e.entries.iter().map(|e| {
775            let name = if e.name.chars().all(|x| x.is_ascii_digit()) {
776                format_ident!("_{}", e.name)
777            } else if is_bitfield {
778                format_ident!("{}", e.name.TO_SHOUTY_SNEK_CASE())
779            } else {
780                format_ident!("{}", e.name.to_pascal_case())
781            };
782            let value = if info.is_int.unwrap_or(false) {
783                let value = e.value as i32;
784                quote! { #value }
785            } else {
786                let value = e.value;
787                quote! { #value }
788            };
789            let summary = e.summary.as_deref().unwrap_or("");
790            let summary = summary.replace('[', "\\[").replace(']', "\\]");
791            if is_bitfield {
792                quote! {
793                    #[doc = #summary]
794                    const #name = #value;
795                }
796            } else {
797                quote! {
798                    #[doc = #summary]
799                    #name = #value,
800                }
801            }
802        });
803        if is_bitfield {
804            quote! {
805                ::runa_wayland_scanner::bitflags! {
806                    #doc
807                    #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
808                    #[repr(transparent)]
809                    pub struct #name: #repr {
810                        #(#members)*
811                    }
812                }
813            }
814        } else {
815            quote! {
816                #doc
817                #[derive(
818                    ::runa_wayland_scanner::num_enum::IntoPrimitive,
819                    ::runa_wayland_scanner::num_enum::TryFromPrimitive,
820                    Debug, Clone, Copy, PartialEq, Eq
821                )]
822                #[repr(#repr)]
823                pub enum #name {
824                    #(#members)*
825                }
826            }
827        }
828    });
829    quote! {
830        pub mod enums {
831            #(#enums)*
832        }
833    }
834}
835fn generate_interface(
836    iface: &Interface,
837    iface_version: &HashMap<&str, u32>,
838    enum_info: &EnumInfos,
839) -> TokenStream {
840    let name = format_ident!("{}", iface.name);
841    let version = format_ident!("v{}", iface.version);
842
843    let requests = generate_event_or_request(
844        &iface.name,
845        &iface.requests,
846        iface_version,
847        enum_info,
848        EventOrRequest::Request,
849    );
850    let events = generate_event_or_request(
851        &iface.name,
852        &iface.events,
853        iface_version,
854        enum_info,
855        EventOrRequest::Event,
856    );
857    let doc_comment = generate_doc_comment(&iface.description);
858    let enums = generate_enums(&iface.enums, &iface.name, enum_info);
859
860    let iface_name = &iface.name;
861    let iface_version = iface.version;
862    quote! {
863        #doc_comment
864        pub mod #name {
865            #![allow(unused_imports, unused_mut, unused_variables)]
866            pub mod #version {
867                use super::super::__generated_root;
868                /// Name of the interface
869                pub const NAME: &str = #iface_name;
870                pub const VERSION: u32 = #iface_version;
871                #requests
872                #events
873                #enums
874            }
875        }
876    }
877}
878
879fn scan_enum(proto: &Protocol, enum_info: &mut EnumInfos) -> Result<()> {
880    for iface in proto.interfaces.iter() {
881        for req in iface.requests.iter().chain(iface.events.iter()) {
882            for arg in req.args.iter() {
883                if let Some(ref enum_) = arg.enum_ {
884                    let info = enum_info.get_mut(&iface.name, enum_)?;
885                    let is_int = arg.typ == spec_parser::protocol::Type::Int;
886
887                    eprintln!("{}::{}: is_int={}", iface.name, enum_, is_int);
888                    if let Some(old) = info.is_int {
889                        if old != is_int {
890                            return Err(Error::BadEnumType(enum_.clone()))
891                        }
892                    } else {
893                        info.is_int = Some(is_int);
894                    }
895                }
896            }
897        }
898    }
899    Ok(())
900}
901
902pub fn generate_protocol(proto: &Protocol) -> Result<TokenStream> {
903    let iface_version = proto
904        .interfaces
905        .iter()
906        .map(|i| (i.name.as_str(), i.version))
907        .collect();
908    let mut enum_info = EnumInfos(
909        proto
910            .interfaces
911            .iter()
912            .map(|i| {
913                (
914                    i.name.as_str(),
915                    i.enums
916                        .iter()
917                        .map(move |e| {
918                            (e.name.as_str(), EnumInfo {
919                                is_bitfield: e.bitfield,
920                                is_int:      None,
921                            })
922                        })
923                        .collect(),
924                )
925            })
926            .collect(),
927    );
928    scan_enum(proto, &mut enum_info)?;
929    let interfaces = proto
930        .interfaces
931        .iter()
932        .map(|v| generate_interface(v, &iface_version, &enum_info));
933    let name = format_ident!("{}", proto.name);
934    Ok(quote! {
935        #[allow(clippy::needless_lifetimes)]
936        pub mod #name {
937            use super::#name as __generated_root;
938            #(#interfaces)*
939        }
940    })
941}