hyperlight_js/sandbox/
proto_js_sandbox.rs1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::time::SystemTime;
4
5use anyhow::Context;
6use hyperlight_host::sandbox::SandboxConfiguration;
7use hyperlight_host::{new_error, GuestBinary, Result, UninitializedSandbox};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use tracing::{instrument, Level};
11
12use super::js_sandbox::JSSandbox;
13use super::sandbox_builder::SandboxBuilder;
14use crate::sandbox::host_fn::{Function, HostModule};
15use crate::sandbox::metrics::SandboxMetricsGuard;
16use crate::HostPrintFn;
17
18pub struct ProtoJSSandbox {
21 inner: UninitializedSandbox,
22 host_modules: HashMap<String, HostModule>,
23 _metric_guard: SandboxMetricsGuard<ProtoJSSandbox>,
25}
26
27impl ProtoJSSandbox {
28 #[instrument(err(Debug), skip_all, level=Level::INFO, fields(version= env!("CARGO_PKG_VERSION")))]
29 pub(super) fn new(
30 guest_binary: GuestBinary,
31 cfg: Option<SandboxConfiguration>,
32 host_print_writer: Option<HostPrintFn>,
33 ) -> Result<Self> {
34 let mut usbox: UninitializedSandbox = UninitializedSandbox::new(guest_binary, cfg)?;
35
36 if let Some(host_print_writer) = host_print_writer {
38 usbox.register_print(host_print_writer)?;
39 }
40
41 fn current_time_micros() -> hyperlight_host::Result<u64> {
43 Ok(SystemTime::now()
44 .duration_since(SystemTime::UNIX_EPOCH)
45 .with_context(|| "Unable to get duration since epoch")
46 .map(|d| d.as_micros() as u64)?)
47 }
48
49 usbox.register("CurrentTimeMicros", current_time_micros)?;
50
51 Ok(Self {
52 inner: usbox,
53 host_modules: HashMap::new(),
54 _metric_guard: SandboxMetricsGuard::new(),
55 })
56 }
57
58 #[instrument(err(Debug), skip_all, level=Level::INFO)]
62 pub fn set_module_loader<Fs: crate::resolver::FileSystem + Clone + 'static>(
63 mut self,
64 file_system: Fs,
65 ) -> Result<Self> {
66 use std::path::PathBuf;
67
68 use oxc_resolver::{ResolveOptions, ResolverGeneric};
69
70 let resolver = ResolverGeneric::new_with_file_system(
71 file_system.clone(),
72 ResolveOptions {
73 extensions: vec![".js".into(), ".mjs".into()],
74 condition_names: vec!["import".into(), "module".into()],
75 ..Default::default()
76 },
77 );
78
79 self.inner.register(
80 "ResolveModule",
81 move |base: String, specifier: String| -> hyperlight_host::Result<String> {
82 tracing::debug!(
83 base = %base,
84 specifier = %specifier,
85 "Resolving module"
86 );
87
88 let resolved = resolver.resolve(&base, &specifier).map_err(|e| {
89 new_error!(
90 "Failed to resolve module '{}' from '{}': {:?}",
91 specifier,
92 base,
93 e
94 )
95 })?;
96
97 Ok(resolved.path().to_string_lossy().to_string())
98 },
99 )?;
100
101 self.inner.register(
102 "LoadModule",
103 move |path: String| -> hyperlight_host::Result<String> {
104 tracing::debug!(path = %path, "Loading module");
105 let path_buf = PathBuf::from(&path);
106 let source = file_system
107 .read_to_string(&path_buf)
108 .map_err(|e| new_error!("Failed to read module '{}': {}", path, e))?;
109
110 Ok(source)
111 },
112 )?;
113
114 Ok(self)
115 }
116
117 #[instrument(err(Debug), skip(self), level=Level::INFO)]
119 pub fn load_runtime(mut self) -> Result<JSSandbox> {
120 let host_modules = self.host_modules;
121
122 let host_modules_json = serde_json::to_string(&host_modules)?;
123
124 self.inner.register(
125 "CallHostJsFunction",
126 move |module_name: String, func_name: String, args: String| -> Result<String> {
127 let module = host_modules
128 .get(&module_name)
129 .ok_or_else(|| new_error!("Host module '{}' not found", module_name))?;
130 let func = module.get(&func_name).ok_or_else(|| {
131 new_error!(
132 "Host function '{}' not found in module '{}'",
133 func_name,
134 module_name
135 )
136 })?;
137 func(args)
138 },
139 )?;
140
141 let mut multi_use_sandbox = self.inner.evolve()?;
142
143 let _: () = multi_use_sandbox.call("RegisterHostModules", host_modules_json)?;
144
145 JSSandbox::new(multi_use_sandbox)
146 }
147
148 #[instrument(skip(self), level=Level::INFO)]
181 pub fn host_module(&mut self, name: impl Into<String> + Debug) -> &mut HostModule {
182 self.host_modules.entry(name.into()).or_default()
183 }
184
185 #[instrument(err(Debug), skip(self, func), level=Level::INFO)]
191 pub fn register<Output: Serialize, Args: DeserializeOwned>(
192 &mut self,
193 module: impl Into<String> + Debug,
194 name: impl Into<String> + Debug,
195 func: impl Function<Output, Args> + Send + Sync + 'static,
196 ) -> Result<()> {
197 self.host_module(module).register(name, func);
198 Ok(())
199 }
200}
201
202impl std::fmt::Debug for ProtoJSSandbox {
203 #[instrument(skip_all, level=Level::TRACE)]
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 f.debug_struct("ProtoJsSandbox").finish()
206 }
207}
208
209impl Default for ProtoJSSandbox {
210 #[instrument(skip_all, level=Level::INFO)]
211 fn default() -> Self {
212 #[allow(clippy::unwrap_used)]
215 SandboxBuilder::new().build().unwrap()
216 }
217}