1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use tupa_parser::{Comparator, Expr, ExprKind, Item, PipelineDecl, Program, Stmt, Type};
4use tupa_typecheck::analyze_effects;
5
6#[derive(Serialize, Deserialize)]
7pub struct ExecutionPlan {
8 pub name: String,
9 pub version: String,
10 pub seed: Option<u64>,
11 pub input_schema: TypeSchema,
12 pub output_schema: Option<TypeSchema>,
13 pub steps: Vec<StepPlan>,
14 pub constraints: Vec<ConstraintPlan>,
15 pub metrics: HashMap<String, f64>,
16 pub metric_plans: Vec<MetricPlan>,
17}
18
19#[derive(Serialize, Deserialize)]
20pub struct StepPlan {
21 pub name: String,
22 pub function_ref: String,
23 pub effects: Vec<String>,
24}
25
26#[derive(Serialize, Deserialize)]
27pub struct ConstraintPlan {
28 pub metric: String,
29 pub comparator: String,
30 pub threshold: f64,
31}
32
33#[derive(Serialize, Deserialize)]
34pub struct TypeSchema {
35 pub kind: String,
36 pub elem: Option<Box<TypeSchema>>,
37 pub len: Option<i64>,
38 pub name: Option<String>,
39 pub tensor_shape: Option<Vec<Option<usize>>>,
40 pub tensor_dtype: Option<String>,
41}
42
43#[derive(Serialize, Deserialize)]
44pub struct MetricPlan {
45 pub name: String,
46 pub function_ref: String,
47 pub args: serde_json::Value,
48}
49
50pub fn type_to_schema(ty: &Type) -> TypeSchema {
51 match ty {
52 Type::Tensor(t) => TypeSchema {
53 kind: "tensor".into(),
54 elem: None,
55 len: None,
56 name: None,
57 tensor_shape: Some(t.shape.iter().map(|&x| x.map(|n| n as usize)).collect()),
58 tensor_dtype: Some(format!("{:?}", t.dtype)),
59 },
60 Type::Array { elem, len } => TypeSchema {
61 kind: "array".into(),
62 elem: Some(Box::new(type_to_schema(elem))),
63 len: Some(*len),
64 name: None,
65 tensor_shape: None,
66 tensor_dtype: None,
67 },
68 Type::Slice { elem } => TypeSchema {
69 kind: "slice".into(),
70 elem: Some(Box::new(type_to_schema(elem))),
71 len: None,
72 name: None,
73 tensor_shape: None,
74 tensor_dtype: None,
75 },
76 Type::Safe { base, .. } => type_to_schema(base),
77 Type::Ident(name) => match name.as_str() {
78 "i64" => TypeSchema {
79 kind: "i64".into(),
80 elem: None,
81 len: None,
82 name: None,
83 tensor_shape: None,
84 tensor_dtype: None,
85 },
86 "f64" => TypeSchema {
87 kind: "f64".into(),
88 elem: None,
89 len: None,
90 name: None,
91 tensor_shape: None,
92 tensor_dtype: None,
93 },
94 "bool" => TypeSchema {
95 kind: "bool".into(),
96 elem: None,
97 len: None,
98 name: None,
99 tensor_shape: None,
100 tensor_dtype: None,
101 },
102 "string" => TypeSchema {
103 kind: "string".into(),
104 elem: None,
105 len: None,
106 name: None,
107 tensor_shape: None,
108 tensor_dtype: None,
109 },
110 _ => TypeSchema {
111 kind: "ident".into(),
112 elem: None,
113 len: None,
114 name: Some(name.clone()),
115 tensor_shape: None,
116 tensor_dtype: None,
117 },
118 },
119 _ => TypeSchema {
120 kind: "unknown".into(),
121 elem: None,
122 len: None,
123 name: None,
124 tensor_shape: None,
125 tensor_dtype: None,
126 },
127 }
128}
129
130fn constraint_to_plan(c: &tupa_parser::Constraint) -> ConstraintPlan {
131 ConstraintPlan {
132 metric: c.metric.clone(),
133 comparator: match c.comparator {
134 Comparator::Lt => "lt".into(),
135 Comparator::Le => "le".into(),
136 Comparator::Eq => "eq".into(),
137 Comparator::Ge => "ge".into(),
138 Comparator::Gt => "gt".into(),
139 },
140 threshold: c.threshold,
141 }
142}
143
144fn extract_metrics(pipeline: &PipelineDecl) -> HashMap<String, f64> {
145 let mut map = HashMap::new();
146 if let Some(block) = &pipeline.validation {
147 for stmt in block {
148 if let Stmt::Let { name, expr, .. } = stmt {
149 match &expr.kind {
150 ExprKind::Int(n) => {
151 map.insert(name.clone(), *n as f64);
152 }
153 ExprKind::Float(f) => {
154 map.insert(name.clone(), *f);
155 }
156 _ => {}
157 }
158 }
159 }
160 }
161 map
162}
163
164fn expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
165 match &expr.kind {
166 ExprKind::Int(n) => Some(serde_json::json!(*n)),
167 ExprKind::Float(f) => Some(serde_json::json!(*f)),
168 ExprKind::Bool(b) => Some(serde_json::json!(*b)),
169 ExprKind::ArrayLiteral(items) => {
170 let mut arr = Vec::new();
171 for it in items {
172 if let Some(v) = expr_to_json(it) {
173 arr.push(v);
174 } else {
175 return None;
176 }
177 }
178 Some(serde_json::Value::Array(arr))
179 }
180 _ => None,
181 }
182}
183
184fn extract_metric_plans(module_name: &str, pipeline: &PipelineDecl) -> Vec<MetricPlan> {
185 let mut list = Vec::new();
186 if let Some(block) = &pipeline.validation {
187 for stmt in block {
188 if let Stmt::Let { name, expr, .. } = stmt {
189 if let ExprKind::Call { callee, args } = &expr.kind {
190 if let ExprKind::Ident(func) = &callee.kind {
191 let json_args = if args.len() == 1 {
193 expr_to_json(&args[0]).unwrap_or(serde_json::Value::Null)
194 } else {
195 let mut arr = Vec::new();
196 for a in args {
197 if let Some(v) = expr_to_json(a) {
198 arr.push(v);
199 }
200 }
201 serde_json::Value::Array(arr)
202 };
203 list.push(MetricPlan {
204 name: name.clone(),
205 function_ref: format!("{module_name}::{func}"),
206 args: json_args,
207 });
208 }
209 }
210 }
211 }
212 }
213 list
214}
215
216pub fn codegen_pipeline(
217 module_name: &str,
218 pipeline: &PipelineDecl,
219 program: &Program,
220) -> serde_json::Result<String> {
221 let steps: Vec<StepPlan> = pipeline
222 .steps
223 .iter()
224 .map(|step| {
225 let effects = analyze_effects(&step.body, &HashMap::new()).to_names();
226 let mut function_ref = format!("{module_name}::step_{}", step.name);
227
228 if let ExprKind::Call { callee, args } = &step.body.kind {
230 if let ExprKind::Ident(func_name) = &callee.kind {
231 let is_simple_call = args.len() == 1
233 && matches!(&args[0].kind, ExprKind::Ident(n) if n == "input");
234
235 if is_simple_call {
236 for item in &program.items {
238 if let Item::Function(f) = item {
239 if &f.name == func_name {
240 if let Some(spec) = &f.external_spec {
241 if let Some(py_target) = &spec.python {
242 function_ref = format!("py:{}", py_target);
243 }
244 }
245 break;
246 }
247 }
248 }
249 }
250 }
251 }
252
253 StepPlan {
254 name: step.name.clone(),
255 function_ref,
256 effects,
257 }
258 })
259 .collect();
260 let plan = ExecutionPlan {
261 name: pipeline.name.clone(),
262 version: env!("CARGO_PKG_VERSION").to_string(),
263 seed: pipeline.seed,
264 input_schema: type_to_schema(&pipeline.input_ty),
265 output_schema: pipeline.output_ty.as_ref().map(type_to_schema),
266 steps,
267 constraints: pipeline
268 .constraints
269 .iter()
270 .map(constraint_to_plan)
271 .collect(),
272 metrics: extract_metrics(pipeline),
273 metric_plans: extract_metric_plans(module_name, pipeline),
274 };
275 serde_json::to_string_pretty(&plan)
276}