Skip to main content

memlink_runtime/
validation.rs

1//! Module validation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use libloading::Library;
7
8use crate::abi::{validate_abi_version, AbiVersionError, MEMLINK_ABI_VERSION};
9use crate::exports::REQUIRED_EXPORTS;
10use crate::instance::ModuleInstance;
11
12#[derive(Debug, Clone)]
13pub struct ValidationResult {
14    pub valid: bool,
15    pub warnings: Vec<String>,
16    pub errors: Vec<String>,
17    pub abi_version: Option<u32>,
18    pub has_optional_exports: bool,
19}
20
21impl ValidationResult {
22    pub fn valid() -> Self {
23        ValidationResult {
24            valid: true,
25            warnings: vec![],
26            errors: vec![],
27            abi_version: None,
28            has_optional_exports: false,
29        }
30    }
31
32    pub fn invalid() -> Self {
33        ValidationResult {
34            valid: false,
35            warnings: vec![],
36            errors: vec![],
37            abi_version: None,
38            has_optional_exports: false,
39        }
40    }
41
42    pub fn add_warning(&mut self, warning: impl Into<String>) {
43        self.warnings.push(warning.into());
44    }
45
46    pub fn add_error(&mut self, error: impl Into<String>) {
47        self.errors.push(error.into());
48        self.valid = false;
49    }
50}
51
52pub fn validate_module(_instance: &ModuleInstance) -> crate::Result<ValidationResult> {
53    let mut result = ValidationResult::valid();
54    result.abi_version = Some(MEMLINK_ABI_VERSION);
55    Ok(result)
56}
57
58pub fn validate_exports(library: &Library) -> crate::Result<ValidationResult> {
59    let mut result = ValidationResult::valid();
60    let mut found_required = 0;
61    let mut found_optional = 0;
62
63    for &export_name in REQUIRED_EXPORTS {
64        let symbol_name = format!("{}\0", export_name);
65        let check: std::result::Result<libloading::Symbol<unsafe extern "C" fn()>, _> =
66            unsafe { library.get(symbol_name.as_bytes()) };
67
68        if check.is_ok() {
69            found_required += 1;
70        } else {
71            result.add_error(format!("Missing required export: {}", export_name));
72        }
73    }
74
75    for &export_name in crate::exports::OPTIONAL_EXPORTS {
76        let symbol_name = format!("{}\0", export_name);
77        let check: std::result::Result<libloading::Symbol<unsafe extern "C" fn()>, _> =
78            unsafe { library.get(symbol_name.as_bytes()) };
79
80        if check.is_ok() {
81            found_optional += 1;
82        }
83    }
84
85    result.has_optional_exports = found_optional > 0;
86
87    if found_required == REQUIRED_EXPORTS.len() {
88        result.add_warning(format!(
89            "All {} required exports present, {} optional exports found",
90            found_required, found_optional
91        ));
92    }
93
94    Ok(result)
95}
96
97pub fn validate_abi(module_version: u32) -> ValidationResult {
98    let mut result = ValidationResult::valid();
99    result.abi_version = Some(module_version);
100
101    match validate_abi_version(module_version) {
102        Ok(()) => {
103            if module_version < MEMLINK_ABI_VERSION {
104                result.add_warning(format!(
105                    "Module ABI version {} is older than current version {}. \
106                     Compatibility is not guaranteed.",
107                    module_version, MEMLINK_ABI_VERSION
108                ));
109            }
110        }
111        Err(AbiVersionError::TooOld { module, min_supported }) => {
112            result.add_error(format!(
113                "Module ABI version {} is too old (minimum supported: {})",
114                module, min_supported
115            ));
116        }
117        Err(AbiVersionError::TooNew { module, max_supported }) => {
118            result.add_error(format!(
119                "Module ABI version {} is too new (maximum supported: {})",
120                module, max_supported
121            ));
122        }
123    }
124
125    result
126}
127
128#[derive(Debug, Clone)]
129pub struct CachedValidation {
130    pub result: ValidationResult,
131    pub validated_at: u128,
132}
133
134impl CachedValidation {
135    pub fn new(result: ValidationResult) -> Self {
136        CachedValidation {
137            result,
138            validated_at: std::time::Instant::now().elapsed().as_nanos(),
139        }
140    }
141
142    pub fn is_fresh(&self) -> bool {
143        true
144    }
145}
146
147#[derive(Debug, Default)]
148pub struct ValidationCache {
149    cache: HashMap<String, Arc<CachedValidation>>,
150}
151
152impl ValidationCache {
153    pub fn new() -> Self {
154        ValidationCache {
155            cache: HashMap::new(),
156        }
157    }
158
159    pub fn get(&self, key: &str) -> Option<&CachedValidation> {
160        self.cache.get(key).map(|arc| arc.as_ref())
161    }
162
163    pub fn insert(&mut self, key: impl Into<String>, result: ValidationResult) {
164        self.cache.insert(key.into(), Arc::new(CachedValidation::new(result)));
165    }
166
167    pub fn remove(&mut self, key: &str) {
168        self.cache.remove(key);
169    }
170
171    pub fn clear(&mut self) {
172        self.cache.clear();
173    }
174
175    pub fn len(&self) -> usize {
176        self.cache.len()
177    }
178
179    pub fn is_empty(&self) -> bool {
180        self.cache.is_empty()
181    }
182}