1use anyhow::anyhow;
2use std::{
3 env, fs,
4 io::Write,
5 num::NonZeroUsize,
6 path::{Path, PathBuf},
7 process::Command,
8};
9use which::which;
10
11use crate::bundle::Bundle;
12
13pub fn policy(name: impl Into<String>) -> WasmPolicyBuilder {
14 WasmPolicyBuilder::new(name)
15}
16
17#[derive(Clone, Copy)]
19pub enum AotMode {
20 Executable,
26 #[cfg(feature = "wasmtime-cranelift")]
28 Cranelift,
29 None,
31}
32
33impl Default for AotMode {
34 fn default() -> Self {
35 Self::None
36 }
37}
38
39#[cfg(feature = "wasmtime-aot")]
40#[derive(Default)]
41struct WasmTimeAotOptions {
42 mode: AotMode,
43}
44
45pub struct WasmPolicyBuilder {
46 name: String,
47 paths: Vec<String>,
48 entrypoints: Vec<String>,
49 opt_level: Option<NonZeroUsize>,
50 #[cfg(feature = "wasmtime-aot")]
51 aot: WasmTimeAotOptions,
52}
53
54impl WasmPolicyBuilder {
55 pub fn new(name: impl Into<String>) -> Self {
56 Self {
57 name: name.into(),
58 paths: Vec::default(),
59 entrypoints: Vec::default(),
60 opt_level: None,
61 #[cfg(feature = "wasmtime-aot")]
62 aot: WasmTimeAotOptions::default(),
63 }
64 }
65
66 #[cfg(feature = "wasmtime-aot")]
68 #[must_use]
69 pub fn precompile_wasm(mut self, mode: AotMode) -> Self {
70 self.aot.mode = mode;
71 self
72 }
73
74 #[must_use]
75 pub fn add_entrypoint(mut self, ep: impl Into<String>) -> Self {
76 self.entrypoints.push(ep.into());
77 self
78 }
79
80 #[must_use]
81 pub fn add_entrypoints<S, I>(mut self, eps: I) -> Self
82 where
83 I: IntoIterator<Item = S>,
84 S: Into<String>,
85 {
86 self.entrypoints.extend(eps.into_iter().map(Into::into));
87 self
88 }
89
90 #[must_use]
91 pub fn add_source(mut self, path: impl Into<String>) -> Self {
92 self.paths.push(path.into());
93 self
94 }
95
96 #[must_use]
97 pub fn add_sources<S, I>(mut self, paths: I) -> Self
98 where
99 I: IntoIterator<Item = S>,
100 S: Into<String>,
101 {
102 self.paths.extend(paths.into_iter().map(Into::into));
103 self
104 }
105
106 #[must_use]
107 #[allow(clippy::missing_panics_doc)]
108 pub fn opt_level(mut self, level: usize) -> Self {
109 if level == 0 {
110 self.opt_level = None;
111 } else {
112 self.opt_level = Some(level.try_into().unwrap());
113 }
114
115 self
116 }
117
118 #[allow(clippy::missing_panics_doc, clippy::too_many_lines)]
124 pub fn compile(self) -> Result<(), anyhow::Error> {
125 if self.paths.is_empty() {
126 return Err(anyhow!("no sources provided"));
127 }
128
129 if self.entrypoints.is_empty() {
130 return Err(anyhow!("no entrypoints provided"));
131 }
132
133 let opa_executable = which("opa")?;
134
135 let root_dir = env::var("CARGO_MANIFEST_DIR")?;
136 let out_dir = env::var("OUT_DIR")?;
137 println!("cargo:rustc-env=OUT_DIR={out_dir}");
138 let out_dir = Path::new(&out_dir).join("opa");
139
140 let mut opa_cmd = Command::new(&opa_executable);
141
142 let mut input_paths = Vec::new();
143
144 for path in self.paths {
145 let p = Path::new(&path);
146
147 let input_file_path: PathBuf = if p.is_absolute() {
148 p.into()
149 } else {
150 Path::new(&root_dir).join(p)
151 };
152
153 if input_file_path.is_dir() {
154 for entry in walkdir::WalkDir::new(&input_file_path)
155 .into_iter()
156 .filter_map(Result::ok)
157 {
158 if !entry.path().extension().map_or(false, |s| s == "rego") {
159 continue;
160 }
161 input_paths.push(entry.path().into());
162 }
163 } else {
164 input_paths.push(input_file_path);
165 }
166 }
167
168 for path in &mut input_paths {
169 println!("cargo:rerun-if-changed={}", path.to_str().unwrap());
170
171 if !path.extension().map_or(false, |s| s == "rego") {
172 return Err(anyhow!("the policy file must have `.rego` extension"));
173 }
174
175 *path = path.canonicalize()?;
176 }
177
178 let output_file_name = self.name;
179 let output_file_path = out_dir.join(&format!("{output_file_name}.tar.gz"));
180
181 opa_cmd.args([
182 "build",
183 "-t",
184 "wasm",
185 "-o",
186 output_file_path.to_str().unwrap(),
187 ]);
188
189 if let Some(opt) = self.opt_level {
190 opa_cmd.arg("-O");
191 opa_cmd.arg(opt.to_string());
192 }
193
194 for entrypoint in self.entrypoints {
195 opa_cmd.arg("-e");
196 opa_cmd.arg(&entrypoint.replace('.', "/"));
197 }
198
199 for input_path in input_paths {
200 opa_cmd.arg(input_path.to_str().unwrap());
201 }
202
203 fs::create_dir_all(&out_dir)?;
204 let out = opa_cmd.output()?;
205
206 if !out.status.success() {
207 let o = String::from_utf8_lossy(&out.stdout).to_string()
208 + String::from_utf8_lossy(&out.stderr).as_ref();
209 return Err(anyhow!("opa error: {o}"));
210 }
211
212 #[cfg(feature = "wasmtime-aot")]
213 {
214 let cwasm_output_path = out_dir.join(format!("{output_file_name}.cwasm"));
215
216 match self.aot.mode {
217 AotMode::Executable => {
218 let mut bundle = Bundle::from_file(&output_file_path).unwrap();
219
220 let mut f = tempfile::NamedTempFile::new().unwrap();
221
222 f.write_all(&bundle.wasm_policies.pop().unwrap().bytes)
223 .unwrap();
224
225 let p = f.into_temp_path();
226
227 let wasmtime_executable = which("wasmtime")?;
228
229 let mut wasmtime_cmd = Command::new(wasmtime_executable);
230
231 wasmtime_cmd.args([
232 "compile",
233 "-o",
234 cwasm_output_path.to_str().unwrap(),
235 p.to_str().unwrap(),
236 ]);
237
238 let out = wasmtime_cmd.output()?;
239
240 if !out.status.success() {
241 let o = String::from_utf8_lossy(&out.stdout).to_string()
242 + String::from_utf8_lossy(&out.stderr).as_ref();
243 return Err(anyhow!("wasmtime error: {o}"));
244 }
245 }
246 #[cfg(feature = "wasmtime-cranelift")]
247 AotMode::Cranelift => {
248 let mut bundle = Bundle::from_file(&output_file_path)?;
249 let engine = wasmtime::Engine::new(
250 wasmtime::Config::default()
251 .cranelift_opt_level(wasmtime::OptLevel::SpeedAndSize),
252 )?;
253 let m = engine.precompile_module(&bundle.wasm_policies.pop().unwrap().bytes)?;
254 std::fs::write(cwasm_output_path, m)?;
255 }
256 AotMode::None => {
257 std::fs::File::create(cwasm_output_path).unwrap();
259 }
260 }
261 }
262
263 Ok(())
264 }
265}