1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use tupa_parser::{Comparator, Expr, ExprKind, Item, PipelineDecl, Program, Stmt, Type};
4use tupa_typecheck::analyze_effects;
5
6#[derive(Serialize, Deserialize)]
7pub struct ExecutionPlan {
8 pub name: String,
9 pub version: String,
10 pub seed: Option<u64>,
11 pub input_schema: TypeSchema,
12 pub output_schema: Option<TypeSchema>,
13 pub steps: Vec<StepPlan>,
14 pub constraints: Vec<ConstraintPlan>,
15 pub metrics: HashMap<String, f64>,
16 pub metric_plans: Vec<MetricPlan>,
17}
18
19#[derive(Serialize, Deserialize)]
20pub struct StepPlan {
21 pub name: String,
22 pub function_ref: String,
23 pub effects: Vec<String>,
24}
25
26#[derive(Serialize, Deserialize)]
27pub struct ConstraintPlan {
28 pub metric: String,
29 pub comparator: String,
30 pub threshold: f64,
31}
32
33#[derive(Serialize, Deserialize)]
34pub struct TypeSchema {
35 pub kind: String,
36 pub elem: Option<Box<TypeSchema>>,
37 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub fields: Option<HashMap<String, TypeSchema>>,
39 pub len: Option<i64>,
40 pub name: Option<String>,
41 pub tensor_shape: Option<Vec<Option<usize>>>,
42 pub tensor_dtype: Option<String>,
43}
44
45#[derive(Serialize, Deserialize)]
46pub struct MetricPlan {
47 pub name: String,
48 pub function_ref: String,
49 pub args: serde_json::Value,
50}
51
52pub fn type_to_schema(ty: &Type) -> TypeSchema {
53 match ty {
54 Type::Tensor(t) => TypeSchema {
55 kind: "tensor".into(),
56 elem: None,
57 fields: None,
58 len: None,
59 name: None,
60 tensor_shape: Some(t.shape.iter().map(|&x| x.map(|n| n as usize)).collect()),
61 tensor_dtype: Some(format!("{:?}", t.dtype)),
62 },
63 Type::Array { elem, len } => TypeSchema {
64 kind: "array".into(),
65 elem: Some(Box::new(type_to_schema(elem))),
66 fields: None,
67 len: Some(*len),
68 name: None,
69 tensor_shape: None,
70 tensor_dtype: None,
71 },
72 Type::Slice { elem } => TypeSchema {
73 kind: "slice".into(),
74 elem: Some(Box::new(type_to_schema(elem))),
75 fields: None,
76 len: None,
77 name: None,
78 tensor_shape: None,
79 tensor_dtype: None,
80 },
81 Type::Record(fields) => TypeSchema {
82 kind: "object".into(),
83 elem: None,
84 fields: Some(
85 fields
86 .iter()
87 .map(|(name, ty)| (name.clone(), type_to_schema(ty)))
88 .collect(),
89 ),
90 len: None,
91 name: None,
92 tensor_shape: None,
93 tensor_dtype: None,
94 },
95 Type::Safe { base, .. } => type_to_schema(base),
96 Type::Ident(name) => match name.as_str() {
97 "i64" => TypeSchema {
98 kind: "i64".into(),
99 elem: None,
100 fields: None,
101 len: None,
102 name: None,
103 tensor_shape: None,
104 tensor_dtype: None,
105 },
106 "f64" => TypeSchema {
107 kind: "f64".into(),
108 elem: None,
109 fields: None,
110 len: None,
111 name: None,
112 tensor_shape: None,
113 tensor_dtype: None,
114 },
115 "bool" => TypeSchema {
116 kind: "bool".into(),
117 elem: None,
118 fields: None,
119 len: None,
120 name: None,
121 tensor_shape: None,
122 tensor_dtype: None,
123 },
124 "string" => TypeSchema {
125 kind: "string".into(),
126 elem: None,
127 fields: None,
128 len: None,
129 name: None,
130 tensor_shape: None,
131 tensor_dtype: None,
132 },
133 _ => TypeSchema {
134 kind: "ident".into(),
135 elem: None,
136 fields: None,
137 len: None,
138 name: Some(name.clone()),
139 tensor_shape: None,
140 tensor_dtype: None,
141 },
142 },
143 _ => TypeSchema {
144 kind: "unknown".into(),
145 elem: None,
146 fields: None,
147 len: None,
148 name: None,
149 tensor_shape: None,
150 tensor_dtype: None,
151 },
152 }
153}
154
155fn constraint_to_plan(c: &tupa_parser::Constraint) -> ConstraintPlan {
156 ConstraintPlan {
157 metric: c.metric.clone(),
158 comparator: match c.comparator {
159 Comparator::Lt => "lt".into(),
160 Comparator::Le => "le".into(),
161 Comparator::Eq => "eq".into(),
162 Comparator::Ge => "ge".into(),
163 Comparator::Gt => "gt".into(),
164 },
165 threshold: c.threshold,
166 }
167}
168
169fn extract_metrics(pipeline: &PipelineDecl) -> HashMap<String, f64> {
170 let mut map = HashMap::new();
171 if let Some(block) = &pipeline.validation {
172 for stmt in block {
173 if let Stmt::Let { name, expr, .. } = stmt {
174 match &expr.kind {
175 ExprKind::Int(n) => {
176 map.insert(name.clone(), *n as f64);
177 }
178 ExprKind::Float(f) => {
179 map.insert(name.clone(), *f);
180 }
181 _ => {}
182 }
183 }
184 }
185 }
186 map
187}
188
189fn expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
190 match &expr.kind {
191 ExprKind::Int(n) => Some(serde_json::json!(*n)),
192 ExprKind::Float(f) => Some(serde_json::json!(*f)),
193 ExprKind::Bool(b) => Some(serde_json::json!(*b)),
194 ExprKind::ArrayLiteral(items) => {
195 let mut arr = Vec::new();
196 for it in items {
197 if let Some(v) = expr_to_json(it) {
198 arr.push(v);
199 } else {
200 return None;
201 }
202 }
203 Some(serde_json::Value::Array(arr))
204 }
205 _ => None,
206 }
207}
208
209fn extract_metric_plans(module_name: &str, pipeline: &PipelineDecl) -> Vec<MetricPlan> {
210 let mut list = Vec::new();
211 if let Some(block) = &pipeline.validation {
212 for stmt in block {
213 if let Stmt::Let { name, expr, .. } = stmt {
214 if let ExprKind::Call { callee, args } = &expr.kind {
215 if let ExprKind::Ident(func) = &callee.kind {
216 let json_args = if args.len() == 1 {
218 expr_to_json(&args[0]).unwrap_or(serde_json::Value::Null)
219 } else {
220 let mut arr = Vec::new();
221 for a in args {
222 if let Some(v) = expr_to_json(a) {
223 arr.push(v);
224 }
225 }
226 serde_json::Value::Array(arr)
227 };
228 list.push(MetricPlan {
229 name: name.clone(),
230 function_ref: format!("{module_name}::{func}"),
231 args: json_args,
232 });
233 }
234 }
235 }
236 }
237 }
238 list
239}
240
241pub fn codegen_pipeline(
242 module_name: &str,
243 pipeline: &PipelineDecl,
244 program: &Program,
245) -> serde_json::Result<String> {
246 let steps: Vec<StepPlan> = pipeline
247 .steps
248 .iter()
249 .map(|step| {
250 let effects = analyze_effects(&step.body, &HashMap::new()).to_names();
251 let mut function_ref = format!("{module_name}::step_{}", step.name);
252
253 if let ExprKind::Call { callee, args } = &step.body.kind {
255 if let ExprKind::Ident(func_name) = &callee.kind {
256 let is_simple_call = args.len() == 1
258 && matches!(&args[0].kind, ExprKind::Ident(n) if n == "input");
259
260 if is_simple_call {
261 for item in &program.items {
263 if let Item::Function(f) = item {
264 if &f.name == func_name {
265 if let Some(spec) = &f.external_spec {
266 if let Some(py_target) = &spec.python {
267 function_ref = format!("py:{}", py_target);
268 }
269 }
270 break;
271 }
272 }
273 }
274 }
275 }
276 }
277
278 StepPlan {
279 name: step.name.clone(),
280 function_ref,
281 effects,
282 }
283 })
284 .collect();
285 let plan = ExecutionPlan {
286 name: pipeline.name.clone(),
287 version: env!("CARGO_PKG_VERSION").to_string(),
288 seed: pipeline.seed,
289 input_schema: type_to_schema(&pipeline.input_ty),
290 output_schema: pipeline.output_ty.as_ref().map(type_to_schema),
291 steps,
292 constraints: pipeline
293 .constraints
294 .iter()
295 .map(constraint_to_plan)
296 .collect(),
297 metrics: extract_metrics(pipeline),
298 metric_plans: extract_metric_plans(module_name, pipeline),
299 };
300 serde_json::to_string_pretty(&plan)
301}