1#![warn(missing_docs, nonstandard_style)]
2#![doc = include_str!(concat!(env!("OUT_DIR"), "/README.md"))]
3
4use std::rc::Rc;
5
6pub use arma_rs_proc::{FromArma, IntoArma, arma};
7
8#[cfg(feature = "extension")]
9use crossbeam_channel::{Receiver, Sender, unbounded};
10#[cfg(feature = "extension")]
11pub use libc;
12
13#[cfg(all(target_os = "windows", target_arch = "x86"))]
14pub use link_args;
15
16#[cfg(feature = "extension")]
17#[macro_use]
18extern crate log;
19
20mod flags;
21
22mod value;
23pub use value::{DirectReturn, FromArma, FromArmaError, IntoArma, Value, loadout};
24
25#[cfg(feature = "extension")]
26mod call_context;
27#[cfg(feature = "extension")]
28use call_context::{ArmaCallContext, ArmaContextManager};
29#[cfg(feature = "extension")]
30pub use call_context::{CallContext, CallContextStackTrace, Caller, Mission, Server, Source};
31#[cfg(feature = "extension")]
32mod ext_result;
33#[cfg(feature = "extension")]
34pub use ext_result::IntoExtResult;
35#[cfg(feature = "extension")]
36mod command;
37#[cfg(feature = "extension")]
38pub use command::*;
39#[cfg(feature = "extension")]
40pub mod context;
41#[cfg(feature = "extension")]
42pub use context::*;
43#[cfg(feature = "extension")]
44mod group;
45#[cfg(feature = "extension")]
46pub use group::Group;
47#[cfg(feature = "extension")]
48pub mod testing;
49#[cfg(feature = "extension")]
50pub use testing::Result;
51
52#[cfg(feature = "extension")]
53#[doc(hidden)]
54pub type Callback = extern "system" fn(
56 *const libc::c_char,
57 *const libc::c_char,
58 *const libc::c_char,
59) -> libc::c_int;
60pub type ContextRequest = unsafe extern "system" fn();
62
63#[cfg(feature = "extension")]
64enum CallbackMessage {
65 Call(String, String, Option<Value>),
66 Terminate,
67}
68
69#[cfg(feature = "extension")]
70pub type State = state::TypeMap![Send + Sync];
72
73#[cfg(windows)]
74static CONSOLE_ALLOCATED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
76
77#[unsafe(no_mangle)]
78#[allow(non_upper_case_globals, reason = "This is a C API")]
79pub static mut RVExtensionFeatureFlags: u64 = flags::RV_CONTEXT_NO_DEFAULT_CALL;
81
82#[cfg(feature = "extension")]
85pub struct Extension {
86 version: String,
87 group: group::InternalGroup,
88 allow_no_args: bool,
89 callback: Option<Callback>,
90 callback_channel: (Sender<CallbackMessage>, Receiver<CallbackMessage>),
91 callback_thread: Option<std::thread::JoinHandle<()>>,
92 context_manager: Rc<ArmaContextManager>,
93 pre218_clear_context_override: bool,
94}
95
96#[cfg(feature = "extension")]
97impl Extension {
98 #[must_use]
99 pub fn build() -> ExtensionBuilder {
101 ExtensionBuilder {
102 version: String::from("0.0.0"),
103 group: Group::new(),
104 allow_no_args: false,
105 }
106 }
107}
108
109#[cfg(feature = "extension")]
110impl Extension {
111 #[must_use]
112 pub fn version(&self) -> &str {
114 &self.version
115 }
116
117 #[must_use]
118 pub const fn allow_no_args(&self) -> bool {
124 self.allow_no_args
125 }
126
127 #[doc(hidden)]
128 pub fn register_callback(&mut self, callback: Callback) {
130 self.callback = Some(callback);
131 }
132
133 #[doc(hidden)]
134 pub unsafe fn handle_call_context(&mut self, args: *mut *mut i8, count: libc::c_int) {
138 self.context_manager
139 .replace(Some(ArmaCallContext::from_arma(args, count)));
140 }
141
142 #[must_use]
143 pub fn context(&self) -> Context {
145 Context::new(
146 self.callback_channel.0.clone(),
147 GlobalContext::new(self.version.clone(), self.group.state.clone()),
148 GroupContext::new(self.group.state.clone()),
149 )
150 }
151
152 #[doc(hidden)]
153 pub unsafe fn handle_call(
157 &self,
158 function: *mut libc::c_char,
159 output: *mut libc::c_char,
160 size: libc::size_t,
161 args: Option<*mut *mut i8>,
162 count: Option<libc::c_int>,
163 clear_call_context: bool,
164 ) -> libc::c_int {
165 if clear_call_context && !self.pre218_clear_context_override {
166 self.context_manager.replace(None);
167 }
168 let function = if let Ok(cstring) = unsafe { std::ffi::CStr::from_ptr(function).to_str() } {
169 cstring.to_string()
170 } else {
171 return 1;
172 };
173 match function.as_str() {
174 #[cfg(windows)]
175 "::console" => {
176 if !CONSOLE_ALLOCATED.swap(true, std::sync::atomic::Ordering::SeqCst) {
177 let _ = windows::Win32::System::Console::AllocConsole();
178 }
179 0
180 }
181 _ => self.group.handle(
182 self.context().with_buffer_size(size),
183 self.context_manager.as_ref(),
184 &function,
185 output,
186 size,
187 args,
188 count,
189 ),
190 }
191 }
192
193 #[must_use]
194 pub fn testing(self) -> testing::Extension {
196 testing::Extension::new(self)
197 }
198
199 #[doc(hidden)]
200 pub fn run_callbacks(&mut self) {
202 let callback = self.callback;
203 let (_, rx) = self.callback_channel.clone();
204 self.callback_thread = Some(std::thread::spawn(move || {
205 while let Ok(CallbackMessage::Call(name, func, data)) = rx.recv() {
206 if let Some(c) = callback {
207 let Ok(name) = std::ffi::CString::new(name) else {
208 error!("callback name was not valid");
209 continue;
210 };
211 let Ok(func) = std::ffi::CString::new(func) else {
212 error!("callback func was not valid");
213 continue;
214 };
215 let Ok(data) =
216 std::ffi::CString::new(data.map_or_else(
217 String::new,
218 |value| match value {
219 Value::String(s) => s,
220 v => v.to_string(),
221 },
222 ))
223 else {
224 error!("callback data was not valid");
225 continue;
226 };
227
228 let (name, func, data) = (name.into_raw(), func.into_raw(), data.into_raw());
229 loop {
230 if c(name, func, data) >= 0 {
231 break;
232 }
233 std::thread::sleep(std::time::Duration::from_millis(1));
234 }
235 unsafe {
236 drop(std::ffi::CString::from_raw(name));
237 drop(std::ffi::CString::from_raw(func));
238 drop(std::ffi::CString::from_raw(data));
239 }
240 }
241 }
242 }));
243 }
244}
245
246#[cfg(feature = "extension")]
247impl Drop for Extension {
248 fn drop(&mut self) {
250 if let Some(thread) = self.callback_thread.take() {
251 let (tx, _) = &self.callback_channel;
252 tx.send(CallbackMessage::Terminate)
253 .expect("Failed to send terminate message to callback thread");
254 thread.join().expect("Failed to join callback thread");
255 }
256 }
257}
258
259#[cfg(feature = "extension")]
261pub struct ExtensionBuilder {
262 version: String,
263 group: Group,
264 allow_no_args: bool,
265}
266
267#[cfg(feature = "extension")]
268impl ExtensionBuilder {
269 #[inline]
270 #[must_use]
271 pub fn version(mut self, version: String) -> Self {
273 self.version = version;
274 self
275 }
276
277 #[inline]
278 #[must_use]
279 pub fn group<S>(mut self, name: S, group: Group) -> Self
281 where
282 S: Into<String>,
283 {
284 self.group = self.group.group(name.into(), group);
285 self
286 }
287
288 #[inline]
289 #[must_use]
290 pub fn state<T>(mut self, state: T) -> Self
292 where
293 T: Send + Sync + 'static,
294 {
295 self.group = self.group.state(state);
296 self
297 }
298
299 #[inline]
300 #[must_use]
301 pub fn freeze_state(mut self) -> Self {
303 self.group = self.group.freeze_state();
304 self
305 }
306
307 #[inline]
308 #[must_use]
309 pub const fn allow_no_args(mut self) -> Self {
315 self.allow_no_args = true;
316 self
317 }
318
319 #[inline]
320 #[must_use]
321 pub fn command<S, F, I, R>(mut self, name: S, handler: F) -> Self
323 where
324 S: Into<String>,
325 F: Factory<I, R> + 'static,
326 {
327 self.group = self.group.command(name, handler);
328 self
329 }
330
331 #[inline]
332 #[must_use]
333 pub fn finish(self) -> Extension {
335 #[expect(unused_mut, reason = "Only used on Windows release")]
336 let mut pre218 = false;
337 #[allow(unused_variables)]
338 let function_name =
339 std::ffi::CString::new("RVExtensionRequestContext").expect("CString::new failed");
340 #[cfg(all(windows, not(debug_assertions)))]
341 let request_context: ContextRequest = {
342 let handle = unsafe { winapi::um::libloaderapi::GetModuleHandleW(std::ptr::null()) };
343 if handle.is_null() {
344 panic!("GetModuleHandleW failed");
345 }
346 let func_address =
347 unsafe { winapi::um::libloaderapi::GetProcAddress(handle, function_name.as_ptr()) };
348 if func_address.is_null() {
349 pre218 = true;
350 empty_request_context
351 } else {
352 unsafe { std::mem::transmute(func_address) }
353 }
354 };
355 #[cfg(all(not(windows), not(debug_assertions)))]
356 let request_context: ContextRequest = {
357 let handle = unsafe { libc::dlopen(std::ptr::null(), libc::RTLD_LAZY) };
358 if handle.is_null() {
359 panic!("Failed to open handle to current process");
360 }
361 let func_address = unsafe { libc::dlsym(handle, function_name.as_ptr()) };
362 if func_address.is_null() {
363 pre218 = true;
364 empty_request_context
365 } else {
366 let func = unsafe { std::mem::transmute(func_address) };
367 unsafe { libc::dlclose(handle) };
368 func
369 }
370 };
371
372 #[cfg(debug_assertions)]
373 let request_context = empty_request_context;
374
375 Extension {
376 version: self.version,
377 group: self.group.into(),
378 allow_no_args: self.allow_no_args,
379 callback: None,
380 callback_channel: unbounded(),
381 callback_thread: None,
382 context_manager: Rc::new(ArmaContextManager::new(request_context)),
383 pre218_clear_context_override: pre218,
384 }
385 }
386}
387
388const unsafe extern "system" fn empty_request_context() {}
389
390#[doc(hidden)]
391#[cfg(feature = "extension")]
399pub unsafe fn write_cstr(
400 string: String,
401 ptr: *mut libc::c_char,
402 buf_size: libc::size_t,
403) -> Option<libc::size_t> {
404 if string.is_empty() {
405 return Some(0);
406 }
407
408 let cstr = std::ffi::CString::new(string).ok()?;
409 let len_to_copy = cstr.as_bytes().len();
410 if len_to_copy >= buf_size {
411 return None;
412 }
413
414 unsafe { ptr.copy_from(cstr.as_ptr(), len_to_copy) };
415 unsafe { ptr.add(len_to_copy).write(0x00) };
416 Some(len_to_copy)
417}
418
419#[cfg(all(test, feature = "extension"))]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn write_size_zero() {
425 const BUF_SIZE: libc::size_t = 0;
426 let mut buf = [0; BUF_SIZE];
427 let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
428
429 assert_eq!(result, None);
430 assert_eq!(buf, [0; BUF_SIZE]);
431 }
432
433 #[test]
434 fn write_size_zero_empty() {
435 const BUF_SIZE: libc::size_t = 0;
436 let mut buf = [0; BUF_SIZE];
437 let result = unsafe { write_cstr(String::new(), buf.as_mut_ptr(), BUF_SIZE) };
438
439 assert_eq!(result, Some(0));
440 assert_eq!(buf, [0; BUF_SIZE]);
441 }
442
443 #[test]
444 fn write_size_one() {
445 const BUF_SIZE: libc::size_t = 1;
446 let mut buf = [0; BUF_SIZE];
447 let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
448
449 assert_eq!(result, None);
450 assert_eq!(buf, [0; BUF_SIZE]);
451 }
452
453 #[test]
454 fn write_size_one_empty() {
455 const BUF_SIZE: libc::size_t = 1;
456 let mut buf = [0; BUF_SIZE];
457 let result = unsafe { write_cstr(String::new(), buf.as_mut_ptr(), BUF_SIZE) };
458
459 assert_eq!(result, Some(0));
460 assert_eq!(buf, [0; BUF_SIZE]);
461 }
462
463 #[test]
464 fn write_empty() {
465 const BUF_SIZE: libc::size_t = 7;
466 let mut buf = [0; BUF_SIZE];
467 let result = unsafe { write_cstr(String::new(), buf.as_mut_ptr(), BUF_SIZE) };
468
469 assert_eq!(result, Some(0));
470 assert_eq!(buf, [0; BUF_SIZE]);
471 }
472
473 #[test]
474 fn write_half() {
475 const BUF_SIZE: libc::size_t = 7;
476 let mut buf = [0; BUF_SIZE];
477 let result = unsafe { write_cstr("foo".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
478
479 assert_eq!(result, Some(3));
480 assert_eq!(buf, (b"foo\0\0\0\0").map(u8::cast_signed));
481 }
482
483 #[test]
484 fn write_full() {
485 const BUF_SIZE: libc::size_t = 7;
486 let mut buf = [0; BUF_SIZE];
487 let result = unsafe { write_cstr("foobar".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
488
489 assert_eq!(result, Some(6));
490 assert_eq!(buf, (b"foobar\0").map(u8::cast_signed));
491 }
492
493 #[test]
494 fn write_overflow() {
495 const BUF_SIZE: libc::size_t = 7;
496 let mut buf = [0; BUF_SIZE];
497 let result = unsafe { write_cstr("foo bar".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
498
499 assert_eq!(result, None);
500 assert_eq!(buf, [0; BUF_SIZE]);
501 }
502
503 #[test]
504 fn write_overwrite() {
505 const BUF_SIZE: libc::size_t = 7;
506 let mut buf = (b"zzzzzz\0").map(u8::cast_signed);
507 let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
508
509 assert_eq!(result, Some(1));
510 assert_eq!(buf, (b"a\0zzzz\0").map(u8::cast_signed));
511 }
512}