1#![warn(missing_docs)]
7
8use oak_python::ast::{BinaryOperator, Expression, Literal, Statement};
9use python_types::{PythonError, PythonResult, PythonValue};
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum Instruction {
14 LoadConst(PythonValue), LoadLocal(usize), LoadGlobal(String), LoadAttr(String), LoadIndex(usize), StoreLocal(usize), StoreGlobal(String), StoreAttr(String), StoreIndex(usize), Add, Sub, Mul, Div, Mod, Exp, FloorDiv, Eq, Neq, Lt, Lte, Gt, Gte, And, Or, Not, Jump(i32), JumpIfFalse(i32), JumpIfTrue(i32), JumpIfNone(i32), JumpIfNotNone(i32), Call(usize), Return, NewList(usize), NewDict(usize), NoneOp, TrueOp, FalseOp, Raise, TryBegin, TryEnd, Except, }
75
76#[derive(Debug, Clone, PartialEq)]
78pub struct Function {
79 pub name: String,
81 pub params: Vec<String>,
83 pub instructions: Vec<Instruction>,
85 pub locals: Vec<String>,
87}
88
89#[derive(Debug, Clone, PartialEq)]
91pub struct Module {
92 pub name: String,
94 pub functions: Vec<Function>,
96 pub globals: Vec<String>,
98}
99
100pub fn ast_to_ir(ast: &oak_python::PythonRoot) -> PythonResult<Module> {
102 let mut module = Module { name: "__main__".to_string(), functions: Vec::new(), globals: Vec::new() };
103
104 if module.functions.is_empty() {
106 let main_function =
107 Function { name: "__main__".to_string(), params: Vec::new(), instructions: Vec::new(), locals: Vec::new() };
108 module.functions.push(main_function);
109 }
110
111 for statement in &ast.program.statements {
113 match statement {
114 Statement::FunctionDef { name, parameters, body, .. } => {
115 let mut function = Function {
117 name: name.clone(),
118 params: parameters.iter().map(|p| p.name.clone()).collect(),
119 instructions: Vec::new(),
120 locals: parameters.iter().map(|p| p.name.clone()).collect(),
121 };
122
123 for stmt in body {
125 process_statement(stmt, &mut function.instructions, &mut function.locals)?;
126 }
127
128 if !function.instructions.iter().any(|inst| matches!(inst, Instruction::Return)) {
130 function.instructions.push(Instruction::NoneOp);
131 function.instructions.push(Instruction::Return);
132 }
133
134 if module.functions.len() == 1 && module.functions[0].name == "__main__" {
136 module.functions[0] = function;
137 }
138 else {
139 module.functions.push(function);
140 }
141 }
142 Statement::Assignment { target, value } => {
143 process_expression(value, &mut module.functions.last_mut().unwrap().instructions)?;
145
146 if let Expression::Name(name) = target {
148 if let Some(func) = module.functions.last_mut() {
150 if func.locals.contains(name) {
151 let index = func.locals.iter().position(|x| x == name).unwrap();
152 func.instructions.push(Instruction::StoreLocal(index));
153 }
154 else {
155 func.instructions.push(Instruction::StoreGlobal(name.clone()));
156 if !module.globals.contains(name) {
157 module.globals.push(name.clone());
158 }
159 }
160 }
161 }
162 }
163 Statement::Expression(expr) => {
164 process_expression(expr, &mut module.functions.last_mut().unwrap().instructions)?;
166 }
167 Statement::Return(expr) => {
168 if let Some(expr) = expr {
170 process_expression(expr, &mut module.functions.last_mut().unwrap().instructions)?;
171 }
172 else {
173 module.functions.last_mut().unwrap().instructions.push(Instruction::NoneOp);
174 }
175 module.functions.last_mut().unwrap().instructions.push(Instruction::Return);
176 }
177 _ => {
178 }
180 }
181 }
182
183 if let Some(func) = module.functions.last_mut() {
185 if !func.instructions.iter().any(|inst| matches!(inst, Instruction::Return)) {
186 func.instructions.push(Instruction::NoneOp);
187 func.instructions.push(Instruction::Return);
188 }
189 }
190
191 Ok(module)
192}
193
194fn process_statement(stmt: &Statement, instructions: &mut Vec<Instruction>, locals: &mut Vec<String>) -> PythonResult<()> {
196 match stmt {
197 Statement::Assignment { target, value } => {
198 process_expression(value, instructions)?;
200
201 if let Expression::Name(name) = target {
203 if !locals.contains(name) {
204 locals.push(name.clone());
205 }
206 let index = locals.iter().position(|x| x == name).unwrap();
207 instructions.push(Instruction::StoreLocal(index));
208 }
209 }
210 Statement::Expression(expr) => {
211 process_expression(expr, instructions)?;
212 }
213 Statement::Return(expr) => {
214 if let Some(expr) = expr {
215 process_expression(expr, instructions)?;
216 }
217 else {
218 instructions.push(Instruction::NoneOp);
219 }
220 instructions.push(Instruction::Return);
221 }
222 _ => {
223 }
225 }
226 Ok(())
227}
228
229fn process_expression(expr: &Expression, instructions: &mut Vec<Instruction>) -> PythonResult<()> {
231 match expr {
232 Expression::Literal(lit) => {
233 match lit {
234 Literal::Integer(i) => {
235 instructions.push(Instruction::LoadConst(PythonValue::Integer(*i)));
236 }
237 Literal::Float(f) => {
238 instructions.push(Instruction::LoadConst(PythonValue::Float(*f)));
239 }
240 Literal::String(s) => {
241 instructions.push(Instruction::LoadConst(PythonValue::String(s.clone())));
242 }
243 Literal::Boolean(b) => {
244 if *b {
245 instructions.push(Instruction::TrueOp);
246 }
247 else {
248 instructions.push(Instruction::FalseOp);
249 }
250 }
251 Literal::None => {
252 instructions.push(Instruction::NoneOp);
253 }
254 _ => {
255 }
257 }
258 }
259 Expression::Name(name) => {
260 instructions.push(Instruction::LoadGlobal(name.clone()));
262 }
263 Expression::BinaryOp { left, operator, right } => {
264 process_expression(left, instructions)?;
265 process_expression(right, instructions)?;
266
267 match operator {
268 BinaryOperator::Add => instructions.push(Instruction::Add),
269 BinaryOperator::Sub => instructions.push(Instruction::Sub),
270 BinaryOperator::Mult => instructions.push(Instruction::Mul),
271 BinaryOperator::Div => instructions.push(Instruction::Div),
272 _ => {
273 }
275 }
276 }
277 Expression::Call { func, args, .. } => {
278 for arg in args {
280 process_expression(arg, instructions)?;
281 }
282
283 process_expression(func, instructions)?;
285
286 instructions.push(Instruction::Call(args.len() as usize));
288 }
289 _ => {
290 }
292 }
293 Ok(())
294}
295
296pub fn optimize_ir(ir: &Module) -> PythonResult<Module> {
298 let mut optimized_module = Module { name: ir.name.clone(), functions: Vec::new(), globals: ir.globals.clone() };
299
300 for func in &ir.functions {
302 let mut optimized_func = Function {
303 name: func.name.clone(),
304 params: func.params.clone(),
305 instructions: optimize_instructions(&func.instructions),
306 locals: func.locals.clone(),
307 };
308 optimized_module.functions.push(optimized_func);
309 }
310
311 Ok(optimized_module)
312}
313
314fn optimize_instructions(instructions: &[Instruction]) -> Vec<Instruction> {
316 let mut optimized = Vec::new();
317 let mut i = 0;
318
319 while i < instructions.len() {
320 optimized.push(instructions[i].clone());
323 i += 1;
324 }
325
326 optimized
327}