opendp_tooling/codegen/r/
r.rs

1use std::{collections::HashMap, fs, iter::once};
2
3use crate::{
4    codegen::{flatten_type_recipe, tab_r},
5    Argument, Function, TypeRecipe, Value,
6};
7
8use super::BLOCKLIST;
9
10static ATOM_TYPES: &'static [&'static str] = &[
11    "u32", "u64", "i32", "i64", "f32", "f64", "usize", "bool", "String",
12];
13
14/// Generates all code for an OpenDP R module.
15/// Each call corresponds to one R file.
16pub fn generate_r_module(
17    module_name: &str,
18    module: &Vec<Function>,
19    hierarchy: &HashMap<String, Vec<String>>,
20) -> String {
21    let body = module
22        .into_iter()
23        .filter(|func| func.has_ffi)
24        .filter(|func| !BLOCKLIST.contains(&func.name.as_str()))
25        .map(|func| generate_r_function(module_name, &func, hierarchy))
26        .collect::<Vec<String>>()
27        .join("\n");
28
29    format!(
30        "# Auto-generated. Do not edit.
31
32#' @include typing.R mod.R
33NULL
34
35{body}"
36    )
37}
38
39/// Generate the code for a user-facing R function.
40///
41/// This internally calls a C function that wraps the Rust OpenDP Library
42pub(crate) fn generate_r_function(
43    module_name: &str,
44    func: &Function,
45    hierarchy: &HashMap<String, Vec<String>>,
46) -> String {
47    println!("generating R: {}", func.name);
48    let mut args = (func.args.iter())
49        .map(|arg| generate_r_input_argument(arg, func))
50        .collect::<Vec<_>>();
51
52    // move default arguments to end
53    args.sort_by(|(_, l_is_default), (_, r_is_default)| l_is_default.cmp(r_is_default));
54
55    let args = args.into_iter().map(|v| v.0).collect::<Vec<_>>();
56
57    let then_func = if func.name.starts_with("make_") {
58        let offset = if func.supports_partial { 2 } else { 0 };
59        let pre_args_nl = if args.len() > 0 { "\n" } else { "" };
60        format!(
61            r#"
62
63{then_docs}
64{then_name} <- function(
65{then_args}
66) {{
67{then_log}
68  make_chain_dyn(
69    {name}({pre_args_nl}{args}),
70    lhs,
71    log_)
72}}"#,
73            then_docs = generate_then_doc_block(module_name, func, hierarchy),
74            then_name = func.name.replacen("make_", "then_", 1),
75            then_args = tab_r(
76                once("lhs".to_string())
77                    .chain(args[offset..].to_owned())
78                    .collect::<Vec<_>>()
79                    .join(",\n")
80            ),
81            then_log = tab_r(generate_logger(module_name, func, true)),
82            name = func.name,
83            args = tab_r(tab_r(tab_r(
84                if func.supports_partial {
85                    vec![
86                        "output_domain(lhs)".to_string(),
87                        "output_metric(lhs)".to_string(),
88                    ]
89                } else {
90                    vec![]
91                }
92                .into_iter()
93                .chain(func.args[offset..].iter().map(|arg| {
94                    let name = if arg.is_type {
95                        format!(".{}", arg.name())
96                    } else {
97                        sanitize_r(arg.name(), arg.is_type)
98                    };
99                    format!("{name} = {name}")
100                }))
101                .collect::<Vec<_>>()
102                .join(",\n")
103            )))
104        )
105    } else {
106        String::default()
107    };
108
109    format!(
110        r#"
111{doc_block}
112{func_name} <- function(
113{args}
114) {{
115{body}
116}}{then_func}
117"#,
118        doc_block = generate_doc_block(module_name, func, hierarchy),
119        func_name = func.name.trim_start_matches("_"),
120        args = tab_r(args.join(",\n")),
121        body = tab_r(generate_r_body(module_name, func))
122    )
123}
124
125/// generate an input argument, complete with name, hint and default.
126/// also returns a bool to make it possible to move arguments with defaults to the end of the signature.
127fn generate_r_input_argument(arg: &Argument, func: &Function) -> (String, bool) {
128    let default = if let Some(default) = &arg.default {
129        Some(match default.clone() {
130            Value::Null => "NULL".to_string(),
131            Value::Bool(value) => if value { "TRUE" } else { "FALSE" }.to_string(),
132            Value::Integer(int) => format!("{}L", int),
133            Value::Float(float) => float.to_string(),
134            Value::String(mut string) => {
135                if arg.is_type {
136                    string = arg.generics.iter().fold(string, |string, generic| {
137                        // TODO: kind of hacky. avoids needing a full parser for the type string
138                        // replace all instances of the generic with .generic
139                        string
140                            .replace(&format!("{},", generic), &format!(".{},", generic))
141                            .replace(&format!("{}>", generic), &format!(".{}>", generic))
142                            .replace(&format!("{})", generic), &format!(".{})", generic))
143                    });
144                }
145                format!("\"{}\"", string)
146            }
147        })
148    } else {
149        // let default value be None if it is a type arg and there is a public example
150        generate_public_example(func, arg).map(|_| "NULL".to_string())
151    };
152    (
153        format!(
154            r#"{name}{default}"#,
155            name = sanitize_r(arg.name(), arg.is_type),
156            default = default
157                .as_ref()
158                .map(|default| format!(" = {}", default))
159                .unwrap_or_else(String::new)
160        ),
161        default.is_some(),
162    )
163}
164
165fn find_quoted(s: &str, pat: char) -> Option<(usize, usize)> {
166    let left = s.find(pat)?;
167    let right = left + 1 + s[left + 1..].find(pat)?;
168    Some((left, right))
169}
170
171// R wants special formatting for latex
172fn escape_latex(s: &str) -> String {
173    // if a substring is enclosed in latex "$"
174    if let Some((l, u)) = find_quoted(s, '$') {
175        [
176            escape_latex(&s[..l]).as_str(),
177            "\\eqn{",
178            &s[l + 1..u],
179            "}",
180            escape_latex(&s[u + 1..]).as_str(),
181        ]
182        .join("")
183    } else {
184        s.to_string()
185    }
186}
187
188// docstrings in R lead with a one-line title
189fn generate_constructor_title(name: &String) -> String {
190    if name.starts_with("make") {
191        format!(
192            "{} constructor\n\n",
193            name.trim_start_matches("make_").replace("_", " ")
194        )
195    } else {
196        String::new()
197    }
198}
199
200/// generate a documentation block for the current function, with the function description, args, and return
201/// in Roxygen format: https://mpn.metworx.com/packages/roxygen2/7.1.1/articles/rd-formatting.html
202fn generate_doc_block(
203    module_name: &str,
204    func: &Function,
205    hierarchy: &HashMap<String, Vec<String>>,
206) -> String {
207    let title = generate_constructor_title(&func.name);
208    let description = (func.description.as_ref())
209        .map(|v| {
210            v.split("\n")
211                .map(escape_latex)
212                .collect::<Vec<_>>()
213                .join("\n")
214        })
215        .map(|v| format!("{title}{}\n", v))
216        .unwrap_or_else(String::new);
217
218    let concept = format!("@concept {}\n", module_name);
219
220    let doc_args = (func.args.iter())
221        .map(|v| generate_docstring_arg(v))
222        .collect::<Vec<String>>()
223        .join("\n");
224
225    let export = if !func.name.starts_with("_") && module_name != "data" {
226        "\n@export"
227    } else {
228        ""
229    };
230
231    format!(
232        r#"{description}
233{concept}{doc_args}{ret_arg}{examples}{export}"#,
234        concept = concept,
235        description = description,
236        doc_args = doc_args,
237        ret_arg = generate_docstring_return_arg(&func.ret, hierarchy),
238        examples = generate_docstring_examples(module_name, func),
239    )
240    .split("\n")
241    .map(|l| format!("#' {}", l).trim().to_string())
242    .collect::<Vec<_>>()
243    .join("\n")
244}
245
246/// generate a documentation block for a then_* partial constructor, with the function description, args, and return
247/// in Roxygen format: https://mpn.metworx.com/packages/roxygen2/7.1.1/articles/rd-formatting.html
248fn generate_then_doc_block(
249    module_name: &str,
250    func: &Function,
251    hierarchy: &HashMap<String, Vec<String>>,
252) -> String {
253    let title = generate_constructor_title(&func.name);
254    let offset = if func.supports_partial { 2 } else { 0 };
255
256    let doc_args = (func.args[offset..].iter())
257        .map(|v| generate_docstring_arg(v))
258        .collect::<Vec<String>>()
259        .join("\n");
260
261    format!(
262        r#"partial {title}See documentation for [{func_name}()] for details.
263
264@concept {module_name}
265@param lhs The prior transformation or metric space.
266{doc_args}{ret_arg}{examples}
267@export"#,
268        func_name = func.name,
269        ret_arg = generate_docstring_return_arg(&func.ret, hierarchy),
270        examples = generate_docstring_examples(module_name, func),
271    )
272    .split("\n")
273    .map(|l| format!("#' {}", l).trim().to_string())
274    .collect::<Vec<_>>()
275    .join("\n")
276}
277
278/// generate the part of a docstring corresponding to an argument
279fn generate_docstring_arg(arg: &Argument) -> String {
280    let name = sanitize_r(arg.name(), arg.is_type);
281    format!(
282        r#"@param {name} {description}"#,
283        name = name,
284        description = arg
285            .description
286            .as_ref()
287            .map(|v| escape_latex(v.as_str()))
288            .unwrap_or_else(|| "undocumented".to_string())
289    )
290}
291
292/// generate the part of a docstring corresponding to a return argument
293fn generate_docstring_return_arg(
294    arg: &Argument,
295    hierarchy: &HashMap<String, Vec<String>>,
296) -> String {
297    let description = if let Some(description) = &arg.description {
298        description.clone()
299    } else if let Some(type_) = arg.python_type_hint(hierarchy) {
300        type_
301    } else {
302        return String::new();
303    };
304    format!("\n@return {description}")
305}
306
307fn generate_docstring_examples(module_name: &str, func: &Function) -> String {
308    let example_path = format!("src/{}/code/{}.R", &module_name, &func.name);
309    match fs::read_to_string(example_path) {
310        Ok(example) => format!("\n@examples\n{example}"),
311        Err(_) => "".to_string(),
312    }
313}
314
315/// generate the function body, consisting of type args formatters, data converters, and the call
316/// - type arg formatters make every type arg a RuntimeType, and construct derived RuntimeTypes
317/// - data converters convert from python to c representations according to the formatted type args
318/// - the call constructs and retrieves the ffi function name, sets ctypes,
319///     makes the call, handles errors, and converts the response to python
320fn generate_r_body(module_name: &str, func: &Function) -> String {
321    format!(
322        r#"{deprecated}{flag_checker}{type_arg_formatter}{logger}{assert_is_similar}
323{make_call}
324output"#,
325        deprecated = generate_deprecated(func),
326        flag_checker = generate_flag_check(&func.features),
327        type_arg_formatter = generate_type_arg_formatter(func),
328        assert_is_similar = generate_assert_is_similar(func),
329        logger = generate_logger(module_name, func, false),
330        make_call = generate_wrapper_call(module_name, func)
331    )
332}
333
334/// generate code that provides an example of the type of the type_arg
335fn generate_public_example(func: &Function, type_arg: &Argument) -> Option<String> {
336    // the json has supplied explicit instructions to find an example
337    if let Some(example) = &type_arg.example {
338        return Some(example.to_r(Some(&func.type_names())));
339    }
340
341    let type_name = type_arg.name.as_ref().unwrap();
342
343    // rewrite args to remove references to derived types
344    let mut args = func.args.clone();
345    args.iter_mut()
346        .filter(|arg| arg.rust_type.is_some())
347        .for_each(|arg| {
348            arg.rust_type = Some(flatten_type_recipe(
349                arg.rust_type.as_ref().unwrap(),
350                &func.derived_types,
351            ))
352        });
353
354    // code generation
355    args.iter().find_map(|arg| match &arg.rust_type {
356        Some(TypeRecipe::Name(name)) => (name == type_name).then(|| arg.name()),
357        Some(TypeRecipe::Nest { origin, args }) => {
358            if origin != "Vec" {
359                return None;
360            }
361            let TypeRecipe::Name(arg_name) = &args[0] else {
362                return None;
363            };
364            if arg_name != type_name {
365                return None;
366            }
367
368            Some(format!("get_first({name})", name = arg.name()))
369        }
370        _ => None,
371    })
372}
373
374/// the generated code ensures every type arg is a RuntimeType, and constructs derived RuntimeTypes
375fn generate_type_arg_formatter(func: &Function) -> String {
376    let type_names = func.type_names();
377    let type_arg_formatter: String = func.args.iter()
378        .filter(|arg| arg.is_type)
379        .map(|type_arg| {
380            let name = sanitize_r(type_arg.name(), type_arg.is_type);
381            let generics = if type_arg.generics.is_empty() {
382                "".to_string()
383            } else {
384                format!(", generics = list({})", type_arg.generics.iter()
385                    .map(|v| format!("\"{}\"", sanitize_r(v, true)))
386                    .collect::<Vec<_>>().join(", "))
387            };
388            if let Some(example) = generate_public_example(func, type_arg) {
389                format!(r#"{name} <- parse_or_infer(type_name = {name}, public_example = {example}{generics})"#)
390            } else {
391                format!(r#"{name} <- rt_parse(type_name = {name}{generics})"#)
392            }
393        })
394        // additional types that are constructed by introspecting existing types
395        .chain(func.derived_types.iter()
396            .map(|type_spec|
397                format!("{name} <- {derivation}",
398                        name = sanitize_r(type_spec.name(), true),
399                        derivation = type_spec.rust_type.as_ref().unwrap().to_r(Some(&type_names)))))
400        // substitute concrete types in for generics
401        .chain(func.args.iter()
402            .filter(|arg| !arg.generics.is_empty())
403            .map(|arg|
404                format!("{name} <- rt_substitute({name}, {args})",
405                        name=sanitize_r(arg.name(), true),
406                        args=arg.generics.iter()
407                            .map(|generic| format!("{generic} = {generic}", generic = sanitize_r(generic, true)))
408                            .collect::<Vec<_>>().join(", "))))
409        // determine types of arguments that are not type args
410        .chain(func.args.iter().filter(|arg| arg.has_implicit_type())
411            .map(|arg| {
412                let name = sanitize_r(arg.name(), arg.is_type);
413                let converter = arg.rust_type.as_ref().unwrap().to_r(Some(&type_names));
414                format!(r#".T.{name} <- {converter}"#)
415            }))
416        .collect::<Vec<_>>()
417        .join("\n");
418
419    if type_arg_formatter.is_empty() {
420        "# No type arguments to standardize.".to_string()
421    } else {
422        format!(
423            r#"# Standardize type arguments.
424{type_arg_formatter}
425"#
426        )
427    }
428}
429
430fn generate_assert_is_similar(func: &Function) -> String {
431    let assert_is_similar: String = (func.args.iter())
432        .filter(|arg| !arg.is_type)
433        .map(|arg| {
434            let expected = if arg.has_implicit_type() {
435                let name = sanitize_r(arg.name(), arg.is_type);
436                format!(".T.{name}")
437            } else {
438                arg.rust_type
439                    .as_ref()
440                    .unwrap()
441                    .to_r(Some(&func.type_names()))
442            };
443            let inferred = {
444                let name = sanitize_r(arg.name(), arg.is_type);
445                let generics = if arg.generics.is_empty() {
446                    "".to_string()
447                } else {
448                    format!(
449                        ", generics = list({})",
450                        arg.generics
451                            .iter()
452                            .map(|v| format!("\"{}\"", sanitize_r(v, true)))
453                            .collect::<Vec<_>>()
454                            .join(", ")
455                    )
456                };
457                format!(
458                    "rt_infer({name}{generics})",
459                    name = name,
460                    generics = generics
461                )
462            };
463            (expected, inferred)
464        })
465        .filter(|(expected, _)| expected != "NULL")
466        .map(|(expected, inferred)| {
467            format!("rt_assert_is_similar(expected = {expected}, inferred = {inferred})")
468        })
469        .collect::<Vec<_>>()
470        .join("\n");
471
472    if assert_is_similar.is_empty() {
473        "".to_string()
474    } else {
475        format!(
476            r#"
477# Assert that arguments are correctly typed.
478{assert_is_similar}
479"#
480        )
481    }
482}
483
484// generates the `log_ <- ...` code that tracks arguments
485fn generate_logger(module_name: &str, func: &Function, then: bool) -> String {
486    let func_name = if then {
487        func.name.replacen("make_", "then_", 1)
488    } else {
489        func.name.clone()
490    };
491    let offset = if then && func.supports_partial { 2 } else { 0 };
492    let keys = func.args[offset..]
493        .iter()
494        .map(|arg| format!("\"{}\"", arg.name()))
495        .collect::<Vec<_>>()
496        .join(", ");
497
498    let vals = func.args[offset..]
499        .iter()
500        .map(|arg| {
501            let r_name = sanitize_r(arg.name().clone(), arg.is_type);
502            generate_recipe_logger(arg.rust_type.clone().unwrap_or(TypeRecipe::None), r_name)
503        })
504        .collect::<Vec<_>>()
505        .join(", ");
506
507    format!(
508        r#"
509log_ <- new_constructor_log("{func_name}", "{module_name}", new_hashtab(
510  list({keys}),
511  list({vals})
512))
513"#
514    )
515}
516
517// adds unboxes wherever possible to `name`, based on its type `recipe`
518fn generate_recipe_logger(recipe: TypeRecipe, name: String) -> String {
519    match recipe {
520        TypeRecipe::Name(ty) if ATOM_TYPES.contains(&ty.as_str()) => format!("unbox2({name})"),
521        TypeRecipe::Nest { origin, .. } if origin == "Tuple" => {
522            format!("lapply({name}, unbox2)")
523        }
524        _ => name,
525    }
526}
527
528/// the generated code
529/// - constructs and retrieves the ffi function name
530/// - calls the C wrapper with arguments *and* extra type info
531fn generate_wrapper_call(module_name: &str, func: &Function) -> String {
532    let args = (func.args.iter())
533        .chain(func.derived_types.iter())
534        .map(|arg| sanitize_r(arg.name(), arg.is_type))
535        .chain(
536            (func.args.iter().filter(|arg| arg.has_implicit_type()))
537                .map(|arg| sanitize_r(arg.name(), arg.is_type))
538                .map(|name| format!("rt_parse(.T.{name})")),
539        )
540        .map(|name| format!("{name},"))
541        .collect::<Vec<_>>()
542        .join(" ");
543
544    let args_str = if args.is_empty() {
545        "".to_string()
546    } else {
547        format!("\n  {args}")
548    };
549
550    let call = format!(
551        r#".Call(
552  "{module_name}__{name}",{args_str}
553  log_, PACKAGE = "opendp")"#,
554        name = func.name
555    );
556    format!(
557        r#"# Call wrapper function.
558output <- {call}"#
559    )
560}
561
562// generate call to ".Deprecated()" if needed
563fn generate_deprecated(func: &Function) -> String {
564    if let Some(deprecation) = &func.deprecation {
565        format!(
566            ".Deprecated(msg = \"{}\")\n",
567            deprecation.note.replace("\"", "\\\"")
568        )
569    } else {
570        String::default()
571    }
572}
573
574// generate code that checks that a set of feature flags are enabled
575fn generate_flag_check(features: &Vec<String>) -> String {
576    if features.is_empty() {
577        String::default()
578    } else {
579        format!(
580            "assert_features({})\n\n",
581            features
582                .iter()
583                .map(|f| format!("\"{}\"", f))
584                .collect::<Vec<_>>()
585                .join(", ")
586        )
587    }
588}
589
590// ensures that `name` is a valid R variable name, and prefixes types with dots
591fn sanitize_r<T: AsRef<str> + ToString>(name: T, is_type: bool) -> String {
592    let mut name = name.to_string();
593    if is_type {
594        name = format!(".{}", name);
595    }
596    let blacklist = ["function", "T", "F"];
597    if blacklist.contains(&name.as_ref()) {
598        name = format!("{}_", name)
599    }
600    name
601}
602
603impl Argument {
604    // R wants to resolve all types on the R side, before passing them to C
605    // some function arguments contain their own nontrivial types/TypeRecipes,
606    //     like `Vec<T>` or `measurement_input_carrier_type(this)`
607    pub fn has_implicit_type(&self) -> bool {
608        !self.is_type && !matches!(self.rust_type, Some(TypeRecipe::Name(_) | TypeRecipe::None))
609    }
610}
611
612impl Function {
613    pub fn type_names(&self) -> Vec<String> {
614        (self.args.iter())
615            .filter(|arg| arg.is_type)
616            .map(|arg| arg.name())
617            .chain(self.derived_types.iter().map(Argument::name))
618            .collect()
619    }
620}
621
622impl TypeRecipe {
623    /// translate the abstract derived_types info into R RuntimeType constructors
624    pub fn to_r(&self, sanitize_types: Option<&[String]>) -> String {
625        match self {
626            Self::Name(name) => sanitize_types
627                .map(|types| sanitize_r(name, types.contains(name)))
628                .unwrap_or_else(|| name.to_string()),
629            Self::Function { function, params } => format!(
630                "{function}({params})",
631                function = function,
632                params = params
633                    .iter()
634                    .map(|v| v.to_r(sanitize_types))
635                    .collect::<Vec<_>>()
636                    .join(", ")
637            ),
638            Self::Nest { origin, args } => format!(
639                "new_runtime_type(origin = \"{origin}\", args = list({args}))",
640                origin = origin,
641                args = args
642                    .iter()
643                    .map(|arg| arg.to_r(sanitize_types))
644                    .collect::<Vec<_>>()
645                    .join(", ")
646            ),
647            Self::None => "NULL".to_string(),
648        }
649    }
650}