arma_rs/
lib.rs

1#![warn(missing_docs, nonstandard_style)]
2#![doc = include_str!(concat!(env!("OUT_DIR"), "/README.md"))]
3
4use std::rc::Rc;
5
6pub use arma_rs_proc::{arma, FromArma, IntoArma};
7
8#[cfg(feature = "extension")]
9use crossbeam_channel::{unbounded, Receiver, Sender};
10#[cfg(feature = "extension")]
11pub use libc;
12
13#[cfg(all(target_os = "windows", target_arch = "x86"))]
14pub use link_args;
15
16#[cfg(feature = "extension")]
17#[macro_use]
18extern crate log;
19
20mod flags;
21
22mod value;
23pub use value::{loadout, DirectReturn, FromArma, FromArmaError, IntoArma, Value};
24
25#[cfg(feature = "extension")]
26mod call_context;
27#[cfg(feature = "extension")]
28use call_context::{ArmaCallContext, ArmaContextManager};
29#[cfg(feature = "extension")]
30pub use call_context::{CallContext, CallContextStackTrace, Caller, Mission, Server, Source};
31#[cfg(feature = "extension")]
32mod ext_result;
33#[cfg(feature = "extension")]
34pub use ext_result::IntoExtResult;
35#[cfg(feature = "extension")]
36mod command;
37#[cfg(feature = "extension")]
38pub use command::*;
39#[cfg(feature = "extension")]
40pub mod context;
41#[cfg(feature = "extension")]
42pub use context::*;
43#[cfg(feature = "extension")]
44mod group;
45#[cfg(feature = "extension")]
46pub use group::Group;
47#[cfg(feature = "extension")]
48pub mod testing;
49#[cfg(feature = "extension")]
50pub use testing::Result;
51
52#[cfg(all(windows, feature = "extension"))]
53#[doc(hidden)]
54/// Used by generated code to call back into Arma
55pub type Callback = extern "stdcall" fn(
56    *const libc::c_char,
57    *const libc::c_char,
58    *const libc::c_char,
59) -> libc::c_int;
60#[cfg(all(not(windows), feature = "extension"))]
61#[doc(hidden)]
62/// Used by generated code to call back into Arma
63pub type Callback =
64    extern "C" fn(*const libc::c_char, *const libc::c_char, *const libc::c_char) -> libc::c_int;
65/// Requests a call context from Arma
66pub type ContextRequest = unsafe extern "C" fn();
67
68#[cfg(feature = "extension")]
69enum CallbackMessage {
70    Call(String, String, Option<Value>),
71    Terminate,
72}
73
74#[cfg(feature = "extension")]
75/// State TypeMap that can hold at most one value per type key.
76pub type State = state::TypeMap![Send + Sync];
77
78#[cfg(windows)]
79/// Allows a console to be allocated for the extension.
80static CONSOLE_ALLOCATED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
81
82#[no_mangle]
83#[allow(non_upper_case_globals, reason = "This is a C API")]
84/// Feature flags read on each callExtension call.
85pub static mut RVExtensionFeatureFlags: u64 = flags::RV_CONTEXT_NO_DEFAULT_CALL;
86
87/// Contains all the information about your extension
88/// This is used by the generated code to interface with Arma
89#[cfg(feature = "extension")]
90pub struct Extension {
91    version: String,
92    group: group::InternalGroup,
93    allow_no_args: bool,
94    callback: Option<Callback>,
95    callback_channel: (Sender<CallbackMessage>, Receiver<CallbackMessage>),
96    callback_thread: Option<std::thread::JoinHandle<()>>,
97    context_manager: Rc<ArmaContextManager>,
98    pre218_clear_context_override: bool,
99}
100
101#[cfg(feature = "extension")]
102impl Extension {
103    #[must_use]
104    /// Creates a new extension.
105    pub fn build() -> ExtensionBuilder {
106        ExtensionBuilder {
107            version: String::from("0.0.0"),
108            group: Group::new(),
109            allow_no_args: false,
110        }
111    }
112}
113
114#[cfg(feature = "extension")]
115impl Extension {
116    #[must_use]
117    /// Returns the version of the extension.
118    pub fn version(&self) -> &str {
119        &self.version
120    }
121
122    #[must_use]
123    /// Returns if the extension can be called without any arguments.
124    /// Example:
125    /// ```sqf
126    /// "my_ext" callExtension "my_func"
127    /// ```
128    pub const fn allow_no_args(&self) -> bool {
129        self.allow_no_args
130    }
131
132    #[doc(hidden)]
133    /// Called by generated code, do not call directly.
134    pub fn register_callback(&mut self, callback: Callback) {
135        self.callback = Some(callback);
136    }
137
138    #[doc(hidden)]
139    /// Called by generated code, do not call directly.
140    /// # Safety
141    /// This function is unsafe because it interacts with the C API.
142    pub unsafe fn handle_call_context(&mut self, args: *mut *mut i8, count: libc::c_int) {
143        self.context_manager
144            .replace(Some(ArmaCallContext::from_arma(args, count)));
145    }
146
147    #[must_use]
148    /// Get a context for interacting with Arma
149    pub fn context(&self) -> Context {
150        Context::new(
151            self.callback_channel.0.clone(),
152            GlobalContext::new(self.version.clone(), self.group.state.clone()),
153            GroupContext::new(self.group.state.clone()),
154        )
155    }
156
157    #[doc(hidden)]
158    /// Called by generated code, do not call directly.
159    /// # Safety
160    /// This function is unsafe because it interacts with the C API.
161    pub unsafe fn handle_call(
162        &self,
163        function: *mut libc::c_char,
164        output: *mut libc::c_char,
165        size: libc::size_t,
166        args: Option<*mut *mut i8>,
167        count: Option<libc::c_int>,
168        clear_call_context: bool,
169    ) -> libc::c_int {
170        if clear_call_context && !self.pre218_clear_context_override {
171            self.context_manager.replace(None);
172        }
173        let function = if let Ok(cstring) = std::ffi::CStr::from_ptr(function).to_str() {
174            cstring.to_string()
175        } else {
176            return 1;
177        };
178        match function.as_str() {
179            #[cfg(windows)]
180            "::console" => {
181                if !CONSOLE_ALLOCATED.swap(true, std::sync::atomic::Ordering::SeqCst) {
182                    let _ = windows::Win32::System::Console::AllocConsole();
183                }
184                0
185            }
186            _ => self.group.handle(
187                self.context().with_buffer_size(size),
188                self.context_manager.as_ref(),
189                &function,
190                output,
191                size,
192                args,
193                count,
194            ),
195        }
196    }
197
198    #[must_use]
199    /// Create a version of the extension that can be used in tests.
200    pub fn testing(self) -> testing::Extension {
201        testing::Extension::new(self)
202    }
203
204    #[doc(hidden)]
205    /// Called by generated code, do not call directly.
206    pub fn run_callbacks(&mut self) {
207        let callback = self.callback;
208        let (_, rx) = self.callback_channel.clone();
209        self.callback_thread = Some(std::thread::spawn(move || {
210            while let Ok(CallbackMessage::Call(name, func, data)) = rx.recv() {
211                if let Some(c) = callback {
212                    let name = if let Ok(cstring) = std::ffi::CString::new(name) {
213                        cstring
214                    } else {
215                        error!("callback name was not valid");
216                        continue;
217                    };
218                    let func = if let Ok(cstring) = std::ffi::CString::new(func) {
219                        cstring
220                    } else {
221                        error!("callback func was not valid");
222                        continue;
223                    };
224                    let data = if let Ok(cstring) = std::ffi::CString::new(match data {
225                        Some(value) => match value {
226                            Value::String(s) => s,
227                            v => v.to_string(),
228                        },
229                        None => String::new(),
230                    }) {
231                        cstring
232                    } else {
233                        error!("callback data was not valid");
234                        continue;
235                    };
236
237                    let (name, func, data) = (name.into_raw(), func.into_raw(), data.into_raw());
238                    loop {
239                        if c(name, func, data) >= 0 {
240                            break;
241                        }
242                        std::thread::sleep(std::time::Duration::from_millis(1));
243                    }
244                    unsafe {
245                        drop(std::ffi::CString::from_raw(name));
246                        drop(std::ffi::CString::from_raw(func));
247                        drop(std::ffi::CString::from_raw(data));
248                    }
249                }
250            }
251        }));
252    }
253}
254
255#[cfg(feature = "extension")]
256impl Drop for Extension {
257    // Never called when loaded by arma, instead this is purely required for rust testing.
258    fn drop(&mut self) {
259        if let Some(thread) = self.callback_thread.take() {
260            let (tx, _) = &self.callback_channel;
261            tx.send(CallbackMessage::Terminate).unwrap();
262            thread.join().unwrap();
263        }
264    }
265}
266
267/// Used to build an extension.
268#[cfg(feature = "extension")]
269pub struct ExtensionBuilder {
270    version: String,
271    group: Group,
272    allow_no_args: bool,
273}
274
275#[cfg(feature = "extension")]
276impl ExtensionBuilder {
277    #[inline]
278    #[must_use]
279    /// Sets the version of the extension.
280    pub fn version(mut self, version: String) -> Self {
281        self.version = version;
282        self
283    }
284
285    #[inline]
286    #[must_use]
287    /// Add a group to the extension.
288    pub fn group<S>(mut self, name: S, group: Group) -> Self
289    where
290        S: Into<String>,
291    {
292        self.group = self.group.group(name.into(), group);
293        self
294    }
295
296    #[inline]
297    #[must_use]
298    /// Add a new state value to the extension if it has not be added already
299    pub fn state<T>(mut self, state: T) -> Self
300    where
301        T: Send + Sync + 'static,
302    {
303        self.group = self.group.state(state);
304        self
305    }
306
307    #[inline]
308    #[must_use]
309    /// Freeze the extension's state, preventing the state from changing, allowing for faster reads
310    pub fn freeze_state(mut self) -> Self {
311        self.group = self.group.freeze_state();
312        self
313    }
314
315    #[inline]
316    #[must_use]
317    /// Allows the extension to be called without any arguments.
318    /// Example:
319    /// ```sqf
320    /// "my_ext" callExtension "my_func"
321    /// ```
322    pub const fn allow_no_args(mut self) -> Self {
323        self.allow_no_args = true;
324        self
325    }
326
327    #[inline]
328    #[must_use]
329    /// Add a command to the extension.
330    pub fn command<S, F, I, R>(mut self, name: S, handler: F) -> Self
331    where
332        S: Into<String>,
333        F: Factory<I, R> + 'static,
334    {
335        self.group = self.group.command(name, handler);
336        self
337    }
338
339    #[inline]
340    #[must_use]
341    /// Builds the extension.
342    pub fn finish(self) -> Extension {
343        #[expect(unused_mut, reason = "Only used on Windows release")]
344        let mut pre218 = false;
345        #[allow(unused_variables)]
346        let function_name =
347            std::ffi::CString::new("RVExtensionRequestContext").expect("CString::new failed");
348        #[cfg(all(windows, not(debug_assertions)))]
349        let request_context: ContextRequest = {
350            let handle = unsafe { winapi::um::libloaderapi::GetModuleHandleW(std::ptr::null()) };
351            if handle.is_null() {
352                panic!("GetModuleHandleW failed");
353            }
354            let func_address =
355                unsafe { winapi::um::libloaderapi::GetProcAddress(handle, function_name.as_ptr()) };
356            if func_address.is_null() {
357                pre218 = true;
358                empty_request_context
359            } else {
360                unsafe { std::mem::transmute(func_address) }
361            }
362        };
363        #[cfg(all(not(windows), not(debug_assertions)))]
364        let request_context: ContextRequest = {
365            let handle = unsafe { libc::dlopen(std::ptr::null(), libc::RTLD_LAZY) };
366            if handle.is_null() {
367                panic!("Failed to open handle to current process");
368            }
369            let func_address = unsafe { libc::dlsym(handle, function_name.as_ptr()) };
370            if func_address.is_null() {
371                pre218 = true;
372                empty_request_context
373            } else {
374                let func = unsafe { std::mem::transmute(func_address) };
375                unsafe { libc::dlclose(handle) };
376                func
377            }
378        };
379
380        #[cfg(debug_assertions)]
381        let request_context = empty_request_context;
382
383        Extension {
384            version: self.version,
385            group: self.group.into(),
386            allow_no_args: self.allow_no_args,
387            callback: None,
388            callback_channel: unbounded(),
389            callback_thread: None,
390            context_manager: Rc::new(ArmaContextManager::new(request_context)),
391            pre218_clear_context_override: pre218,
392        }
393    }
394}
395
396unsafe extern "C" fn empty_request_context() {}
397
398#[doc(hidden)]
399/// Called by generated code, do not call directly.
400///
401/// # Safety
402/// This function is unsafe because it interacts with the C API.
403///
404/// # Note
405/// This function assumes `buf_size` includes space for a single terminating zero byte at the end.
406#[cfg(feature = "extension")]
407pub unsafe fn write_cstr(
408    string: String,
409    ptr: *mut libc::c_char,
410    buf_size: libc::size_t,
411) -> Option<libc::size_t> {
412    if string.is_empty() {
413        return Some(0);
414    }
415
416    let cstr = std::ffi::CString::new(string).ok()?;
417    let len_to_copy = cstr.as_bytes().len();
418    if len_to_copy >= buf_size {
419        return None;
420    }
421
422    ptr.copy_from(cstr.as_ptr(), len_to_copy);
423    ptr.add(len_to_copy).write(0x00);
424    Some(len_to_copy)
425}
426
427#[cfg(all(test, feature = "extension"))]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn write_size_zero() {
433        const BUF_SIZE: libc::size_t = 0;
434        let mut buf = [0; BUF_SIZE];
435        let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
436
437        assert_eq!(result, None);
438        assert_eq!(buf, [0; BUF_SIZE]);
439    }
440
441    #[test]
442    fn write_size_zero_empty() {
443        const BUF_SIZE: libc::size_t = 0;
444        let mut buf = [0; BUF_SIZE];
445        let result = unsafe { write_cstr("".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
446
447        assert_eq!(result, Some(0));
448        assert_eq!(buf, [0; BUF_SIZE]);
449    }
450
451    #[test]
452    fn write_size_one() {
453        const BUF_SIZE: libc::size_t = 1;
454        let mut buf = [0; BUF_SIZE];
455        let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
456
457        assert_eq!(result, None);
458        assert_eq!(buf, [0; BUF_SIZE]);
459    }
460
461    #[test]
462    fn write_size_one_empty() {
463        const BUF_SIZE: libc::size_t = 1;
464        let mut buf = [0; BUF_SIZE];
465        let result = unsafe { write_cstr("".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
466
467        assert_eq!(result, Some(0));
468        assert_eq!(buf, [0; BUF_SIZE]);
469    }
470
471    #[test]
472    fn write_empty() {
473        const BUF_SIZE: libc::size_t = 7;
474        let mut buf = [0; BUF_SIZE];
475        let result = unsafe { write_cstr("".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
476
477        assert_eq!(result, Some(0));
478        assert_eq!(buf, [0; BUF_SIZE]);
479    }
480
481    #[test]
482    fn write_half() {
483        const BUF_SIZE: libc::size_t = 7;
484        let mut buf = [0; BUF_SIZE];
485        let result = unsafe { write_cstr("foo".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
486
487        assert_eq!(result, Some(3));
488        assert_eq!(buf, (b"foo\0\0\0\0").map(|c| c as i8));
489    }
490
491    #[test]
492    fn write_full() {
493        const BUF_SIZE: libc::size_t = 7;
494        let mut buf = [0; BUF_SIZE];
495        let result = unsafe { write_cstr("foobar".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
496
497        assert_eq!(result, Some(6));
498        assert_eq!(buf, (b"foobar\0").map(|c| c as i8));
499    }
500
501    #[test]
502    fn write_overflow() {
503        const BUF_SIZE: libc::size_t = 7;
504        let mut buf = [0; BUF_SIZE];
505        let result = unsafe { write_cstr("foo bar".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
506
507        assert_eq!(result, None);
508        assert_eq!(buf, [0; BUF_SIZE]);
509    }
510
511    #[test]
512    fn write_overwrite() {
513        const BUF_SIZE: libc::size_t = 7;
514        let mut buf = (b"zzzzzz\0").map(|c| c as i8);
515        let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
516
517        assert_eq!(result, Some(1));
518        assert_eq!(buf, (b"a\0zzzz\0").map(|c| c as i8));
519    }
520}