python_ast/ast/tree/
assign.rs1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::{format_ident, quote};
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 CodeGen, CodeGenContext, ExprType, Name, Node, PythonOptions, SymbolTableNode,
8 SymbolTableScopes,
9};
10
11#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
12pub struct Assign {
13 pub targets: Vec<Name>,
14 pub value: ExprType,
15 pub type_comment: Option<String>,
16}
17
18impl<'a> FromPyObject<'a> for Assign {
19 fn extract(ob: &'a PyAny) -> PyResult<Self> {
20 let targets: Vec<Name> = ob
21 .getattr("targets")
22 .expect(
23 ob.error_message("<unknown>", "error getting unary operator")
24 .as_str(),
25 )
26 .extract()
27 .expect("1");
28
29 let python_value = ob.getattr("value").expect(
30 ob.error_message("<unknown>", "assignment statement value not found")
31 .as_str(),
32 );
33
34 let value = ExprType::extract(python_value).expect(
35 ob.error_message("<unknown>", "error getting value of assignment statement")
36 .as_str(),
37 );
38
39 Ok(Assign {
40 targets: targets,
41 value: value,
42 type_comment: None,
43 })
44 }
45}
46
47impl<'a> CodeGen for Assign {
48 type Context = CodeGenContext;
49 type Options = PythonOptions;
50 type SymbolTable = SymbolTableScopes;
51
52 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
53 let mut symbols = symbols;
54 let mut position = 0;
55 for target in self.targets {
56 symbols.insert(
57 target.id,
58 SymbolTableNode::Assign {
59 position: position,
60 value: self.value.clone(),
61 },
62 );
63 position += 1;
64 }
65 symbols
66 }
67
68 fn to_rust(
69 self,
70 ctx: Self::Context,
71 options: Self::Options,
72 symbols: Self::SymbolTable,
73 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
74 let mut stream = TokenStream::new();
75 for target in self.targets.into_iter().map(|n| n.id) {
76 let ident = format_ident!("{}", target);
77 stream.extend(quote!(#ident));
78 }
79 let value = self.value.to_rust(ctx, options, symbols)?;
80 Ok(quote!(#stream = #value))
81 }
82}