1use syn::ext::IdentExt;
2
3use crate::{MergedResult, SavvyFn, SavvyFnArg, SavvyFnType};
4
5impl SavvyFnArg {
6 pub fn arg_name_c(&self) -> String {
7 format!("c_arg__{}", self.pat.unraw())
9 }
10}
11
12impl SavvyFn {
13 pub fn get_c_args(&self) -> Vec<(String, String)> {
15 let mut out: Vec<_> = self
16 .args
17 .iter()
18 .map(|arg| {
19 let arg_name = arg.arg_name_c();
20 let ty = arg.to_c_type_string();
21 (arg_name, ty)
22 })
23 .collect();
24
25 if matches!(&self.fn_type, SavvyFnType::Method { .. }) {
26 out.insert(0, ("self__".to_string(), "SEXP".to_string()))
27 }
28 out
29 }
30
31 fn to_c_function_for_header(&self) -> String {
33 let fn_name = self.fn_name_c_header();
34 let args = self.get_c_args();
35
36 let args_sig = if args.is_empty() {
37 "void".to_string()
38 } else {
39 args.iter()
40 .map(|(arg_name, ty)| format!("{ty} {arg_name}"))
41 .collect::<Vec<String>>()
42 .join(", ")
43 };
44
45 format!("SEXP {fn_name}({args_sig});")
46 }
47
48 fn to_c_function_impl(&self) -> String {
50 let fn_name_ffi = self.fn_name_c_header();
51 let fn_name_c = self.fn_name_c_impl();
52 let args = self.get_c_args();
53
54 let (args_sig, args_call) = if args.is_empty() {
55 ("void".to_string(), "".to_string())
56 } else {
57 let args_sig = args
58 .iter()
59 .map(|(arg_name, ty)| format!("{ty} {arg_name}"))
60 .collect::<Vec<String>>()
61 .join(", ");
62
63 let args_call = args
64 .iter()
65 .map(|(arg_name, _)| arg_name.as_str())
66 .collect::<Vec<&str>>()
67 .join(", ");
68
69 (args_sig, args_call)
70 };
71
72 format!(
73 "SEXP {fn_name_c}({args_sig}) {{
74 SEXP res = {fn_name_ffi}({args_call});
75 return handle_result(res);
76}}
77"
78 )
79 }
80
81 fn to_c_function_call_entry(&self) -> String {
83 let fn_name_c = self.fn_name_c_impl();
84 let n_args = self.get_c_args().len();
85 format!(r#" {{"{fn_name_c}", (DL_FUNC) &{fn_name_c}, {n_args}}},"#)
86 }
87}
88
89pub fn generate_c_header_file(result: &MergedResult) -> String {
90 let bare_fns = result
91 .bare_fns
92 .iter()
93 .map(|x| x.to_c_function_for_header())
94 .collect::<Vec<String>>()
95 .join("\n");
96
97 let impls = result
98 .impls
99 .iter()
100 .map(|(ty, i)| {
101 let fns = i
102 .fns
103 .iter()
104 .map(|x| x.to_c_function_for_header())
105 .collect::<Vec<String>>()
106 .join("\n");
107
108 format!("\n// methods and associated functions for {ty}\n{fns}")
109 })
110 .collect::<Vec<String>>()
111 .join("\n");
112
113 format!("{bare_fns}\n{impls}\n")
114}
115
116fn generate_c_function_impl(fns: &[SavvyFn]) -> String {
117 fns.iter()
118 .map(|x| x.to_c_function_impl())
119 .collect::<Vec<String>>()
120 .join("\n")
121}
122
123fn generate_c_function_call_entry(fns: &[SavvyFn]) -> String {
124 fns.iter()
125 .filter(|x| !matches!(x.fn_type, SavvyFnType::InitFunction))
127 .map(|x| x.to_c_function_call_entry())
128 .collect::<Vec<String>>()
129 .join("\n")
130}
131
132fn generate_c_initialization(fns: &[SavvyFn]) -> String {
133 fns.iter()
134 .filter(|x| matches!(x.fn_type, SavvyFnType::InitFunction))
135 .map(|x| format!(" {}(dll);", x.fn_name_c_impl()))
136 .collect::<Vec<String>>()
137 .join("\n")
138}
139
140pub fn generate_c_impl_file(result: &MergedResult, pkg_name: &str) -> String {
141 let common_part = r#"
142// clang-format sorts includes unless SortIncludes: Never. However, the ordering
143// does matter here. So, we need to disable clang-format for safety.
144
145// clang-format off
146#include <stdint.h>
147#include <Rinternals.h>
148#include <R_ext/Parse.h>
149// clang-format on
150
151#include "rust/api.h"
152
153static uintptr_t TAGGED_POINTER_MASK = (uintptr_t)1;
154
155SEXP handle_result(SEXP res_) {
156 uintptr_t res = (uintptr_t)res_;
157
158 // An error is indicated by tag.
159 if ((res & TAGGED_POINTER_MASK) == 1) {
160 // Remove tag
161 SEXP res_aligned = (SEXP)(res & ~TAGGED_POINTER_MASK);
162
163 // Currently, there are two types of error cases:
164 //
165 // 1. Error from Rust code
166 // 2. Error from R's C API, which is caught by R_UnwindProtect()
167 //
168 if (TYPEOF(res_aligned) == CHARSXP) {
169 // In case 1, the result is an error message that can be passed to
170 // Rf_errorcall() directly.
171 Rf_errorcall(R_NilValue, "%s", CHAR(res_aligned));
172 } else {
173 // In case 2, the result is the token to restart the
174 // cleanup process on R's side.
175 R_ContinueUnwind(res_aligned);
176 }
177 }
178
179 return (SEXP)res;
180}
181"#;
182
183 let mut c_fns: Vec<String> = Vec::new();
184 let mut call_entries: Vec<String> = Vec::new();
185
186 c_fns.push(generate_c_function_impl(&result.bare_fns));
187 call_entries.push(generate_c_function_call_entry(&result.bare_fns));
188
189 for (_, i) in result.impls.iter() {
190 c_fns.push(generate_c_function_impl(i.fns.as_slice()));
191 call_entries.push(generate_c_function_call_entry(i.fns.as_slice()));
192 }
193
194 let c_fns = c_fns.join("\n");
195 let call_entries = call_entries.join("\n");
196
197 let initialization = generate_c_initialization(&result.bare_fns);
198
199 format!(
200 "{common_part}
201{c_fns}
202
203static const R_CallMethodDef CallEntries[] = {{
204{call_entries}
205 {{NULL, NULL, 0}}
206}};
207
208void R_init_{pkg_name}(DllInfo *dll) {{
209 R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
210 R_useDynamicSymbols(dll, FALSE);
211
212 // Functions for initialization, if any.
213{initialization}
214}}
215"
216 )
217}