Skip to main content

savvy_bindgen/gen/
c.rs

1use syn::ext::IdentExt;
2
3use crate::{MergedResult, SavvyFn, SavvyFnArg, SavvyFnType};
4
5impl SavvyFnArg {
6    pub fn arg_name_c(&self) -> String {
7        // to avoid conflict with C keywords, add a unique prefix
8        format!("c_arg__{}", self.pat.unraw())
9    }
10}
11
12impl SavvyFn {
13    // The return value is (pat, ty)
14    pub fn get_c_args(&self) -> Vec<(String, String)> {
15        let mut out: Vec<_> = self
16            .args
17            .iter()
18            .map(|arg| {
19                let arg_name = arg.arg_name_c();
20                let ty = arg.to_c_type_string();
21                (arg_name, ty)
22            })
23            .collect();
24
25        if matches!(&self.fn_type, SavvyFnType::Method { .. }) {
26            out.insert(0, ("self__".to_string(), "SEXP".to_string()))
27        }
28        out
29    }
30
31    /// Generate C function signature
32    fn to_c_function_for_header(&self) -> String {
33        let fn_name = self.fn_name_c_header();
34        let args = self.get_c_args();
35
36        let args_sig = if args.is_empty() {
37            "void".to_string()
38        } else {
39            args.iter()
40                .map(|(arg_name, ty)| format!("{ty} {arg_name}"))
41                .collect::<Vec<String>>()
42                .join(", ")
43        };
44
45        format!("SEXP {fn_name}({args_sig});")
46    }
47
48    /// Generate C function implementation
49    fn to_c_function_impl(&self) -> String {
50        let fn_name_ffi = self.fn_name_c_header();
51        let fn_name_c = self.fn_name_c_impl();
52        let args = self.get_c_args();
53
54        let (args_sig, args_call) = if args.is_empty() {
55            ("void".to_string(), "".to_string())
56        } else {
57            let args_sig = args
58                .iter()
59                .map(|(arg_name, ty)| format!("{ty} {arg_name}"))
60                .collect::<Vec<String>>()
61                .join(", ");
62
63            let args_call = args
64                .iter()
65                .map(|(arg_name, _)| arg_name.as_str())
66                .collect::<Vec<&str>>()
67                .join(", ");
68
69            (args_sig, args_call)
70        };
71
72        format!(
73            "SEXP {fn_name_c}({args_sig}) {{
74    SEXP res = {fn_name_ffi}({args_call});
75    return handle_result(res);
76}}
77"
78        )
79    }
80
81    /// Generate C function call entry
82    fn to_c_function_call_entry(&self) -> String {
83        let fn_name_c = self.fn_name_c_impl();
84        let n_args = self.get_c_args().len();
85        format!(r#"    {{"{fn_name_c}", (DL_FUNC) &{fn_name_c}, {n_args}}},"#)
86    }
87}
88
89pub fn generate_c_header_file(result: &MergedResult) -> String {
90    let bare_fns = result
91        .bare_fns
92        .iter()
93        .map(|x| x.to_c_function_for_header())
94        .collect::<Vec<String>>()
95        .join("\n");
96
97    let impls = result
98        .impls
99        .iter()
100        .map(|(ty, i)| {
101            let fns = i
102                .fns
103                .iter()
104                .map(|x| x.to_c_function_for_header())
105                .collect::<Vec<String>>()
106                .join("\n");
107
108            format!("\n// methods and associated functions for {ty}\n{fns}")
109        })
110        .collect::<Vec<String>>()
111        .join("\n");
112
113    format!("{bare_fns}\n{impls}\n")
114}
115
116fn generate_c_function_impl(fns: &[SavvyFn]) -> String {
117    fns.iter()
118        .map(|x| x.to_c_function_impl())
119        .collect::<Vec<String>>()
120        .join("\n")
121}
122
123fn generate_c_function_call_entry(fns: &[SavvyFn]) -> String {
124    fns.iter()
125        // initializaion functions don't need the R interface
126        .filter(|x| !matches!(x.fn_type, SavvyFnType::InitFunction))
127        .map(|x| x.to_c_function_call_entry())
128        .collect::<Vec<String>>()
129        .join("\n")
130}
131
132fn generate_c_initialization(fns: &[SavvyFn]) -> String {
133    fns.iter()
134        .filter(|x| matches!(x.fn_type, SavvyFnType::InitFunction))
135        .map(|x| format!("    {}(dll);", x.fn_name_c_impl()))
136        .collect::<Vec<String>>()
137        .join("\n")
138}
139
140pub fn generate_c_impl_file(result: &MergedResult, pkg_name: &str) -> String {
141    let common_part = r#"
142// clang-format sorts includes unless SortIncludes: Never. However, the ordering
143// does matter here. So, we need to disable clang-format for safety.
144
145// clang-format off
146#include <stdint.h>
147#include <Rinternals.h>
148#include <R_ext/Parse.h>
149// clang-format on
150
151#include "rust/api.h"
152
153static uintptr_t TAGGED_POINTER_MASK = (uintptr_t)1;
154
155SEXP handle_result(SEXP res_) {
156    uintptr_t res = (uintptr_t)res_;
157
158    // An error is indicated by tag.
159    if ((res & TAGGED_POINTER_MASK) == 1) {
160        // Remove tag
161        SEXP res_aligned = (SEXP)(res & ~TAGGED_POINTER_MASK);
162
163        // Currently, there are two types of error cases:
164        //
165        //   1. Error from Rust code
166        //   2. Error from R's C API, which is caught by R_UnwindProtect()
167        //
168        if (TYPEOF(res_aligned) == CHARSXP) {
169            // In case 1, the result is an error message that can be passed to
170            // Rf_errorcall() directly.
171            Rf_errorcall(R_NilValue, "%s", CHAR(res_aligned));
172        } else {
173            // In case 2, the result is the token to restart the
174            // cleanup process on R's side.
175            R_ContinueUnwind(res_aligned);
176        }
177    }
178
179    return (SEXP)res;
180}
181"#;
182
183    let mut c_fns: Vec<String> = Vec::new();
184    let mut call_entries: Vec<String> = Vec::new();
185
186    c_fns.push(generate_c_function_impl(&result.bare_fns));
187    call_entries.push(generate_c_function_call_entry(&result.bare_fns));
188
189    for (_, i) in result.impls.iter() {
190        c_fns.push(generate_c_function_impl(i.fns.as_slice()));
191        call_entries.push(generate_c_function_call_entry(i.fns.as_slice()));
192    }
193
194    let c_fns = c_fns.join("\n");
195    let call_entries = call_entries.join("\n");
196
197    let initialization = generate_c_initialization(&result.bare_fns);
198
199    format!(
200        "{common_part}
201{c_fns}
202
203static const R_CallMethodDef CallEntries[] = {{
204{call_entries}
205    {{NULL, NULL, 0}}
206}};
207
208void R_init_{pkg_name}(DllInfo *dll) {{
209    R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
210    R_useDynamicSymbols(dll, FALSE);
211
212    // Functions for initialization, if any.
213{initialization}
214}}
215"
216    )
217}