dprint_core/plugins/wasm/
mod.rs

1/// The plugin system schema version that is incremented
2/// when there are any breaking changes.
3pub const PLUGIN_SYSTEM_SCHEMA_VERSION: u32 = 4;
4
5#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
6extern "C" {
7  fn fd_write(fd: i32, iovs: *const crate::plugins::wasm::Iovec, iovs_len: i32, nwritten: *mut i32) -> i32;
8}
9
10#[derive(serde::Serialize, serde::Deserialize)]
11#[serde(tag = "kind", content = "data")]
12pub enum JsonResponse {
13  #[serde(rename = "ok")]
14  Ok(serde_json::Value),
15  #[serde(rename = "err")]
16  Err(String),
17}
18
19pub struct WasiPrintFd(pub i32);
20
21impl std::io::Write for WasiPrintFd {
22  fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
23    #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
24    {
25      let iovec = Iovec {
26        buf: buf.as_ptr(),
27        buf_len: buf.len() as u32,
28      };
29      let mut nwritten: i32 = 0;
30      let result = unsafe { fd_write(self.0, &iovec, 1, &mut nwritten) };
31      if result == 0 {
32        Ok(nwritten as usize)
33      } else {
34        Err(std::io::Error::from_raw_os_error(result))
35      }
36    }
37    #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
38    {
39      let size = buf.len();
40      match self.0 {
41        0 => std::io::stdout().write_all(buf)?,
42        1 => std::io::stderr().write_all(buf)?,
43        _ => return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput)),
44      }
45      Ok(size)
46    }
47  }
48
49  fn flush(&mut self) -> std::io::Result<()> {
50    Ok(())
51  }
52}
53
54#[repr(C)]
55pub struct Iovec {
56  pub buf: *const u8,
57  pub buf_len: u32,
58}
59
60#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
61pub mod macros {
62  #[macro_export]
63  macro_rules! generate_plugin_code {
64    ($wasm_plugin_struct:ident, $wasm_plugin_creation:expr) => {
65      generate_plugin_code!($wasm_plugin_struct, $wasm_plugin_creation, Configuration);
66    };
67    ($wasm_plugin_struct:ident, $wasm_plugin_creation:expr, $wasm_plugin_config:ident) => {
68      struct RefStaticCell<T: Default>(std::cell::OnceCell<StaticCell<T>>);
69
70      impl<T: Default> RefStaticCell<T> {
71        pub const fn new() -> Self {
72          RefStaticCell(std::cell::OnceCell::new())
73        }
74
75        #[allow(clippy::mut_from_ref)]
76        unsafe fn get(&self) -> &mut T {
77          let inner = self.0.get_or_init(Default::default);
78          inner.get()
79        }
80
81        fn replace(&self, value: T) -> T {
82          let inner = self.0.get_or_init(Default::default);
83          inner.replace(value)
84        }
85      }
86
87      unsafe impl<T: Default> Sync for RefStaticCell<T> {}
88
89      // This is ok to do because Wasm plugins are only ever executed on a single thread.
90      // https://github.com/rust-lang/rust/issues/53639#issuecomment-790091647
91      struct StaticCell<T>(std::cell::UnsafeCell<T>);
92
93      impl<T: Default> Default for StaticCell<T> {
94        fn default() -> Self {
95          StaticCell(std::cell::UnsafeCell::new(T::default()))
96        }
97      }
98
99      impl<T> StaticCell<T> {
100        const fn new(value: T) -> Self {
101          StaticCell(std::cell::UnsafeCell::new(value))
102        }
103
104        #[allow(clippy::mut_from_ref)]
105        unsafe fn get(&self) -> &mut T {
106          &mut *self.0.get()
107        }
108
109        fn replace(&self, value: T) -> T {
110          std::mem::replace(unsafe { self.get() }, value)
111        }
112      }
113
114      unsafe impl<T> Sync for StaticCell<T> {}
115
116      static WASM_PLUGIN: StaticCell<$wasm_plugin_struct> = StaticCell::new($wasm_plugin_creation);
117
118      // HOST FORMATTING
119
120      #[link(wasm_import_module = "dprint")]
121      extern "C" {
122        fn host_has_cancelled() -> i32;
123      }
124
125      fn format_with_host(request: dprint_core::plugins::SyncHostFormatRequest) -> anyhow::Result<Option<Vec<u8>>> {
126        use std::borrow::Cow;
127
128        #[link(wasm_import_module = "dprint")]
129        extern "C" {
130          fn host_write_buffer(pointer: *const u8);
131          fn host_format(
132            file_path_ptr: *const u8,
133            file_path_len: u32,
134            start_range: u32,
135            end_range: u32,
136            override_config_ptr: *const u8,
137            override_config_len: u32,
138            file_text_ptr: *const u8,
139            file_text_len: u32,
140          ) -> u8;
141          fn host_get_formatted_text() -> u32;
142          fn host_get_error_text() -> u32;
143        }
144
145        let file_path = request.file_path.to_string_lossy();
146        let override_config = if !request.override_config.is_empty() {
147          Cow::Owned(serde_json::to_string(request.override_config).unwrap())
148        } else {
149          Cow::Borrowed("")
150        };
151        let range = request.range.unwrap_or(0..request.file_bytes.len());
152
153        return match unsafe {
154          host_format(
155            file_path.as_ptr(),
156            file_path.len() as u32,
157            range.start as u32,
158            range.end as u32,
159            override_config.as_ptr(),
160            override_config.len() as u32,
161            request.file_bytes.as_ptr(),
162            request.file_bytes.len() as u32,
163          )
164        } {
165          0 => {
166            // no change
167            Ok(None)
168          }
169          1 => {
170            // change
171            let length = unsafe { host_get_formatted_text() };
172            let formatted_text = get_bytes_from_host(length);
173            Ok(Some(formatted_text))
174          }
175          2 => {
176            // error
177            let length = unsafe { host_get_error_text() };
178            let error_text = get_string_from_host(length);
179            Err(anyhow::anyhow!("{}", error_text))
180          }
181          value => panic!("unknown host format value: {}", value),
182        };
183
184        fn get_string_from_host(length: u32) -> String {
185          String::from_utf8(get_bytes_from_host(length)).unwrap()
186        }
187
188        fn get_bytes_from_host(length: u32) -> Vec<u8> {
189          let mut index: u32 = 0;
190          let ptr = clear_shared_bytes(length as usize);
191          unsafe {
192            host_write_buffer(ptr);
193          }
194          take_from_shared_bytes()
195        }
196      }
197
198      // FORMATTING
199
200      static OVERRIDE_CONFIG: StaticCell<Option<dprint_core::configuration::ConfigKeyMap>> = StaticCell::new(None);
201      static FILE_PATH: StaticCell<Option<std::path::PathBuf>> = StaticCell::new(None);
202      static FORMATTED_TEXT: StaticCell<Option<Vec<u8>>> = StaticCell::new(None);
203      static ERROR_TEXT: StaticCell<Option<String>> = StaticCell::new(None);
204
205      #[no_mangle]
206      pub fn set_override_config() {
207        let bytes = take_from_shared_bytes();
208        let config = serde_json::from_slice(&bytes).unwrap();
209        unsafe { OVERRIDE_CONFIG.get().replace(config) };
210      }
211
212      #[no_mangle]
213      pub fn set_file_path() {
214        // convert windows back slashes to forward slashes so it works with PathBuf
215        let text = take_string_from_shared_bytes().replace("\\", "/");
216        unsafe { FILE_PATH.get().replace(std::path::PathBuf::from(text)) };
217      }
218
219      #[no_mangle]
220      pub fn format(config_id: u32) -> u8 {
221        format_inner(config_id, None)
222      }
223
224      #[no_mangle]
225      pub fn format_range(config_id: u32, range_start: u32, range_end: u32) -> u8 {
226        format_inner(config_id, Some(range_start as usize..range_end as usize))
227      }
228
229      fn format_inner(config_id: u32, range: dprint_core::plugins::FormatRange) -> u8 {
230        #[derive(Debug)]
231        struct HostCancellationToken;
232
233        impl dprint_core::plugins::CancellationToken for HostCancellationToken {
234          fn is_cancelled(&self) -> bool {
235            unsafe { host_has_cancelled() == 1 }
236          }
237        }
238
239        let config_id = dprint_core::plugins::FormatConfigId::from_raw(config_id);
240        ensure_initialized(config_id);
241        let config = unsafe {
242          if let Some(override_config) = OVERRIDE_CONFIG.get().take() {
243            std::borrow::Cow::Owned(create_resolved_config_result(config_id, override_config).config)
244          } else {
245            std::borrow::Cow::Borrowed(&get_resolved_config_result(config_id).config)
246          }
247        };
248        let file_path = unsafe { FILE_PATH.get().take().expect("Expected the file path to be set.") };
249        let file_bytes = take_from_shared_bytes();
250
251        let request = dprint_core::plugins::SyncFormatRequest::<$wasm_plugin_config> {
252          file_path: &file_path,
253          file_bytes,
254          config: &config,
255          config_id,
256          range,
257          token: &HostCancellationToken,
258        };
259        let formatted_text = unsafe { WASM_PLUGIN.get().format(request, format_with_host) };
260        match formatted_text {
261          Ok(None) => {
262            0 // no change
263          }
264          Ok(Some(formatted_text)) => {
265            unsafe { FORMATTED_TEXT.get().replace(formatted_text) };
266            1 // change
267          }
268          Err(err_text) => {
269            unsafe { ERROR_TEXT.get().replace(err_text.to_string()) };
270            2 // error
271          }
272        }
273      }
274
275      #[no_mangle]
276      pub fn get_formatted_text() -> usize {
277        let formatted_text = unsafe { FORMATTED_TEXT.get().take().expect("Expected to have formatted text.") };
278        set_shared_bytes(formatted_text)
279      }
280
281      #[no_mangle]
282      pub fn get_error_text() -> usize {
283        let error_text = unsafe { ERROR_TEXT.get().take().expect("Expected to have error text.") };
284        set_shared_bytes_str(error_text)
285      }
286
287      // INFORMATION & CONFIGURATION
288
289      static RESOLVE_CONFIGURATION_RESULT: RefStaticCell<
290        std::collections::HashMap<dprint_core::plugins::FormatConfigId, dprint_core::plugins::PluginResolveConfigurationResult<$wasm_plugin_config>>,
291      > = RefStaticCell::new();
292
293      #[no_mangle]
294      pub fn get_plugin_info() -> usize {
295        use dprint_core::plugins::PluginInfo;
296        let plugin_info = unsafe { WASM_PLUGIN.get().plugin_info() };
297        let info_json = serde_json::to_string(&plugin_info).unwrap();
298        set_shared_bytes_str(info_json)
299      }
300
301      #[no_mangle]
302      pub fn get_license_text() -> usize {
303        set_shared_bytes_str(unsafe { WASM_PLUGIN.get().license_text() })
304      }
305
306      #[no_mangle]
307      pub fn get_resolved_config(config_id: u32) -> usize {
308        let config_id = dprint_core::plugins::FormatConfigId::from_raw(config_id);
309        let bytes = serde_json::to_vec(&get_resolved_config_result(config_id).config).unwrap();
310        set_shared_bytes(bytes)
311      }
312
313      #[no_mangle]
314      pub fn get_config_diagnostics(config_id: u32) -> usize {
315        let config_id = dprint_core::plugins::FormatConfigId::from_raw(config_id);
316        let bytes = serde_json::to_vec(&get_resolved_config_result(config_id).diagnostics).unwrap();
317        set_shared_bytes(bytes)
318      }
319
320      #[no_mangle]
321      pub fn get_config_file_matching(config_id: u32) -> usize {
322        let config_id = dprint_core::plugins::FormatConfigId::from_raw(config_id);
323        let bytes = serde_json::to_vec(&get_resolved_config_result(config_id).file_matching).unwrap();
324        set_shared_bytes(bytes)
325      }
326
327      fn get_resolved_config_result<'a>(
328        config_id: dprint_core::plugins::FormatConfigId,
329      ) -> &'a dprint_core::plugins::PluginResolveConfigurationResult<$wasm_plugin_config> {
330        unsafe {
331          ensure_initialized(config_id);
332          return RESOLVE_CONFIGURATION_RESULT.get().get(&config_id).unwrap();
333        }
334      }
335
336      fn ensure_initialized(config_id: dprint_core::plugins::FormatConfigId) {
337        unsafe {
338          if !RESOLVE_CONFIGURATION_RESULT.get().contains_key(&config_id) {
339            let config_result = create_resolved_config_result(config_id, dprint_core::configuration::ConfigKeyMap::new());
340            RESOLVE_CONFIGURATION_RESULT.get().insert(config_id, config_result);
341          }
342        }
343      }
344
345      fn create_resolved_config_result(
346        config_id: dprint_core::plugins::FormatConfigId,
347        override_config: dprint_core::configuration::ConfigKeyMap,
348      ) -> dprint_core::plugins::PluginResolveConfigurationResult<$wasm_plugin_config> {
349        unsafe {
350          if let Some(config) = UNRESOLVED_CONFIG.get().get(&config_id) {
351            let mut plugin_config = config.plugin.clone();
352            for (key, value) in override_config {
353              plugin_config.insert(key, value);
354            }
355            return WASM_PLUGIN.get().resolve_config(plugin_config, &config.global);
356          }
357        }
358
359        panic!("Plugin must have config set before use (id: {:?}).", config_id);
360      }
361
362      // INITIALIZATION
363
364      static UNRESOLVED_CONFIG: RefStaticCell<std::collections::HashMap<dprint_core::plugins::FormatConfigId, dprint_core::plugins::RawFormatConfig>> =
365        RefStaticCell::new();
366
367      #[no_mangle]
368      pub fn register_config(config_id: u32) {
369        let config_id = dprint_core::plugins::FormatConfigId::from_raw(config_id);
370        let bytes = take_from_shared_bytes();
371        let config: dprint_core::plugins::RawFormatConfig = serde_json::from_slice(&bytes).unwrap();
372        unsafe {
373          UNRESOLVED_CONFIG.get().insert(config_id, config);
374          RESOLVE_CONFIGURATION_RESULT.get().remove(&config_id); // clear
375        }
376      }
377
378      #[no_mangle]
379      pub fn release_config(config_id: u32) {
380        let config_id = dprint_core::plugins::FormatConfigId::from_raw(config_id);
381        unsafe {
382          UNRESOLVED_CONFIG.get().remove(&config_id);
383          RESOLVE_CONFIGURATION_RESULT.get().remove(&config_id);
384        }
385      }
386
387      #[no_mangle]
388      pub fn check_config_updates() -> usize {
389        fn try_check_config_updates(bytes: &[u8]) -> anyhow::Result<serde_json::Value> {
390          let message: dprint_core::plugins::CheckConfigUpdatesMessage = serde_json::from_slice(&bytes)?;
391          let result = unsafe { WASM_PLUGIN.get().check_config_updates(message) }?;
392          Ok(serde_json::to_value(&result)?)
393        }
394
395        let bytes = take_from_shared_bytes();
396        let bytes = serde_json::to_vec(&match try_check_config_updates(&bytes) {
397          Ok(value) => dprint_core::plugins::wasm::JsonResponse::Ok(value),
398          Err(err) => dprint_core::plugins::wasm::JsonResponse::Err(err.to_string()),
399        })
400        .unwrap();
401        set_shared_bytes(bytes)
402      }
403
404      // LOW LEVEL SENDING AND RECEIVING
405
406      static SHARED_BYTES: StaticCell<Vec<u8>> = StaticCell::new(Vec::new());
407
408      #[no_mangle]
409      pub fn dprint_plugin_version_4() -> u32 {
410        dprint_core::plugins::wasm::PLUGIN_SYSTEM_SCHEMA_VERSION
411      }
412
413      #[no_mangle]
414      pub fn get_shared_bytes_ptr() -> *const u8 {
415        unsafe { SHARED_BYTES.get().as_ptr() }
416      }
417
418      #[no_mangle]
419      pub fn clear_shared_bytes(size: usize) -> *const u8 {
420        SHARED_BYTES.replace(vec![0; size]);
421        unsafe { SHARED_BYTES.get().as_ptr() }
422      }
423
424      fn take_string_from_shared_bytes() -> String {
425        String::from_utf8(take_from_shared_bytes()).unwrap()
426      }
427
428      fn take_from_shared_bytes() -> Vec<u8> {
429        SHARED_BYTES.replace(Vec::new())
430      }
431
432      fn set_shared_bytes_str(text: String) -> usize {
433        set_shared_bytes(text.into_bytes())
434      }
435
436      fn set_shared_bytes(bytes: Vec<u8>) -> usize {
437        let length = bytes.len();
438        SHARED_BYTES.replace(bytes);
439        length
440      }
441    };
442  }
443}