Skip to main content

tupa_codegen/
execution_plan.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use tupa_parser::{Comparator, Expr, ExprKind, Item, PipelineDecl, Program, Stmt, Type};
4use tupa_typecheck::analyze_effects;
5
6#[derive(Serialize, Deserialize)]
7pub struct ExecutionPlan {
8    pub name: String,
9    pub version: String,
10    pub seed: Option<u64>,
11    pub input_schema: TypeSchema,
12    pub output_schema: Option<TypeSchema>,
13    pub steps: Vec<StepPlan>,
14    pub constraints: Vec<ConstraintPlan>,
15    pub metrics: HashMap<String, f64>,
16    pub metric_plans: Vec<MetricPlan>,
17}
18
19#[derive(Serialize, Deserialize)]
20pub struct StepPlan {
21    pub name: String,
22    pub function_ref: String,
23    pub effects: Vec<String>,
24}
25
26#[derive(Serialize, Deserialize)]
27pub struct ConstraintPlan {
28    pub metric: String,
29    pub comparator: String,
30    pub threshold: f64,
31}
32
33#[derive(Serialize, Deserialize)]
34pub struct TypeSchema {
35    pub kind: String,
36    pub elem: Option<Box<TypeSchema>>,
37    pub len: Option<i64>,
38    pub name: Option<String>,
39    pub tensor_shape: Option<Vec<Option<usize>>>,
40    pub tensor_dtype: Option<String>,
41}
42
43#[derive(Serialize, Deserialize)]
44pub struct MetricPlan {
45    pub name: String,
46    pub function_ref: String,
47    pub args: serde_json::Value,
48}
49
50pub fn type_to_schema(ty: &Type) -> TypeSchema {
51    match ty {
52        Type::Tensor(t) => TypeSchema {
53            kind: "tensor".into(),
54            elem: None,
55            len: None,
56            name: None,
57            tensor_shape: Some(t.shape.iter().map(|&x| x.map(|n| n as usize)).collect()),
58            tensor_dtype: Some(format!("{:?}", t.dtype)),
59        },
60        Type::Array { elem, len } => TypeSchema {
61            kind: "array".into(),
62            elem: Some(Box::new(type_to_schema(elem))),
63            len: Some(*len),
64            name: None,
65            tensor_shape: None,
66            tensor_dtype: None,
67        },
68        Type::Slice { elem } => TypeSchema {
69            kind: "slice".into(),
70            elem: Some(Box::new(type_to_schema(elem))),
71            len: None,
72            name: None,
73            tensor_shape: None,
74            tensor_dtype: None,
75        },
76        Type::Safe { base, .. } => type_to_schema(base),
77        Type::Ident(name) => match name.as_str() {
78            "i64" => TypeSchema {
79                kind: "i64".into(),
80                elem: None,
81                len: None,
82                name: None,
83                tensor_shape: None,
84                tensor_dtype: None,
85            },
86            "f64" => TypeSchema {
87                kind: "f64".into(),
88                elem: None,
89                len: None,
90                name: None,
91                tensor_shape: None,
92                tensor_dtype: None,
93            },
94            "bool" => TypeSchema {
95                kind: "bool".into(),
96                elem: None,
97                len: None,
98                name: None,
99                tensor_shape: None,
100                tensor_dtype: None,
101            },
102            "string" => TypeSchema {
103                kind: "string".into(),
104                elem: None,
105                len: None,
106                name: None,
107                tensor_shape: None,
108                tensor_dtype: None,
109            },
110            _ => TypeSchema {
111                kind: "ident".into(),
112                elem: None,
113                len: None,
114                name: Some(name.clone()),
115                tensor_shape: None,
116                tensor_dtype: None,
117            },
118        },
119        _ => TypeSchema {
120            kind: "unknown".into(),
121            elem: None,
122            len: None,
123            name: None,
124            tensor_shape: None,
125            tensor_dtype: None,
126        },
127    }
128}
129
130fn constraint_to_plan(c: &tupa_parser::Constraint) -> ConstraintPlan {
131    ConstraintPlan {
132        metric: c.metric.clone(),
133        comparator: match c.comparator {
134            Comparator::Lt => "lt".into(),
135            Comparator::Le => "le".into(),
136            Comparator::Eq => "eq".into(),
137            Comparator::Ge => "ge".into(),
138            Comparator::Gt => "gt".into(),
139        },
140        threshold: c.threshold,
141    }
142}
143
144fn extract_metrics(pipeline: &PipelineDecl) -> HashMap<String, f64> {
145    let mut map = HashMap::new();
146    if let Some(block) = &pipeline.validation {
147        for stmt in block {
148            if let Stmt::Let { name, expr, .. } = stmt {
149                match &expr.kind {
150                    ExprKind::Int(n) => {
151                        map.insert(name.clone(), *n as f64);
152                    }
153                    ExprKind::Float(f) => {
154                        map.insert(name.clone(), *f);
155                    }
156                    _ => {}
157                }
158            }
159        }
160    }
161    map
162}
163
164fn expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
165    match &expr.kind {
166        ExprKind::Int(n) => Some(serde_json::json!(*n)),
167        ExprKind::Float(f) => Some(serde_json::json!(*f)),
168        ExprKind::Bool(b) => Some(serde_json::json!(*b)),
169        ExprKind::ArrayLiteral(items) => {
170            let mut arr = Vec::new();
171            for it in items {
172                if let Some(v) = expr_to_json(it) {
173                    arr.push(v);
174                } else {
175                    return None;
176                }
177            }
178            Some(serde_json::Value::Array(arr))
179        }
180        _ => None,
181    }
182}
183
184fn extract_metric_plans(module_name: &str, pipeline: &PipelineDecl) -> Vec<MetricPlan> {
185    let mut list = Vec::new();
186    if let Some(block) = &pipeline.validation {
187        for stmt in block {
188            if let Stmt::Let { name, expr, .. } = stmt {
189                if let ExprKind::Call { callee, args } = &expr.kind {
190                    if let ExprKind::Ident(func) = &callee.kind {
191                        // build args JSON as array of supported literals
192                        let json_args = if args.len() == 1 {
193                            expr_to_json(&args[0]).unwrap_or(serde_json::Value::Null)
194                        } else {
195                            let mut arr = Vec::new();
196                            for a in args {
197                                if let Some(v) = expr_to_json(a) {
198                                    arr.push(v);
199                                }
200                            }
201                            serde_json::Value::Array(arr)
202                        };
203                        list.push(MetricPlan {
204                            name: name.clone(),
205                            function_ref: format!("{module_name}::{func}"),
206                            args: json_args,
207                        });
208                    }
209                }
210            }
211        }
212    }
213    list
214}
215
216pub fn codegen_pipeline(
217    module_name: &str,
218    pipeline: &PipelineDecl,
219    program: &Program,
220) -> serde_json::Result<String> {
221    let steps: Vec<StepPlan> = pipeline
222        .steps
223        .iter()
224        .map(|step| {
225            let effects = analyze_effects(&step.body, &HashMap::new()).to_names();
226            let mut function_ref = format!("{module_name}::step_{}", step.name);
227
228            // Check if body is a direct call to an external function
229            if let ExprKind::Call { callee, args } = &step.body.kind {
230                if let ExprKind::Ident(func_name) = &callee.kind {
231                    // We only optimize direct calls with 'input' argument for now
232                    let is_simple_call = args.len() == 1
233                        && matches!(&args[0].kind, ExprKind::Ident(n) if n == "input");
234
235                    if is_simple_call {
236                        // Find function definition
237                        for item in &program.items {
238                            if let Item::Function(f) = item {
239                                if &f.name == func_name {
240                                    if let Some(spec) = &f.external_spec {
241                                        if let Some(py_target) = &spec.python {
242                                            function_ref = format!("py:{}", py_target);
243                                        }
244                                    }
245                                    break;
246                                }
247                            }
248                        }
249                    }
250                }
251            }
252
253            StepPlan {
254                name: step.name.clone(),
255                function_ref,
256                effects,
257            }
258        })
259        .collect();
260    let plan = ExecutionPlan {
261        name: pipeline.name.clone(),
262        version: env!("CARGO_PKG_VERSION").to_string(),
263        seed: pipeline.seed,
264        input_schema: type_to_schema(&pipeline.input_ty),
265        output_schema: pipeline.output_ty.as_ref().map(type_to_schema),
266        steps,
267        constraints: pipeline
268            .constraints
269            .iter()
270            .map(constraint_to_plan)
271            .collect(),
272        metrics: extract_metrics(pipeline),
273        metric_plans: extract_metric_plans(module_name, pipeline),
274    };
275    serde_json::to_string_pretty(&plan)
276}