1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 CodeGen, CodeGenContext, ExprType, Node, PythonOptions, SymbolTableScopes,
8 PyAttributeExtractor, extract_list,
9};
10
11#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
13pub struct ListComp {
14 pub elt: Box<ExprType>,
16 pub generators: Vec<Comprehension>,
18 pub lineno: Option<usize>,
20 pub col_offset: Option<usize>,
21 pub end_lineno: Option<usize>,
22 pub end_col_offset: Option<usize>,
23}
24
25#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
27pub struct SetComp {
28 pub elt: Box<ExprType>,
30 pub generators: Vec<Comprehension>,
32 pub lineno: Option<usize>,
34 pub col_offset: Option<usize>,
35 pub end_lineno: Option<usize>,
36 pub end_col_offset: Option<usize>,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
41pub struct GeneratorExp {
42 pub elt: Box<ExprType>,
44 pub generators: Vec<Comprehension>,
46 pub lineno: Option<usize>,
48 pub col_offset: Option<usize>,
49 pub end_lineno: Option<usize>,
50 pub end_col_offset: Option<usize>,
51}
52
53#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
55pub struct DictComp {
56 pub key: Box<ExprType>,
58 pub value: Box<ExprType>,
60 pub generators: Vec<Comprehension>,
62 pub lineno: Option<usize>,
64 pub col_offset: Option<usize>,
65 pub end_lineno: Option<usize>,
66 pub end_col_offset: Option<usize>,
67}
68
69#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
71pub struct Comprehension {
72 pub target: ExprType,
74 pub iter: ExprType,
76 pub ifs: Vec<ExprType>,
78 pub is_async: bool,
80}
81
82impl<'a> FromPyObject<'a> for ListComp {
83 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
84 let elt = ob.extract_attr_with_context("elt", "list comprehension element")?;
86 let elt: ExprType = elt.extract()?;
87
88 let generators: Vec<Comprehension> = extract_list(ob, "generators", "list comprehension generators")?;
90
91 Ok(ListComp {
92 elt: Box::new(elt),
93 generators,
94 lineno: ob.lineno(),
95 col_offset: ob.col_offset(),
96 end_lineno: ob.end_lineno(),
97 end_col_offset: ob.end_col_offset(),
98 })
99 }
100}
101
102impl<'a> FromPyObject<'a> for SetComp {
103 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
104 let elt = ob.extract_attr_with_context("elt", "set comprehension element")?;
106 let elt: ExprType = elt.extract()?;
107
108 let generators: Vec<Comprehension> = extract_list(ob, "generators", "set comprehension generators")?;
110
111 Ok(SetComp {
112 elt: Box::new(elt),
113 generators,
114 lineno: ob.lineno(),
115 col_offset: ob.col_offset(),
116 end_lineno: ob.end_lineno(),
117 end_col_offset: ob.end_col_offset(),
118 })
119 }
120}
121
122impl<'a> FromPyObject<'a> for GeneratorExp {
123 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
124 let elt = ob.extract_attr_with_context("elt", "generator expression element")?;
126 let elt: ExprType = elt.extract()?;
127
128 let generators: Vec<Comprehension> = extract_list(ob, "generators", "generator expression generators")?;
130
131 Ok(GeneratorExp {
132 elt: Box::new(elt),
133 generators,
134 lineno: ob.lineno(),
135 col_offset: ob.col_offset(),
136 end_lineno: ob.end_lineno(),
137 end_col_offset: ob.end_col_offset(),
138 })
139 }
140}
141
142impl<'a> FromPyObject<'a> for DictComp {
143 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
144 let key = ob.extract_attr_with_context("key", "dict comprehension key")?;
146 let key: ExprType = key.extract()?;
147
148 let value = ob.extract_attr_with_context("value", "dict comprehension value")?;
150 let value: ExprType = value.extract()?;
151
152 let generators: Vec<Comprehension> = extract_list(ob, "generators", "dict comprehension generators")?;
154
155 Ok(DictComp {
156 key: Box::new(key),
157 value: Box::new(value),
158 generators,
159 lineno: ob.lineno(),
160 col_offset: ob.col_offset(),
161 end_lineno: ob.end_lineno(),
162 end_col_offset: ob.end_col_offset(),
163 })
164 }
165}
166
167impl<'a> FromPyObject<'a> for Comprehension {
168 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
169 let target = ob.extract_attr_with_context("target", "comprehension target")?;
171 let target: ExprType = target.extract()?;
172
173 let iter = ob.extract_attr_with_context("iter", "comprehension iter")?;
175 let iter: ExprType = iter.extract()?;
176
177 let ifs: Vec<ExprType> = extract_list(ob, "ifs", "comprehension conditions").unwrap_or_default();
179
180 let is_async: bool = ob.getattr("is_async")?.extract().unwrap_or(false);
182
183 Ok(Comprehension {
184 target,
185 iter,
186 ifs,
187 is_async,
188 })
189 }
190}
191
192impl Node for ListComp {
193 fn lineno(&self) -> Option<usize> { self.lineno }
194 fn col_offset(&self) -> Option<usize> { self.col_offset }
195 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
196 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
197}
198
199impl Node for SetComp {
200 fn lineno(&self) -> Option<usize> { self.lineno }
201 fn col_offset(&self) -> Option<usize> { self.col_offset }
202 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
203 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
204}
205
206impl Node for GeneratorExp {
207 fn lineno(&self) -> Option<usize> { self.lineno }
208 fn col_offset(&self) -> Option<usize> { self.col_offset }
209 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
210 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
211}
212
213impl Node for DictComp {
214 fn lineno(&self) -> Option<usize> { self.lineno }
215 fn col_offset(&self) -> Option<usize> { self.col_offset }
216 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
217 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
218}
219
220impl CodeGen for ListComp {
221 type Context = CodeGenContext;
222 type Options = PythonOptions;
223 type SymbolTable = SymbolTableScopes;
224
225 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
226 let symbols = (*self.elt).clone().find_symbols(symbols);
228 self.generators.into_iter().fold(symbols, |acc, generator| {
229 let acc = generator.target.find_symbols(acc);
230 let acc = generator.iter.find_symbols(acc);
231 generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
232 })
233 }
234
235 fn to_rust(
236 self,
237 ctx: Self::Context,
238 options: Self::Options,
239 symbols: Self::SymbolTable,
240 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
241 if self.generators.len() == 1 {
244 let generator = &self.generators[0];
245 let elt = (*self.elt).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
246 let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
247
248 if generator.ifs.is_empty() {
249 Ok(quote! {
251 (#iter_expr).into_iter().map(|_item| #elt).collect::<Vec<_>>()
252 })
253 } else {
254 let conditions: Result<Vec<_>, _> = generator.ifs.iter()
256 .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
257 .collect();
258 let conditions = conditions?;
259 Ok(quote! {
260 (#iter_expr).into_iter()
261 .filter(|_item| { #(#conditions)&&* })
262 .map(|_item| #elt)
263 .collect::<Vec<_>>()
264 })
265 }
266 } else {
267 Ok(quote! {
270 vec![] })
272 }
273 }
274}
275
276impl CodeGen for SetComp {
277 type Context = CodeGenContext;
278 type Options = PythonOptions;
279 type SymbolTable = SymbolTableScopes;
280
281 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
282 let symbols = (*self.elt).clone().find_symbols(symbols);
284 self.generators.into_iter().fold(symbols, |acc, generator| {
285 let acc = generator.target.find_symbols(acc);
286 let acc = generator.iter.find_symbols(acc);
287 generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
288 })
289 }
290
291 fn to_rust(
292 self,
293 ctx: Self::Context,
294 options: Self::Options,
295 symbols: Self::SymbolTable,
296 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
297 if self.generators.len() == 1 {
300 let generator = &self.generators[0];
301 let elt = (*self.elt).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
302 let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
303
304 if generator.ifs.is_empty() {
305 Ok(quote! {
307 (#iter_expr).into_iter().map(|_item| #elt).collect::<std::collections::HashSet<_>>()
308 })
309 } else {
310 let conditions: Result<Vec<_>, _> = generator.ifs.iter()
312 .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
313 .collect();
314 let conditions = conditions?;
315 Ok(quote! {
316 (#iter_expr).into_iter()
317 .filter(|_item| { #(#conditions)&&* })
318 .map(|_item| #elt)
319 .collect::<std::collections::HashSet<_>>()
320 })
321 }
322 } else {
323 Ok(quote! {
326 std::collections::HashSet::new() })
328 }
329 }
330}
331
332impl CodeGen for GeneratorExp {
333 type Context = CodeGenContext;
334 type Options = PythonOptions;
335 type SymbolTable = SymbolTableScopes;
336
337 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
338 let symbols = (*self.elt).clone().find_symbols(symbols);
340 self.generators.into_iter().fold(symbols, |acc, generator| {
341 let acc = generator.target.find_symbols(acc);
342 let acc = generator.iter.find_symbols(acc);
343 generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
344 })
345 }
346
347 fn to_rust(
348 self,
349 ctx: Self::Context,
350 options: Self::Options,
351 symbols: Self::SymbolTable,
352 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
353 if self.generators.len() == 1 {
356 let generator = &self.generators[0];
357 let elt = (*self.elt).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
358 let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
359
360 if generator.ifs.is_empty() {
361 Ok(quote! {
363 (#iter_expr).into_iter().map(|_item| #elt)
364 })
365 } else {
366 let conditions: Result<Vec<_>, _> = generator.ifs.iter()
368 .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
369 .collect();
370 let conditions = conditions?;
371 Ok(quote! {
372 (#iter_expr).into_iter()
373 .filter(|_item| { #(#conditions)&&* })
374 .map(|_item| #elt)
375 })
376 }
377 } else {
378 Ok(quote! {
381 std::iter::empty() })
383 }
384 }
385}
386
387impl CodeGen for DictComp {
388 type Context = CodeGenContext;
389 type Options = PythonOptions;
390 type SymbolTable = SymbolTableScopes;
391
392 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
393 let symbols = (*self.key).clone().find_symbols(symbols);
395 let symbols = (*self.value).clone().find_symbols(symbols);
396 self.generators.into_iter().fold(symbols, |acc, generator| {
397 let acc = generator.target.find_symbols(acc);
398 let acc = generator.iter.find_symbols(acc);
399 generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
400 })
401 }
402
403 fn to_rust(
404 self,
405 ctx: Self::Context,
406 options: Self::Options,
407 symbols: Self::SymbolTable,
408 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
409 if self.generators.len() == 1 {
411 let generator = &self.generators[0];
412 let key = (*self.key).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
413 let value = (*self.value).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
414 let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
415
416 if generator.ifs.is_empty() {
417 Ok(quote! {
419 (#iter_expr).into_iter().map(|_item| (#key, #value)).collect::<std::collections::HashMap<_, _>>()
420 })
421 } else {
422 let conditions: Result<Vec<_>, _> = generator.ifs.iter()
424 .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
425 .collect();
426 let conditions = conditions?;
427 Ok(quote! {
428 (#iter_expr).into_iter()
429 .filter(|_item| { #(#conditions)&&* })
430 .map(|_item| (#key, #value))
431 .collect::<std::collections::HashMap<_, _>>()
432 })
433 }
434 } else {
435 Ok(quote! {
437 std::collections::HashMap::new() })
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 }