Skip to main content

opendp_tooling/codegen/r/
r.rs

1use std::{collections::HashMap, fs, iter::once};
2
3use crate::{
4    Argument, Function, TypeRecipe, Value,
5    codegen::{flatten_type_recipe, tab_r},
6};
7
8use super::BLOCKLIST;
9
10static ATOM_TYPES: &'static [&'static str] = &[
11    "u32", "u64", "i32", "i64", "f32", "f64", "usize", "bool", "String",
12];
13
14/// Generates all code for an OpenDP R module.
15/// Each call corresponds to one R file.
16pub fn generate_r_module(
17    module_name: &str,
18    module: &Vec<Function>,
19    hierarchy: &HashMap<String, Vec<String>>,
20) -> String {
21    let body = module
22        .into_iter()
23        .filter(|func| func.has_ffi)
24        .filter(|func| !BLOCKLIST.contains(&func.name.as_str()))
25        .map(|func| generate_r_function(module_name, &func, hierarchy))
26        .collect::<Vec<String>>()
27        .join("\n");
28
29    format!(
30        "# Auto-generated. Do not edit.
31
32#' @include typing.R mod.R
33NULL
34
35{body}"
36    )
37}
38
39/// Generate the code for a user-facing R function.
40///
41/// This internally calls a C function that wraps the Rust OpenDP Library
42pub(crate) fn generate_r_function(
43    module_name: &str,
44    func: &Function,
45    hierarchy: &HashMap<String, Vec<String>>,
46) -> String {
47    println!("generating R: {}", func.name);
48    let mut args = (func.args.iter())
49        .map(|arg| generate_r_input_argument(arg, func))
50        .collect::<Vec<_>>();
51
52    // move default arguments to end
53    args.sort_by(|(_, l_is_default), (_, r_is_default)| l_is_default.cmp(r_is_default));
54
55    let args = args.into_iter().map(|v| v.0).collect::<Vec<_>>();
56
57    let then_func = if func.name.starts_with("make_") {
58        let offset = if func.supports_partial { 2 } else { 0 };
59        let pre_args_nl = if args.len() > 0 { "\n" } else { "" };
60        format!(
61            r#"
62
63{then_docs}
64{then_name} <- function(
65{then_args}
66) {{
67{then_log}
68  make_chain_dyn(
69    {name}({pre_args_nl}{args}),
70    lhs,
71    log_)
72}}"#,
73            then_docs = generate_then_doc_block(module_name, func, hierarchy),
74            then_name = func.name.replacen("make_", "then_", 1),
75            then_args = tab_r(
76                once("lhs".to_string())
77                    .chain(args[offset..].to_owned())
78                    .collect::<Vec<_>>()
79                    .join(",\n")
80            ),
81            then_log = tab_r(generate_logger(module_name, func, true)),
82            name = func.name,
83            args = tab_r(tab_r(tab_r(
84                if func.supports_partial {
85                    vec![
86                        "output_domain(lhs)".to_string(),
87                        "output_metric(lhs)".to_string(),
88                    ]
89                } else {
90                    vec![]
91                }
92                .into_iter()
93                .chain(func.args[offset..].iter().map(|arg| {
94                    let name = if arg.is_type {
95                        format!(".{}", arg.name())
96                    } else {
97                        sanitize_r(arg.name(), arg.is_type)
98                    };
99                    format!("{name} = {name}")
100                }))
101                .collect::<Vec<_>>()
102                .join(",\n")
103            )))
104        )
105    } else {
106        String::default()
107    };
108
109    format!(
110        r#"
111{doc_block}
112{func_name} <- function(
113{args}
114) {{
115{body}
116}}{then_func}
117"#,
118        doc_block = generate_doc_block(module_name, func, hierarchy),
119        func_name = func.name.trim_start_matches("_"),
120        args = tab_r(args.join(",\n")),
121        body = tab_r(generate_r_body(module_name, func))
122    )
123}
124
125/// generate an input argument, complete with name, hint and default.
126/// also returns a bool to make it possible to move arguments with defaults to the end of the signature.
127fn generate_r_input_argument(arg: &Argument, func: &Function) -> (String, bool) {
128    let type_names = func.type_names();
129    let default = if let Some(default) = &arg.default {
130        Some(match default.clone() {
131            Value::Null => "NULL".to_string(),
132            Value::Bool(value) => if value { "TRUE" } else { "FALSE" }.to_string(),
133            Value::Integer(int) => format!("{}L", int),
134            Value::Float(float) => float.to_string(),
135            Value::String(mut string) => {
136                if arg.is_type {
137                    string = type_names.iter().fold(string, |string, generic| {
138                        // TODO: kind of hacky. avoids needing a full parser for the type string
139                        // replace all instances of the generic with .generic
140                        string
141                            .replace(&format!("{},", generic), &format!(".{},", generic))
142                            .replace(&format!("{}>", generic), &format!(".{}>", generic))
143                            .replace(&format!("{})", generic), &format!(".{})", generic))
144                    });
145                }
146                format!("\"{}\"", string)
147            }
148        })
149    } else {
150        // let default value be None if it is a type arg and there is a public example
151        generate_public_example(func, arg).map(|_| "NULL".to_string())
152    };
153    (
154        format!(
155            r#"{name}{default}"#,
156            name = sanitize_r(arg.name(), arg.is_type),
157            default = default
158                .as_ref()
159                .map(|default| format!(" = {}", default))
160                .unwrap_or_else(String::new)
161        ),
162        default.is_some(),
163    )
164}
165
166fn find_quoted(s: &str, pat: char) -> Option<(usize, usize)> {
167    let left = s.find(pat)?;
168    let right = left + 1 + s[left + 1..].find(pat)?;
169    Some((left, right))
170}
171
172// R wants special formatting for latex
173fn escape_latex(s: &str) -> String {
174    // if a substring is enclosed in latex "$"
175    if let Some((l, u)) = find_quoted(s, '$') {
176        [
177            escape_latex(&s[..l]).as_str(),
178            "\\eqn{",
179            &s[l + 1..u],
180            "}",
181            escape_latex(&s[u + 1..]).as_str(),
182        ]
183        .join("")
184    } else {
185        s.to_string()
186    }
187}
188
189// docstrings in R lead with a one-line title
190fn generate_constructor_title(name: &String) -> String {
191    if name.starts_with("make") {
192        format!(
193            "{} constructor\n\n",
194            name.trim_start_matches("make_").replace("_", " ")
195        )
196    } else {
197        String::new()
198    }
199}
200
201/// generate a documentation block for the current function, with the function description, args, and return
202/// in Roxygen format: https://mpn.metworx.com/packages/roxygen2/7.1.1/articles/rd-formatting.html
203fn generate_doc_block(
204    module_name: &str,
205    func: &Function,
206    hierarchy: &HashMap<String, Vec<String>>,
207) -> String {
208    let title = generate_constructor_title(&func.name);
209    let description = (func.description.as_ref())
210        .map(|v| {
211            v.split("\n")
212                .map(escape_latex)
213                .collect::<Vec<_>>()
214                .join("\n")
215        })
216        .map(|v| format!("{title}{}\n", v))
217        .unwrap_or_else(String::new);
218
219    let concept = format!("@concept {}\n", module_name);
220
221    let doc_args = (func.args.iter())
222        .map(|v| generate_docstring_arg(v))
223        .collect::<Vec<String>>()
224        .join("\n");
225
226    let export = if !func.name.starts_with("_") && module_name != "data" {
227        "\n@export"
228    } else {
229        ""
230    };
231
232    format!(
233        r#"{description}
234{concept}{doc_args}{ret_arg}{examples}{export}"#,
235        concept = concept,
236        description = description,
237        doc_args = doc_args,
238        ret_arg = generate_docstring_return_arg(&func.ret, hierarchy),
239        examples = generate_docstring_examples(module_name, func),
240    )
241    .split("\n")
242    .map(|l| format!("#' {}", l).trim().to_string())
243    .collect::<Vec<_>>()
244    .join("\n")
245}
246
247/// generate a documentation block for a then_* partial constructor, with the function description, args, and return
248/// in Roxygen format: https://mpn.metworx.com/packages/roxygen2/7.1.1/articles/rd-formatting.html
249fn generate_then_doc_block(
250    module_name: &str,
251    func: &Function,
252    hierarchy: &HashMap<String, Vec<String>>,
253) -> String {
254    let title = generate_constructor_title(&func.name);
255    let offset = if func.supports_partial { 2 } else { 0 };
256
257    let doc_args = (func.args[offset..].iter())
258        .map(|v| generate_docstring_arg(v))
259        .collect::<Vec<String>>()
260        .join("\n");
261
262    format!(
263        r#"partial {title}See documentation for [{func_name}()] for details.
264
265@concept {module_name}
266@param lhs The prior transformation or metric space.
267{doc_args}{ret_arg}{examples}
268@export"#,
269        func_name = func.name,
270        ret_arg = generate_docstring_return_arg(&func.ret, hierarchy),
271        examples = generate_docstring_examples(module_name, func),
272    )
273    .split("\n")
274    .map(|l| format!("#' {}", l).trim().to_string())
275    .collect::<Vec<_>>()
276    .join("\n")
277}
278
279/// generate the part of a docstring corresponding to an argument
280fn generate_docstring_arg(arg: &Argument) -> String {
281    let name = sanitize_r(arg.name(), arg.is_type);
282    format!(
283        r#"@param {name} {description}"#,
284        name = name,
285        description = arg
286            .description
287            .as_ref()
288            .map(|v| escape_latex(v.as_str()))
289            .unwrap_or_else(|| "undocumented".to_string())
290    )
291}
292
293/// generate the part of a docstring corresponding to a return argument
294fn generate_docstring_return_arg(
295    arg: &Argument,
296    hierarchy: &HashMap<String, Vec<String>>,
297) -> String {
298    let description = if let Some(description) = &arg.description {
299        description.clone()
300    } else if let Some(type_) = arg.python_type_hint(hierarchy) {
301        type_
302    } else {
303        return String::new();
304    };
305    format!("\n@return {description}")
306}
307
308fn generate_docstring_examples(module_name: &str, func: &Function) -> String {
309    let example_path = format!("src/{}/code/{}.R", &module_name, &func.name);
310    match fs::read_to_string(example_path) {
311        Ok(example) => format!("\n@examples\n{example}"),
312        Err(_) => "".to_string(),
313    }
314}
315
316/// generate the function body, consisting of type args formatters, data converters, and the call
317/// - type arg formatters make every type arg a RuntimeType, and construct derived RuntimeTypes
318/// - data converters convert from python to c representations according to the formatted type args
319/// - the call constructs and retrieves the ffi function name, sets ctypes,
320///     makes the call, handles errors, and converts the response to python
321fn generate_r_body(module_name: &str, func: &Function) -> String {
322    format!(
323        r#"{deprecated}{flag_checker}{type_arg_formatter}{logger}{assert_is_similar}
324{make_call}
325output"#,
326        deprecated = generate_deprecated(func),
327        flag_checker = generate_flag_check(&func.features),
328        type_arg_formatter = generate_type_arg_formatter(func),
329        assert_is_similar = generate_assert_is_similar(func),
330        logger = generate_logger(module_name, func, false),
331        make_call = generate_wrapper_call(module_name, func)
332    )
333}
334
335/// generate code that provides an example of the type of the type_arg
336fn generate_public_example(func: &Function, type_arg: &Argument) -> Option<String> {
337    // the json has supplied explicit instructions to find an example
338    if let Some(example) = &type_arg.example {
339        return Some(example.to_r(Some(&func.type_names())));
340    }
341
342    let type_name = type_arg.name.as_ref().unwrap();
343
344    // rewrite args to remove references to derived types
345    let mut args = func.args.clone();
346    args.iter_mut()
347        .filter(|arg| arg.rust_type.is_some())
348        .for_each(|arg| {
349            arg.rust_type = Some(flatten_type_recipe(
350                arg.rust_type.as_ref().unwrap(),
351                &func.derived_types,
352            ))
353        });
354
355    // code generation
356    args.iter().find_map(|arg| match &arg.rust_type {
357        Some(TypeRecipe::Name(name)) => (name == type_name).then(|| arg.name()),
358        Some(TypeRecipe::Nest { origin, args }) => {
359            if origin != "Vec" {
360                return None;
361            }
362            let TypeRecipe::Name(arg_name) = &args[0] else {
363                return None;
364            };
365            if arg_name != type_name {
366                return None;
367            }
368
369            Some(format!("get_first({name})", name = arg.name()))
370        }
371        _ => None,
372    })
373}
374
375/// the generated code ensures every type arg is a RuntimeType, and constructs derived RuntimeTypes
376fn generate_type_arg_formatter(func: &Function) -> String {
377    let type_names = func.type_names();
378    let type_arg_formatter: String = func.args.iter()
379        .filter(|arg| arg.is_type)
380        .map(|type_arg| {
381            let name = sanitize_r(type_arg.name(), type_arg.is_type);
382            let generics = type_arg.generics(&type_names);
383            let generics = if generics.is_empty() {
384                "".to_string()
385            } else {
386                format!(", generics = list({})", generics.iter()
387                    .map(|v| format!("\"{}\"", sanitize_r(v, true)))
388                    .collect::<Vec<_>>().join(", "))
389            };
390            if let Some(example) = generate_public_example(func, type_arg) {
391                format!(r#"{name} <- parse_or_infer(type_name = {name}, public_example = {example}{generics})"#)
392            } else {
393                format!(r#"{name} <- rt_parse(type_name = {name}{generics})"#)
394            }
395        })
396        // additional types that are constructed by introspecting existing types
397        .chain(func.derived_types.iter()
398            .map(|type_spec|
399                format!("{name} <- {derivation}",
400                        name = sanitize_r(type_spec.name(), true),
401                        derivation = type_spec.rust_type.as_ref().unwrap().to_r(Some(&type_names)))))
402        // substitute concrete types in for generics
403        .chain(func.args.iter()
404            .filter(|arg| arg.is_type && !arg.generics(&type_names).is_empty())
405            .map(|arg|
406                format!("{name} <- rt_substitute({name}, {args})",
407                        name=sanitize_r(arg.name(), true),
408                        args=arg.generics(&type_names).iter()
409                            .map(|generic| format!("{generic} = {generic}", generic = sanitize_r(generic, true)))
410                            .collect::<Vec<_>>().join(", "))))
411        // determine types of arguments that are not type args
412        .chain(func.args.iter().filter(|arg| arg.has_implicit_type())
413            .map(|arg| {
414                let name = sanitize_r(arg.name(), arg.is_type);
415                let converter = arg.rust_type.as_ref().unwrap().to_r(Some(&type_names));
416                format!(r#".T.{name} <- {converter}"#)
417            }))
418        .collect::<Vec<_>>()
419        .join("\n");
420
421    if type_arg_formatter.is_empty() {
422        "# No type arguments to standardize.".to_string()
423    } else {
424        format!(
425            r#"# Standardize type arguments.
426{type_arg_formatter}
427"#
428        )
429    }
430}
431
432fn generate_assert_is_similar(func: &Function) -> String {
433    let assert_is_similar: String = (func.args.iter())
434        .filter(|arg| !arg.is_type)
435        .map(|arg| {
436            let expected = if arg.has_implicit_type() {
437                let name = sanitize_r(arg.name(), arg.is_type);
438                format!(".T.{name}")
439            } else {
440                arg.rust_type
441                    .as_ref()
442                    .unwrap()
443                    .to_r(Some(&func.type_names()))
444            };
445            let inferred = {
446                let name = sanitize_r(arg.name(), arg.is_type);
447                format!("rt_infer({name})")
448            };
449            (expected, inferred)
450        })
451        .filter(|(expected, _)| expected != "NULL")
452        .map(|(expected, inferred)| {
453            format!("rt_assert_is_similar(expected = {expected}, inferred = {inferred})")
454        })
455        .collect::<Vec<_>>()
456        .join("\n");
457
458    if assert_is_similar.is_empty() {
459        "".to_string()
460    } else {
461        format!(
462            r#"
463# Assert that arguments are correctly typed.
464{assert_is_similar}
465"#
466        )
467    }
468}
469
470// generates the `log_ <- ...` code that tracks arguments
471fn generate_logger(module_name: &str, func: &Function, then: bool) -> String {
472    let func_name = if then {
473        func.name.replacen("make_", "then_", 1)
474    } else {
475        func.name.clone()
476    };
477    let offset = if then && func.supports_partial { 2 } else { 0 };
478    let keys = func.args[offset..]
479        .iter()
480        .map(|arg| format!("\"{}\"", arg.name()))
481        .collect::<Vec<_>>()
482        .join(", ");
483
484    let vals = func.args[offset..]
485        .iter()
486        .map(|arg| {
487            let r_name = sanitize_r(arg.name().clone(), arg.is_type);
488            generate_recipe_logger(arg.rust_type.clone().unwrap_or(TypeRecipe::None), r_name)
489        })
490        .collect::<Vec<_>>()
491        .join(", ");
492
493    format!(
494        r#"
495log_ <- new_constructor_log("{func_name}", "{module_name}", new_hashtab(
496  list({keys}),
497  list({vals})
498))
499"#
500    )
501}
502
503// adds unboxes wherever possible to `name`, based on its type `recipe`
504fn generate_recipe_logger(recipe: TypeRecipe, name: String) -> String {
505    match recipe {
506        TypeRecipe::Name(ty) if ATOM_TYPES.contains(&ty.as_str()) => format!("unbox2({name})"),
507        TypeRecipe::Nest { origin, .. } if origin == "Tuple" => {
508            format!("lapply({name}, unbox2)")
509        }
510        _ => name,
511    }
512}
513
514/// the generated code
515/// - constructs and retrieves the ffi function name
516/// - calls the C wrapper with arguments *and* extra type info
517fn generate_wrapper_call(module_name: &str, func: &Function) -> String {
518    let args = (func.args.iter())
519        .chain(func.derived_types.iter())
520        .map(|arg| sanitize_r(arg.name(), arg.is_type))
521        .chain(
522            (func.args.iter().filter(|arg| arg.has_implicit_type()))
523                .map(|arg| sanitize_r(arg.name(), arg.is_type))
524                .map(|name| format!("rt_parse(.T.{name})")),
525        )
526        .map(|name| format!("{name},"))
527        .collect::<Vec<_>>()
528        .join(" ");
529
530    let args_str = if args.is_empty() {
531        "".to_string()
532    } else {
533        format!("\n  {args}")
534    };
535
536    let call = format!(
537        r#".Call(
538  "{module_name}__{name}",{args_str}
539  log_, PACKAGE = "opendp")"#,
540        name = func.name
541    );
542    format!(
543        r#"# Call wrapper function.
544output <- {call}"#
545    )
546}
547
548// generate call to ".Deprecated()" if needed
549fn generate_deprecated(func: &Function) -> String {
550    if let Some(deprecation) = &func.deprecation {
551        format!(
552            ".Deprecated(msg = \"{}\")\n",
553            deprecation.note.replace("\"", "\\\"")
554        )
555    } else {
556        String::default()
557    }
558}
559
560// generate code that checks that a set of feature flags are enabled
561fn generate_flag_check(features: &Vec<String>) -> String {
562    if features.is_empty() {
563        String::default()
564    } else {
565        format!(
566            "assert_features({})\n\n",
567            features
568                .iter()
569                .map(|f| format!("\"{}\"", f))
570                .collect::<Vec<_>>()
571                .join(", ")
572        )
573    }
574}
575
576// ensures that `name` is a valid R variable name, and prefixes types with dots
577fn sanitize_r<T: AsRef<str> + ToString>(name: T, is_type: bool) -> String {
578    let mut name = name.to_string();
579    if is_type {
580        name = format!(".{}", name);
581    }
582    let blacklist = ["function", "T", "F"];
583    if blacklist.contains(&name.as_ref()) {
584        name = format!("{}_", name)
585    }
586    name
587}
588
589impl Argument {
590    // R wants to resolve all types on the R side, before passing them to C
591    // some function arguments contain their own nontrivial types/TypeRecipes,
592    //     like `Vec<T>` or `measurement_input_carrier_type(this)`
593    pub fn has_implicit_type(&self) -> bool {
594        !self.is_type && !matches!(self.rust_type, Some(TypeRecipe::Name(_) | TypeRecipe::None))
595    }
596}
597
598impl Function {
599    pub fn type_names(&self) -> Vec<String> {
600        (self.args.iter())
601            .filter(|arg| arg.is_type)
602            .map(|arg| arg.name())
603            .chain(self.derived_types.iter().map(Argument::name))
604            .collect()
605    }
606}
607
608impl TypeRecipe {
609    /// translate the abstract derived_types info into R RuntimeType constructors
610    pub fn to_r(&self, sanitize_types: Option<&[String]>) -> String {
611        match self {
612            Self::Name(name) => sanitize_types
613                .map(|types| sanitize_r(name, types.contains(name)))
614                .unwrap_or_else(|| name.to_string()),
615            Self::Function { function, params } => format!(
616                "{function}({params})",
617                function = function,
618                params = params
619                    .iter()
620                    .map(|v| v.to_r(sanitize_types))
621                    .collect::<Vec<_>>()
622                    .join(", ")
623            ),
624            Self::Nest { origin, args } => format!(
625                "new_runtime_type(origin = \"{origin}\", args = list({args}))",
626                origin = origin,
627                args = args
628                    .iter()
629                    .map(|arg| arg.to_r(sanitize_types))
630                    .collect::<Vec<_>>()
631                    .join(", ")
632            ),
633            Self::None => "NULL".to_string(),
634        }
635    }
636}