Skip to main content

lutra_codegen/
lib.rs

1mod python;
2mod rust;
3mod sql;
4
5pub use lutra_compiler::ProgramRepr;
6
7use std::borrow::Cow;
8use std::{fs, path};
9
10use lutra_bin::{Encode, ir};
11use lutra_compiler::DiscoverParams;
12
13/// Wrapper for [generate], designed to play well with Cargo's `build.rs`.
14///
15/// Writes into `OUT_DIR/lutra.rs`.
16/// Prints cargo::rerun-if-changed directives.
17/// Panics on error.
18#[track_caller]
19pub fn check_and_generate(project_dir: impl AsRef<path::Path>, options: GenerateOptions) {
20    let out_dir = path::PathBuf::from(std::env::var_os("OUT_DIR").unwrap());
21    let out_path = out_dir.join("lutra.rs");
22
23    // tracing_subscriber::fmt::Subscriber::builder()
24    //     .without_time()
25    //     .with_target(false)
26    //     .with_max_level(tracing::Level::DEBUG)
27    //     .with_writer(std::io::stderr)
28    //     .init();
29
30    // discover
31    let source = lutra_compiler::discover(DiscoverParams {
32        project: Some(project_dir.as_ref().to_path_buf()),
33    });
34    let source = match source {
35        Ok(s) => s,
36        Err(e) => panic!("{e}"),
37    };
38
39    // compile
40    let project = lutra_compiler::check(source, Default::default());
41    let project = match project {
42        Ok(p) => p,
43        Err(e) => panic!("{e}"),
44    };
45
46    // generate
47    let code = rust::run(&project, &options, out_dir).unwrap();
48    fs::write(out_path, code).unwrap();
49
50    // print directives with absolute paths for reliability
51    for (p, _) in project.source.get_sources() {
52        let abs_path = project.source.get_absolute_path(p);
53        println!("cargo::rerun-if-changed={}", abs_path.display());
54    }
55}
56
57#[track_caller]
58pub fn generate(
59    project: &lutra_compiler::Project,
60    target: Target,
61    out_path: impl AsRef<path::Path>,
62    options: GenerateOptions,
63) -> Result<(), lutra_compiler::error::Error> {
64    let out_path = out_path.as_ref();
65    let out_dir = out_path.parent().unwrap().to_path_buf();
66
67    // generate
68    let code = match target {
69        Target::Rust => rust::run(project, &options, out_dir).unwrap(),
70        Target::Python => python::run(project, &options, out_dir).unwrap(),
71        Target::Sql => sql::run(project, &options, out_dir).unwrap(),
72    };
73
74    fs::write(out_path, code)?;
75
76    Ok(())
77}
78
79pub enum Target {
80    Rust,
81    Python,
82    Sql,
83}
84
85#[track_caller]
86pub fn generate_program_bytecode(project_dir: &path::Path, program: &str, out_file: &path::Path) {
87    // discover the project
88    let source = lutra_compiler::discover(DiscoverParams {
89        project: Some(project_dir.into()),
90    })
91    .unwrap();
92
93    // compile
94    let project =
95        lutra_compiler::check(source, Default::default()).unwrap_or_else(|e| panic!("{e}"));
96
97    // lower & bytecode
98    let params =
99        lutra_compiler::CompileParams::new(program, lutra_compiler::ProgramRepr::BytecodeLt);
100    let (program, _ty) = lutra_compiler::compile(&project, &params).unwrap();
101
102    let buf = program.into_bytecode_lt().unwrap().encode();
103    std::fs::write(out_file, buf).unwrap();
104}
105
106#[derive(Debug, Clone)]
107#[cfg_attr(feature = "clap", derive(clap::Args))]
108pub struct GenerateOptions {
109    #[cfg_attr(feature = "clap", arg(long = "no-types", default_value_t = false))]
110    pub no_types: bool,
111
112    #[cfg_attr(
113        feature = "clap",
114        arg(long = "no-encode-decode", default_value_t = false)
115    )]
116    pub no_encode_decode: bool,
117
118    #[cfg_attr(
119        feature = "clap",
120        arg(
121            long = "no-function-traits",
122            default_value_t = true,
123            num_args = 0..=1,
124            default_missing_value = "true"
125        )
126    )]
127    pub no_function_traits: bool,
128
129    #[cfg_attr(feature = "clap", arg(long))]
130    pub client: bool,
131
132    #[cfg_attr(feature = "clap", arg(long))]
133    pub programs_bytecode_lt: Vec<String>,
134
135    #[cfg_attr(feature = "clap", arg(long))]
136    pub programs_sql_pg: Vec<String>,
137
138    #[cfg_attr(feature = "clap", arg(long))]
139    pub programs_sql_duckdb: Vec<String>,
140
141    #[cfg_attr(feature = "clap", arg(long, default_value = "::lutra_bin"))]
142    pub lutra_bin_path: String,
143
144    #[cfg_attr(feature = "clap", arg(long, default_value = "::lutra_runner"))]
145    pub lutra_runner_path: String,
146}
147
148impl Default for GenerateOptions {
149    fn default() -> Self {
150        Self {
151            no_types: false,
152            no_encode_decode: false,
153            no_function_traits: true,
154            client: false,
155            programs_bytecode_lt: Vec::new(),
156            programs_sql_pg: Vec::new(),
157            programs_sql_duckdb: Vec::new(),
158            lutra_bin_path: "::lutra_bin".into(),
159            lutra_runner_path: "::lutra_runner".into(),
160        }
161    }
162}
163
164impl GenerateOptions {
165    /// Do not generate type definitions
166    pub fn no_generate_types(mut self) -> Self {
167        self.no_types = true;
168        self
169    }
170
171    /// Do not generate [lutra_bin::Encode] and [lutra_bin::Decode] implementations
172    pub fn no_generate_encode_decode(mut self) -> Self {
173        self.no_encode_decode = true;
174        self
175    }
176
177    /// Do not generate traits for functions
178    pub fn generate_function_traits(mut self) -> Self {
179        self.no_function_traits = false;
180        self
181    }
182
183    /// Generate module client wrappers for generated programs
184    pub fn generate_client(mut self) -> Self {
185        self.client = true;
186        self
187    }
188
189    /// Set path to [lutra_bin] dependency
190    pub fn with_lutra_bin_path(mut self, path: String) -> Self {
191        self.lutra_bin_path = path;
192        self
193    }
194
195    /// Set path to [lutra_runner] dependency
196    pub fn with_lutra_runner_path(mut self, path: String) -> Self {
197        self.lutra_runner_path = path;
198        self
199    }
200
201    /// Generates programs for all functions in a module
202    pub fn generate_programs(mut self, module_path: impl ToString, fmt: ProgramRepr) -> Self {
203        let module_path = module_path.to_string();
204        match fmt {
205            ProgramRepr::BytecodeLt => self.programs_bytecode_lt.push(module_path),
206            ProgramRepr::SqlPg => self.programs_sql_pg.push(module_path),
207            ProgramRepr::SqlDuckdb => self.programs_sql_duckdb.push(module_path),
208        }
209        self
210    }
211
212    pub(crate) fn generates_types(&self) -> bool {
213        !self.no_types
214    }
215
216    pub(crate) fn generates_encode_decode(&self) -> bool {
217        !self.no_encode_decode
218    }
219
220    pub(crate) fn generates_function_traits(&self) -> bool {
221        !self.no_function_traits
222    }
223
224    pub(crate) fn generates_client(&self) -> bool {
225        self.client
226    }
227
228    pub(crate) fn included_program_repr(&self, module_path: &str) -> Option<ProgramRepr> {
229        if self.programs_bytecode_lt.iter().any(|p| p == module_path) {
230            Some(ProgramRepr::BytecodeLt)
231        } else if self.programs_sql_pg.iter().any(|p| p == module_path) {
232            Some(ProgramRepr::SqlPg)
233        } else if self.programs_sql_duckdb.iter().any(|p| p == module_path) {
234            Some(ProgramRepr::SqlDuckdb)
235        } else {
236            None
237        }
238    }
239
240    pub(crate) fn has_programs_in_subtree(&self, module_path: &str) -> bool {
241        self.programs_bytecode_lt
242            .iter()
243            .chain(self.programs_sql_pg.iter())
244            .chain(self.programs_sql_duckdb.iter())
245            .any(|p| {
246                p == module_path
247                    || (!module_path.is_empty()
248                        && p.starts_with(module_path)
249                        && p[module_path.len()..].starts_with("::"))
250                    || (module_path.is_empty() && !p.is_empty())
251            })
252    }
253}
254
255/// Types might not have names, because they are defined inline.
256/// This function traverses a type definition and generates names for all of the types.
257fn infer_names(def_name: &str, ty: &mut ir::Ty) {
258    if ty.name.is_none() {
259        ty.name = Some(def_name.to_string());
260    }
261
262    let mut name_prefix = Vec::new();
263    infer_names_re(ty, &mut name_prefix);
264}
265
266fn infer_names_of_program_ty(ty: &mut lutra_bin::rr::ProgramType, program_name: &str) {
267    let mut name_camel = vec![snake_to_sentence(program_name)];
268    {
269        name_camel.push("Input".into());
270        infer_names_re(&mut ty.input, &mut name_camel);
271        name_camel.pop();
272    }
273    {
274        name_camel.push("Output".into());
275        infer_names_re(&mut ty.output, &mut name_camel);
276        name_camel.pop();
277    }
278}
279
280fn infer_names_re(ty: &mut ir::Ty, name_prefix: &mut Vec<String>) {
281    if ty.name.is_none() {
282        ty.name = Some(name_prefix.concat());
283    } else {
284        name_prefix.push(ty.name.clone().unwrap());
285    }
286
287    match &mut ty.kind {
288        ir::TyKind::Primitive(_) | ir::TyKind::Ident(_) => {}
289
290        ir::TyKind::Tuple(fields) => {
291            for (index, field) in fields.iter_mut().enumerate() {
292                let name = tuple_field_name(&field.name, index);
293                name_prefix.push(name.into_owned());
294
295                infer_names_re(&mut field.ty, name_prefix);
296                name_prefix.pop();
297            }
298        }
299
300        ir::TyKind::Array(items_ty) => {
301            name_prefix.push("Items".to_string());
302            infer_names_re(items_ty, name_prefix);
303            name_prefix.pop();
304        }
305
306        ir::TyKind::Enum(variants) => {
307            for v in variants {
308                name_prefix.push(v.name.clone());
309                infer_names_re(&mut v.ty, name_prefix);
310                name_prefix.pop();
311            }
312        }
313
314        ir::TyKind::Function(func) => {
315            for param in &mut func.params {
316                infer_names_re(param, name_prefix);
317            }
318            infer_names_re(&mut func.body, name_prefix);
319        }
320    }
321}
322
323pub fn tuple_field_name(name: &Option<String>, index: usize) -> Cow<'_, str> {
324    (name.as_ref())
325        .map(|x| Cow::Borrowed(x.as_str()))
326        .unwrap_or_else(|| format!("field{index}").into())
327}
328
329fn camel_to_snake(camel: &str) -> String {
330    let mut snake = String::with_capacity(camel.len());
331    for current in camel.chars() {
332        if current.is_uppercase() {
333            if !snake.is_empty() && snake.ends_with('_') {
334                snake.push('_');
335            }
336            snake.push(current.to_lowercase().next().unwrap());
337        } else {
338            snake.push(current);
339        }
340    }
341
342    snake
343}
344
345fn snake_to_sentence(snake: &str) -> String {
346    let mut sentence = String::with_capacity(snake.len());
347    let mut next_upper = true;
348    for current in snake.chars() {
349        if current == '_' {
350            next_upper = true;
351            continue;
352        }
353
354        if next_upper {
355            sentence.push(current.to_uppercase().next().unwrap());
356        } else {
357            sentence.push(current.to_lowercase().next().unwrap());
358        }
359        next_upper = false;
360    }
361
362    sentence
363}