Skip to main content

lutra_codegen/
lib.rs

1mod python;
2mod rust;
3mod sql;
4
5pub use lutra_compiler::ProgramFormat;
6
7use std::borrow::Cow;
8use std::{fs, path};
9
10use lutra_bin::{Encode, ir};
11use lutra_compiler::DiscoverParams;
12
13/// Wrapper for [generate], designed to play well with Cargo's `build.rs`.
14///
15/// Writes into `OUT_DIR/lutra.rs`.
16/// Prints cargo::rerun-if-changed directives.
17/// Panics on error.
18#[track_caller]
19pub fn check_and_generate(project_dir: impl AsRef<path::Path>, options: GenerateOptions) {
20    let out_dir = path::PathBuf::from(std::env::var_os("OUT_DIR").unwrap());
21    let out_path = out_dir.join("lutra.rs");
22
23    // tracing_subscriber::fmt::Subscriber::builder()
24    //     .without_time()
25    //     .with_target(false)
26    //     .with_max_level(tracing::Level::DEBUG)
27    //     .with_writer(std::io::stderr)
28    //     .init();
29
30    // discover
31    let source = lutra_compiler::discover(DiscoverParams {
32        project: Some(project_dir.as_ref().to_path_buf()),
33    });
34    let source = match source {
35        Ok(s) => s,
36        Err(e) => panic!("{e}"),
37    };
38
39    // compile
40    let project = lutra_compiler::check(source, Default::default());
41    let project = match project {
42        Ok(p) => p,
43        Err(e) => panic!("{e}"),
44    };
45
46    // generate
47    let code = rust::run(&project, &options, out_dir).unwrap();
48    fs::write(out_path, code).unwrap();
49
50    // print directives with absolute paths for reliability
51    for (p, _) in project.source.get_sources() {
52        let abs_path = project.source.get_absolute_path(p);
53        println!("cargo::rerun-if-changed={}", abs_path.display());
54    }
55}
56
57#[track_caller]
58pub fn generate(
59    project: &lutra_compiler::Project,
60    target: Target,
61    out_path: impl AsRef<path::Path>,
62    options: GenerateOptions,
63) -> Result<(), lutra_compiler::error::Error> {
64    let out_path = out_path.as_ref();
65    let out_dir = out_path.parent().unwrap().to_path_buf();
66
67    // generate
68    let code = match target {
69        Target::Rust => rust::run(project, &options, out_dir).unwrap(),
70        Target::Python => python::run(project, &options, out_dir).unwrap(),
71        Target::Sql => sql::run(project, &options, out_dir).unwrap(),
72    };
73
74    fs::write(out_path, code)?;
75
76    Ok(())
77}
78
79pub enum Target {
80    Rust,
81    Python,
82    Sql,
83}
84
85#[track_caller]
86pub fn generate_program_bytecode(project_dir: &path::Path, program: &str, out_file: &path::Path) {
87    // discover the project
88    let source = lutra_compiler::discover(DiscoverParams {
89        project: Some(project_dir.into()),
90    })
91    .unwrap();
92
93    // compile
94    let project =
95        lutra_compiler::check(source, Default::default()).unwrap_or_else(|e| panic!("{e}"));
96
97    // lower & bytecode
98    let (program, _ty) = lutra_compiler::compile(
99        &project,
100        program,
101        None,
102        lutra_compiler::ProgramFormat::BytecodeLt,
103    )
104    .unwrap();
105
106    let buf = program.into_bytecode_lt().unwrap().encode();
107    std::fs::write(out_file, buf).unwrap();
108}
109
110#[derive(Debug, Clone)]
111pub struct GenerateOptions {
112    generate_types: bool,
113    generate_encode_decode: bool,
114    generate_function_traits: bool,
115
116    include_programs: Vec<(String, ProgramFormat)>,
117
118    lutra_bin_path: String,
119}
120
121impl Default for GenerateOptions {
122    fn default() -> Self {
123        Self {
124            generate_types: true,
125            generate_encode_decode: true,
126            generate_function_traits: false,
127            include_programs: Vec::new(),
128            lutra_bin_path: "::lutra_bin".into(),
129        }
130    }
131}
132
133impl GenerateOptions {
134    /// Do not generate type definitions
135    pub fn no_generate_types(mut self) -> Self {
136        self.generate_types = false;
137        self
138    }
139
140    /// Do not generate [lutra_bin::Encode] and [lutra_bin::Decode] implementations
141    pub fn no_generate_encode_decode(mut self) -> Self {
142        self.generate_encode_decode = false;
143        self
144    }
145
146    /// Do not generate traits for functions
147    pub fn generate_function_traits(mut self) -> Self {
148        self.generate_function_traits = true;
149        self
150    }
151
152    /// Set path to [lutra_bin] dependency
153    pub fn with_lutra_bin_path(mut self, path: String) -> Self {
154        self.lutra_bin_path = path;
155        self
156    }
157
158    /// Generates programs for all functions in a module
159    pub fn generate_programs(mut self, module_path: impl ToString, fmt: ProgramFormat) -> Self {
160        self.include_programs.push((module_path.to_string(), fmt));
161        self
162    }
163}
164
165/// Types might not have names, because they are defined inline.
166/// This function traverses a type definition and generates names for all of the types.
167fn infer_names(def_name: &str, ty: &mut ir::Ty) {
168    if ty.name.is_none() {
169        ty.name = Some(def_name.to_string());
170    }
171
172    let mut name_prefix = Vec::new();
173    infer_names_re(ty, &mut name_prefix);
174}
175
176fn infer_names_of_program_ty(ty: &mut lutra_bin::rr::ProgramType, program_name: &str) {
177    let mut name_camel = vec![snake_to_sentence(program_name)];
178    {
179        name_camel.push("Input".into());
180        infer_names_re(&mut ty.input, &mut name_camel);
181        name_camel.pop();
182    }
183    {
184        name_camel.push("Output".into());
185        infer_names_re(&mut ty.output, &mut name_camel);
186        name_camel.pop();
187    }
188}
189
190fn infer_names_re(ty: &mut ir::Ty, name_prefix: &mut Vec<String>) {
191    if ty.name.is_none() {
192        ty.name = Some(name_prefix.concat());
193    } else {
194        name_prefix.push(ty.name.clone().unwrap());
195    }
196
197    match &mut ty.kind {
198        ir::TyKind::Primitive(_) | ir::TyKind::Ident(_) => {}
199
200        ir::TyKind::Tuple(fields) => {
201            for (index, field) in fields.iter_mut().enumerate() {
202                let name = tuple_field_name(&field.name, index);
203                name_prefix.push(name.into_owned());
204
205                infer_names_re(&mut field.ty, name_prefix);
206                name_prefix.pop();
207            }
208        }
209
210        ir::TyKind::Array(items_ty) => {
211            name_prefix.push("Items".to_string());
212            infer_names_re(items_ty, name_prefix);
213            name_prefix.pop();
214        }
215
216        ir::TyKind::Enum(variants) => {
217            for v in variants {
218                name_prefix.push(v.name.clone());
219                infer_names_re(&mut v.ty, name_prefix);
220                name_prefix.pop();
221            }
222        }
223
224        ir::TyKind::Function(func) => {
225            for param in &mut func.params {
226                infer_names_re(param, name_prefix);
227            }
228            infer_names_re(&mut func.body, name_prefix);
229        }
230    }
231}
232
233pub fn tuple_field_name(name: &Option<String>, index: usize) -> Cow<'_, str> {
234    (name.as_ref())
235        .map(|x| Cow::Borrowed(x.as_str()))
236        .unwrap_or_else(|| format!("field{index}").into())
237}
238
239fn camel_to_snake(camel: &str) -> String {
240    let mut snake = String::with_capacity(camel.len());
241    for current in camel.chars() {
242        if current.is_uppercase() {
243            if !snake.is_empty() && snake.ends_with('_') {
244                snake.push('_');
245            }
246            snake.push(current.to_lowercase().next().unwrap());
247        } else {
248            snake.push(current);
249        }
250    }
251
252    snake
253}
254
255fn snake_to_sentence(snake: &str) -> String {
256    let mut sentence = String::with_capacity(snake.len());
257    let mut next_upper = true;
258    for current in snake.chars() {
259        if current == '_' {
260            next_upper = true;
261            continue;
262        }
263
264        if next_upper {
265            sentence.push(current.to_uppercase().next().unwrap());
266        } else {
267            sentence.push(current.to_lowercase().next().unwrap());
268        }
269        next_upper = false;
270    }
271
272    sentence
273}