arma_rs/
lib.rs

1#![warn(missing_docs, nonstandard_style)]
2#![doc = include_str!(concat!(env!("OUT_DIR"), "/README.md"))]
3
4use std::rc::Rc;
5
6pub use arma_rs_proc::{FromArma, IntoArma, arma};
7
8#[cfg(feature = "extension")]
9use crossbeam_channel::{Receiver, Sender, unbounded};
10#[cfg(feature = "extension")]
11pub use libc;
12
13#[cfg(all(target_os = "windows", target_arch = "x86"))]
14pub use link_args;
15
16#[cfg(feature = "extension")]
17#[macro_use]
18extern crate log;
19
20mod flags;
21
22mod value;
23pub use value::{DirectReturn, FromArma, FromArmaError, IntoArma, Value, loadout};
24
25#[cfg(feature = "extension")]
26mod call_context;
27#[cfg(feature = "extension")]
28use call_context::{ArmaCallContext, ArmaContextManager};
29#[cfg(feature = "extension")]
30pub use call_context::{CallContext, CallContextStackTrace, Caller, Mission, Server, Source};
31#[cfg(feature = "extension")]
32mod ext_result;
33#[cfg(feature = "extension")]
34pub use ext_result::IntoExtResult;
35#[cfg(feature = "extension")]
36mod command;
37#[cfg(feature = "extension")]
38pub use command::*;
39#[cfg(feature = "extension")]
40pub mod context;
41#[cfg(feature = "extension")]
42pub use context::*;
43#[cfg(feature = "extension")]
44mod group;
45#[cfg(feature = "extension")]
46pub use group::Group;
47#[cfg(feature = "extension")]
48pub mod testing;
49#[cfg(feature = "extension")]
50pub use testing::Result;
51
52#[cfg(feature = "extension")]
53#[doc(hidden)]
54/// Used by generated code to call back into Arma
55pub type Callback = extern "system" fn(
56    *const libc::c_char,
57    *const libc::c_char,
58    *const libc::c_char,
59) -> libc::c_int;
60/// Requests a call context from Arma
61pub type ContextRequest = unsafe extern "system" fn();
62
63#[cfg(feature = "extension")]
64enum CallbackMessage {
65    Call(String, String, Option<Value>),
66    Terminate,
67}
68
69#[cfg(feature = "extension")]
70/// State `TypeMap` that can hold at most one value per type key.
71pub type State = state::TypeMap![Send + Sync];
72
73#[cfg(windows)]
74/// Allows a console to be allocated for the extension.
75static CONSOLE_ALLOCATED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
76
77#[unsafe(no_mangle)]
78#[allow(non_upper_case_globals, reason = "This is a C API")]
79/// Feature flags read on each callExtension call.
80pub static mut RVExtensionFeatureFlags: u64 = flags::RV_CONTEXT_NO_DEFAULT_CALL;
81
82/// Contains all the information about your extension
83/// This is used by the generated code to interface with Arma
84#[cfg(feature = "extension")]
85pub struct Extension {
86    version: String,
87    group: group::InternalGroup,
88    allow_no_args: bool,
89    callback: Option<Callback>,
90    callback_channel: (Sender<CallbackMessage>, Receiver<CallbackMessage>),
91    callback_thread: Option<std::thread::JoinHandle<()>>,
92    context_manager: Rc<ArmaContextManager>,
93    pre218_clear_context_override: bool,
94}
95
96#[cfg(feature = "extension")]
97impl Extension {
98    #[must_use]
99    /// Creates a new extension.
100    pub fn build() -> ExtensionBuilder {
101        ExtensionBuilder {
102            version: String::from("0.0.0"),
103            group: Group::new(),
104            allow_no_args: false,
105        }
106    }
107}
108
109#[cfg(feature = "extension")]
110impl Extension {
111    #[must_use]
112    /// Returns the version of the extension.
113    pub fn version(&self) -> &str {
114        &self.version
115    }
116
117    #[must_use]
118    /// Returns if the extension can be called without any arguments.
119    /// Example:
120    /// ```sqf
121    /// "my_ext" callExtension "my_func"
122    /// ```
123    pub const fn allow_no_args(&self) -> bool {
124        self.allow_no_args
125    }
126
127    #[doc(hidden)]
128    /// Called by generated code, do not call directly.
129    pub fn register_callback(&mut self, callback: Callback) {
130        self.callback = Some(callback);
131    }
132
133    #[doc(hidden)]
134    /// Called by generated code, do not call directly.
135    /// # Safety
136    /// This function is unsafe because it interacts with the C API.
137    pub unsafe fn handle_call_context(&mut self, args: *mut *mut i8, count: libc::c_int) {
138        self.context_manager
139            .replace(Some(ArmaCallContext::from_arma(args, count)));
140    }
141
142    #[must_use]
143    /// Get a context for interacting with Arma
144    pub fn context(&self) -> Context {
145        Context::new(
146            self.callback_channel.0.clone(),
147            GlobalContext::new(self.version.clone(), self.group.state.clone()),
148            GroupContext::new(self.group.state.clone()),
149        )
150    }
151
152    #[doc(hidden)]
153    /// Called by generated code, do not call directly.
154    /// # Safety
155    /// This function is unsafe because it interacts with the C API.
156    pub unsafe fn handle_call(
157        &self,
158        function: *mut libc::c_char,
159        output: *mut libc::c_char,
160        size: libc::size_t,
161        args: Option<*mut *mut i8>,
162        count: Option<libc::c_int>,
163        clear_call_context: bool,
164    ) -> libc::c_int {
165        if clear_call_context && !self.pre218_clear_context_override {
166            self.context_manager.replace(None);
167        }
168        let function = if let Ok(cstring) = unsafe { std::ffi::CStr::from_ptr(function).to_str() } {
169            cstring.to_string()
170        } else {
171            return 1;
172        };
173        match function.as_str() {
174            #[cfg(windows)]
175            "::console" => {
176                if !CONSOLE_ALLOCATED.swap(true, std::sync::atomic::Ordering::SeqCst) {
177                    let _ = windows::Win32::System::Console::AllocConsole();
178                }
179                0
180            }
181            _ => self.group.handle(
182                self.context().with_buffer_size(size),
183                self.context_manager.as_ref(),
184                &function,
185                output,
186                size,
187                args,
188                count,
189            ),
190        }
191    }
192
193    #[must_use]
194    /// Create a version of the extension that can be used in tests.
195    pub fn testing(self) -> testing::Extension {
196        testing::Extension::new(self)
197    }
198
199    #[doc(hidden)]
200    /// Called by generated code, do not call directly.
201    pub fn run_callbacks(&mut self) {
202        let callback = self.callback;
203        let (_, rx) = self.callback_channel.clone();
204        self.callback_thread = Some(std::thread::spawn(move || {
205            while let Ok(CallbackMessage::Call(name, func, data)) = rx.recv() {
206                if let Some(c) = callback {
207                    let Ok(name) = std::ffi::CString::new(name) else {
208                        error!("callback name was not valid");
209                        continue;
210                    };
211                    let Ok(func) = std::ffi::CString::new(func) else {
212                        error!("callback func was not valid");
213                        continue;
214                    };
215                    let Ok(data) =
216                        std::ffi::CString::new(data.map_or_else(
217                            String::new,
218                            |value| match value {
219                                Value::String(s) => s,
220                                v => v.to_string(),
221                            },
222                        ))
223                    else {
224                        error!("callback data was not valid");
225                        continue;
226                    };
227
228                    let (name, func, data) = (name.into_raw(), func.into_raw(), data.into_raw());
229                    loop {
230                        if c(name, func, data) >= 0 {
231                            break;
232                        }
233                        std::thread::sleep(std::time::Duration::from_millis(1));
234                    }
235                    unsafe {
236                        drop(std::ffi::CString::from_raw(name));
237                        drop(std::ffi::CString::from_raw(func));
238                        drop(std::ffi::CString::from_raw(data));
239                    }
240                }
241            }
242        }));
243    }
244}
245
246#[cfg(feature = "extension")]
247impl Drop for Extension {
248    // Never called when loaded by arma, instead this is purely required for rust testing.
249    fn drop(&mut self) {
250        if let Some(thread) = self.callback_thread.take() {
251            let (tx, _) = &self.callback_channel;
252            tx.send(CallbackMessage::Terminate)
253                .expect("Failed to send terminate message to callback thread");
254            thread.join().expect("Failed to join callback thread");
255        }
256    }
257}
258
259/// Used to build an extension.
260#[cfg(feature = "extension")]
261pub struct ExtensionBuilder {
262    version: String,
263    group: Group,
264    allow_no_args: bool,
265}
266
267#[cfg(feature = "extension")]
268impl ExtensionBuilder {
269    #[inline]
270    #[must_use]
271    /// Sets the version of the extension.
272    pub fn version(mut self, version: String) -> Self {
273        self.version = version;
274        self
275    }
276
277    #[inline]
278    #[must_use]
279    /// Add a group to the extension.
280    pub fn group<S>(mut self, name: S, group: Group) -> Self
281    where
282        S: Into<String>,
283    {
284        self.group = self.group.group(name.into(), group);
285        self
286    }
287
288    #[inline]
289    #[must_use]
290    /// Add a new state value to the extension if it has not be added already
291    pub fn state<T>(mut self, state: T) -> Self
292    where
293        T: Send + Sync + 'static,
294    {
295        self.group = self.group.state(state);
296        self
297    }
298
299    #[inline]
300    #[must_use]
301    /// Freeze the extension's state, preventing the state from changing, allowing for faster reads
302    pub fn freeze_state(mut self) -> Self {
303        self.group = self.group.freeze_state();
304        self
305    }
306
307    #[inline]
308    #[must_use]
309    /// Allows the extension to be called without any arguments.
310    /// Example:
311    /// ```sqf
312    /// "my_ext" callExtension "my_func"
313    /// ```
314    pub const fn allow_no_args(mut self) -> Self {
315        self.allow_no_args = true;
316        self
317    }
318
319    #[inline]
320    #[must_use]
321    /// Add a command to the extension.
322    pub fn command<S, F, I, R>(mut self, name: S, handler: F) -> Self
323    where
324        S: Into<String>,
325        F: Factory<I, R> + 'static,
326    {
327        self.group = self.group.command(name, handler);
328        self
329    }
330
331    #[inline]
332    #[must_use]
333    /// Builds the extension.
334    pub fn finish(self) -> Extension {
335        #[expect(unused_mut, reason = "Only used on Windows release")]
336        let mut pre218 = false;
337        #[allow(unused_variables)]
338        let function_name =
339            std::ffi::CString::new("RVExtensionRequestContext").expect("CString::new failed");
340        #[cfg(all(windows, not(debug_assertions)))]
341        let request_context: ContextRequest = {
342            let handle = unsafe { winapi::um::libloaderapi::GetModuleHandleW(std::ptr::null()) };
343            if handle.is_null() {
344                panic!("GetModuleHandleW failed");
345            }
346            let func_address =
347                unsafe { winapi::um::libloaderapi::GetProcAddress(handle, function_name.as_ptr()) };
348            if func_address.is_null() {
349                pre218 = true;
350                empty_request_context
351            } else {
352                unsafe { std::mem::transmute(func_address) }
353            }
354        };
355        #[cfg(all(not(windows), not(debug_assertions)))]
356        let request_context: ContextRequest = {
357            let handle = unsafe { libc::dlopen(std::ptr::null(), libc::RTLD_LAZY) };
358            if handle.is_null() {
359                panic!("Failed to open handle to current process");
360            }
361            let func_address = unsafe { libc::dlsym(handle, function_name.as_ptr()) };
362            if func_address.is_null() {
363                pre218 = true;
364                empty_request_context
365            } else {
366                let func = unsafe { std::mem::transmute(func_address) };
367                unsafe { libc::dlclose(handle) };
368                func
369            }
370        };
371
372        #[cfg(debug_assertions)]
373        let request_context = empty_request_context;
374
375        Extension {
376            version: self.version,
377            group: self.group.into(),
378            allow_no_args: self.allow_no_args,
379            callback: None,
380            callback_channel: unbounded(),
381            callback_thread: None,
382            context_manager: Rc::new(ArmaContextManager::new(request_context)),
383            pre218_clear_context_override: pre218,
384        }
385    }
386}
387
388const unsafe extern "system" fn empty_request_context() {}
389
390#[doc(hidden)]
391/// Called by generated code, do not call directly.
392///
393/// # Safety
394/// This function is unsafe because it interacts with the C API.
395///
396/// # Note
397/// This function assumes `buf_size` includes space for a single terminating zero byte at the end.
398#[cfg(feature = "extension")]
399pub unsafe fn write_cstr(
400    string: String,
401    ptr: *mut libc::c_char,
402    buf_size: libc::size_t,
403) -> Option<libc::size_t> {
404    if string.is_empty() {
405        return Some(0);
406    }
407
408    let cstr = std::ffi::CString::new(string).ok()?;
409    let len_to_copy = cstr.as_bytes().len();
410    if len_to_copy >= buf_size {
411        return None;
412    }
413
414    unsafe { ptr.copy_from(cstr.as_ptr(), len_to_copy) };
415    unsafe { ptr.add(len_to_copy).write(0x00) };
416    Some(len_to_copy)
417}
418
419#[cfg(all(test, feature = "extension"))]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn write_size_zero() {
425        const BUF_SIZE: libc::size_t = 0;
426        let mut buf = [0; BUF_SIZE];
427        let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
428
429        assert_eq!(result, None);
430        assert_eq!(buf, [0; BUF_SIZE]);
431    }
432
433    #[test]
434    fn write_size_zero_empty() {
435        const BUF_SIZE: libc::size_t = 0;
436        let mut buf = [0; BUF_SIZE];
437        let result = unsafe { write_cstr(String::new(), buf.as_mut_ptr(), BUF_SIZE) };
438
439        assert_eq!(result, Some(0));
440        assert_eq!(buf, [0; BUF_SIZE]);
441    }
442
443    #[test]
444    fn write_size_one() {
445        const BUF_SIZE: libc::size_t = 1;
446        let mut buf = [0; BUF_SIZE];
447        let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
448
449        assert_eq!(result, None);
450        assert_eq!(buf, [0; BUF_SIZE]);
451    }
452
453    #[test]
454    fn write_size_one_empty() {
455        const BUF_SIZE: libc::size_t = 1;
456        let mut buf = [0; BUF_SIZE];
457        let result = unsafe { write_cstr(String::new(), buf.as_mut_ptr(), BUF_SIZE) };
458
459        assert_eq!(result, Some(0));
460        assert_eq!(buf, [0; BUF_SIZE]);
461    }
462
463    #[test]
464    fn write_empty() {
465        const BUF_SIZE: libc::size_t = 7;
466        let mut buf = [0; BUF_SIZE];
467        let result = unsafe { write_cstr(String::new(), buf.as_mut_ptr(), BUF_SIZE) };
468
469        assert_eq!(result, Some(0));
470        assert_eq!(buf, [0; BUF_SIZE]);
471    }
472
473    #[test]
474    fn write_half() {
475        const BUF_SIZE: libc::size_t = 7;
476        let mut buf = [0; BUF_SIZE];
477        let result = unsafe { write_cstr("foo".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
478
479        assert_eq!(result, Some(3));
480        assert_eq!(buf, (b"foo\0\0\0\0").map(u8::cast_signed));
481    }
482
483    #[test]
484    fn write_full() {
485        const BUF_SIZE: libc::size_t = 7;
486        let mut buf = [0; BUF_SIZE];
487        let result = unsafe { write_cstr("foobar".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
488
489        assert_eq!(result, Some(6));
490        assert_eq!(buf, (b"foobar\0").map(u8::cast_signed));
491    }
492
493    #[test]
494    fn write_overflow() {
495        const BUF_SIZE: libc::size_t = 7;
496        let mut buf = [0; BUF_SIZE];
497        let result = unsafe { write_cstr("foo bar".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
498
499        assert_eq!(result, None);
500        assert_eq!(buf, [0; BUF_SIZE]);
501    }
502
503    #[test]
504    fn write_overwrite() {
505        const BUF_SIZE: libc::size_t = 7;
506        let mut buf = (b"zzzzzz\0").map(u8::cast_signed);
507        let result = unsafe { write_cstr("a".to_string(), buf.as_mut_ptr(), BUF_SIZE) };
508
509        assert_eq!(result, Some(1));
510        assert_eq!(buf, (b"a\0zzzz\0").map(u8::cast_signed));
511    }
512}