Skip to main content

kt_list_comprehensions/
lib.rs

1//! Python-like list comprehensions for Rust.
2//! 
3//! Provides a macro for list comprehensions with the following features:
4//! - One or more nested list comprehensions
5//! - Support for multiple `if` clauses
6//! - Supports Rust expressions and patterns
7//! 
8//! # Examples
9//! ```rust
10//! use kt_list_comprehensions::list_comprehension;
11//!
12//! let vec = vec![-1, 2, 2, 3, 4];
13//!
14//! let result: Vec<i32> = list_comprehension![x * 2 for x in vec if x > 0].collect();
15//!
16//! assert_eq!(result, [4, 4, 6, 8]);
17//! ```
18//!
19//! ```rust
20//! use kt_list_comprehensions::list_comprehension;
21//!
22//! let vec_of_vectors = vec![vec![1, 2, 3], vec![4, 5, 6]];
23//!
24//! let result: Vec<i32> = list_comprehension![x for vec in vec_of_vectors for x in vec].collect();
25//!
26//! assert_eq!(result, [1, 2, 3, 4, 5, 6]);
27//! ```
28//! 
29//! ```rust
30//! use kt_list_comprehensions::list_comprehension;
31//! 
32//! let vec = vec!["1", "2", "not a number", "-3"];
33//! 
34//! let parse_i32 = |string: &str| string.parse::<i32>();
35//! 
36//! let result: Vec<i32> = list_comprehension![parse_i32(number).unwrap() for number in vec if parse_i32(number).is_ok() if parse_i32(number).unwrap() > 0].collect();
37//! 
38//! assert_eq!(result, [1, 2]);
39//! ```
40
41use proc_macro::TokenStream;
42use syn::parse::ParseStream;
43
44/// Python-like list comprehensions in Rust.
45/// 
46/// # Examples
47/// ```rust
48/// use kt_list_comprehensions::list_comprehension;
49/// 
50/// let vec = vec![-1, 2, 2, 3, 4];
51///
52/// let result: Vec<i32> = list_comprehension![x * 2 for x in vec if x > 0].collect();
53/// 
54/// assert_eq!(result, [4, 4, 6, 8]);
55/// ```
56/// 
57/// ```rust
58/// use kt_list_comprehensions::list_comprehension;
59/// 
60/// let vec_of_vectors = vec![vec![1, 2, 3], vec![4, 5, 6]];
61///
62/// let result: Vec<i32> = list_comprehension![x for vec in vec_of_vectors for x in vec].collect();
63/// 
64/// assert_eq!(result, [1, 2, 3, 4, 5, 6]);
65/// ```
66///
67/// ```rust
68/// use kt_list_comprehensions::list_comprehension;
69///
70/// let vec = vec!["1", "2", "not a number", "-3"];
71///
72/// let parse_i32 = |string: &str| string.parse::<i32>();
73///
74/// let result: Vec<i32> = list_comprehension![parse_i32(number).unwrap() for number in vec if parse_i32(number).is_ok() if parse_i32(number).unwrap() > 0].collect();
75///
76/// assert_eq!(result, [1, 2]);
77/// ```
78#[proc_macro]
79pub fn list_comprehension(token_stream: TokenStream) -> TokenStream {
80	let comprehension = syn::parse_macro_input!(token_stream as Comprehension);
81	
82	let tokens = quote::quote! {
83		#comprehension
84	};
85	
86	tokens.into()
87}
88
89struct Comprehension {
90	mapping: Mapping,
91	for_if_clause: ForIfClause,
92	additional_for_if_clauses: Vec<ForIfClause>,
93}
94
95impl syn::parse::Parse for Comprehension {
96	fn parse(input: ParseStream) -> syn::Result<Self> {
97		Ok(Self {
98			mapping: input.parse::<Mapping>()?,
99			for_if_clause: input.parse::<ForIfClause>()?,
100			additional_for_if_clauses: parse_zero_or_more::<ForIfClause>(input)
101		})
102	}
103}
104
105impl quote::ToTokens for Comprehension {
106	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
107		let all_clauses = core::iter::once(&self.for_if_clause)
108			.chain(&self.additional_for_if_clauses);
109		
110		let mut clause_iterator = all_clauses.rev();
111		
112		let mut output = {
113			let innermost_clause = clause_iterator
114				.next()
115				.expect("List comprehensions require at least one for-if clause");
116
117			let ForIfClause {
118				pattern,
119				expression: sequence,
120				conditions
121			} = innermost_clause;
122			
123			let Mapping(mapping) = &self.mapping;
124			
125			quote::quote! {
126				::core::iter::IntoIterator::into_iter(#sequence)
127					.filter_map(move |#pattern| {
128						(true #(&& (#conditions))*)
129							.then(|| #mapping)
130					})
131			}
132		};
133		
134		output = clause_iterator.fold(output, |current_output, next_clause| {
135			let ForIfClause {
136				pattern,
137				expression: sequence,
138				conditions,
139			} = next_clause;
140			
141			quote::quote! {
142				::core::iter::IntoIterator::into_iter(#sequence)
143					.filter_map(move |#pattern| {
144						(true #(&& (#conditions))*)
145							.then(|| #current_output)
146					})
147					.flatten()
148			}
149		});
150		
151		tokens.extend(output);
152	}
153}
154
155struct Mapping(syn::Expr);
156
157impl syn::parse::Parse for Mapping {
158	fn parse(input: ParseStream) -> syn::Result<Self> {
159		input.parse::<syn::Expr>()
160			.map(Self)
161	}
162}
163
164impl quote::ToTokens for Mapping {
165	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
166		self.0.to_tokens(tokens);
167	}
168}
169
170struct ForIfClause {
171	pattern: Pattern,
172	expression: syn::Expr,
173	conditions: Vec<Condition>,
174}
175
176impl syn::parse::Parse for ForIfClause {
177	fn parse(input: ParseStream) -> syn::Result<Self> {
178		input.parse::<syn::Token![for]>()?;
179		
180		let pattern = input.parse()?;
181		
182		input.parse::<syn::Token![in]>()?;
183		
184		let expression = input.parse()?;
185		
186		let conditions = parse_zero_or_more::<Condition>(input);
187		
188		Ok(Self {
189			pattern,
190			expression,
191			conditions
192		})
193	}
194}
195
196struct Pattern(syn::Pat);
197
198impl syn::parse::Parse for Pattern {
199	fn parse(input: ParseStream) -> syn::Result<Self> {
200		input.call(syn::Pat::parse_single)
201			.map(Self)
202	}
203}
204
205impl quote::ToTokens for Pattern {
206	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
207		self.0.to_tokens(tokens);
208	}
209}
210
211struct Condition(syn::Expr);
212
213impl syn::parse::Parse for Condition {
214	fn parse(input: ParseStream) -> syn::Result<Self> {
215		input.parse::<syn::Token![if]>()?;
216		
217		input.parse::<syn::Expr>()
218			.map(Self)
219	}
220}
221
222impl quote::ToTokens for Condition {
223	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
224		self.0.to_tokens(tokens);
225	}
226}
227
228fn parse_zero_or_more<T>(input: ParseStream) -> Vec<T>
229where
230	T: syn::parse::Parse
231{
232	let mut out = Vec::new();
233	
234	while let Ok(item) = input.parse::<T>() {
235		out.push(item);
236	}
237	
238	out
239}