kdam_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields, FieldsNamed, Meta, Path};
4
5/// Derive [BarExt](https://docs.rs/kdam/latest/kdam/trait.BarExt.html) trait for a struct.
6///
7/// # Example
8///
9/// ```no_test
10/// use kdam::{tqdm, Bar, BarExt};
11/// use std::{io::Result, num::NonZeroU16};
12/// 
13/// #[derive(BarExt)]
14/// struct CustomBar {
15///     #[bar]
16///     pb: Bar,
17/// }
18/// 
19/// impl CustomBar {
20///     /// Render progress bar text.
21///     fn render(&mut self) -> String {
22///         let fmt_percentage = self.pb.fmt_percentage(0);
23///         let padding = 1 + fmt_percentage.chars().count() as u16 + self.pb.animation.spaces() as u16;
24/// 
25///         let ncols = self.pb.ncols_for_animation(padding);
26/// 
27///         if ncols == 0 {
28///             self.pb.bar_length = padding - 1;
29///             fmt_percentage
30///         } else {
31///             self.pb.bar_length = padding + ncols;
32///             self.pb.animation.fmt_render(
33///                 NonZeroU16::new(ncols).unwrap(),
34///                 self.pb.percentage(),
35///                 &None,
36///             ) + " "
37///                 + &fmt_percentage
38///         }
39///     }
40/// }
41/// ```
42#[proc_macro_derive(BarExt, attributes(bar))]
43pub fn bar_ext(input: TokenStream) -> TokenStream {
44    let input = parse_macro_input!(input as DeriveInput);
45    let mut bar_field = None;
46
47    if let Data::Struct(DataStruct {
48        fields: Fields::Named(FieldsNamed { named, .. }),
49        ..
50    }) = &input.data
51    {
52        for field in named {
53            for attr in &field.attrs {
54                if let Meta::Path(Path { segments, .. }) = &attr.meta {
55                    for segment in segments {
56                        if &segment.ident.to_string() == "bar" {
57                            bar_field = Some(field.ident.clone());
58                        }
59                    }
60                }
61            }
62        }
63    } else {
64        unimplemented!("BarExt derive macro is only derivable on structs.")
65    }
66
67    if bar_field.is_none() {
68        panic!("One struct field needs to use #[bar] attribute.")
69    }
70
71    let crate_name = if std::env::var("CARGO_CRATE_NAME")
72        .expect("CARGO_CRATE_NAME env variable not set by cargo.")
73        == "kdam"
74    {
75        "crate"
76    } else {
77        "kdam"
78    };
79    let crate_name = format_ident!("{}", crate_name);
80
81    let bar_field = bar_field
82        .flatten()
83        .expect("#[bar] attribute struct field has not a valid identifier.");
84    let name = &input.ident;
85    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
86
87    let expanded = quote! {
88        impl #impl_generics #crate_name::BarExt for #name #ty_generics #where_clause {
89            fn clear(&mut self) -> ::std::io::Result<()> {
90                self.#bar_field.clear()
91            }
92
93            fn input<T: Into<String>>(&mut self, text: T) -> ::std::io::Result<String> {
94                self.clear()?;
95                self.#bar_field.writer.print(text.into().as_bytes())?;
96
97                let mut buf = String::new();
98                ::std::io::stdin().read_line(&mut buf)?;
99
100                if self.#bar_field.leave {
101                    self.refresh()?;
102                }
103
104                Ok(buf)
105            }
106
107            fn refresh(&mut self) -> ::std::io::Result<()> {
108                self.#bar_field.elapsed_time();
109
110                if self.#bar_field.completed() {
111                    if !self.#bar_field.leave && self.#bar_field.position > 0 {
112                        return self.clear();
113                    }
114        
115                    self.#bar_field.total = self.#bar_field.counter;
116                }
117
118                let text = self.render();
119                let bar_length = #crate_name::term::Colorizer::len_ansi(text.as_str()) as u16;
120        
121                if bar_length > self.#bar_field.bar_length {
122                    self.clear()?;
123                    self.#bar_field.bar_length = bar_length;
124                }
125        
126                self.#bar_field.writer.print_at(self.#bar_field.position, text.as_bytes())?;
127                Ok(())
128            }
129
130            fn render(&mut self) -> String {
131                Self::render(self)
132            }
133
134            fn reset(&mut self, total: Option<usize>) {
135                self.#bar_field.reset(total);
136            }
137
138            fn update(&mut self, n: usize) -> ::std::io::Result<bool> {
139                self.#bar_field.counter += n;
140                let should_refresh = self.#bar_field.should_refresh();
141
142                if should_refresh {
143                    self.refresh()?;
144                }
145
146                Ok(should_refresh)
147            }
148
149            fn update_to(&mut self, n: usize) -> ::std::io::Result<bool> {
150                self.#bar_field.counter = n;
151                self.update(0)
152            }
153
154            fn write<T: Into<String>>(&mut self, text: T) -> ::std::io::Result<()> {
155                self.#bar_field.clear()?;
156                self.#bar_field.writer.print(format!("\r{}\n", text.into()).as_bytes())?;
157
158                if self.#bar_field.leave {
159                    self.refresh()?;
160                }
161
162                Ok(())
163            }
164
165            fn write_to<T: ::std::io::Write>(&mut self, writer: &mut T, n: Option<usize>) -> ::std::io::Result<bool> {
166                let text;
167
168                if let Some(n) = n {
169                    self.#bar_field.counter += n;
170
171                    if self.#bar_field.should_refresh() {
172                        text = #crate_name::term::Colorizer::trim_ansi(self.render().as_str());
173                    } else {
174                        return Ok(false);
175                    }
176                } else {
177                    text = #crate_name::term::Colorizer::trim_ansi(self.render().as_str());
178                }
179
180                self.#bar_field.bar_length = #crate_name::term::Colorizer::len_ansi(text.as_str()) as u16;
181                writer.write_all((text + "\n").as_bytes())?;
182                writer.flush()?;
183                Ok(true)
184            }
185        }
186    };
187
188    TokenStream::from(expanded)
189}