1mod linux;
7
8pub use linux::LinuxPlatform;
9
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum DetectionKind {
15 Include,
17 Macro,
19 FunctionCall,
21 TypeName,
23}
24
25#[derive(Debug, Clone)]
27pub struct DetectionPattern {
28 pub kind: DetectionKind,
29 pub pattern: String,
30 pub weight: f32,
31}
32
33impl DetectionPattern {
34 pub fn include(pattern: &str, weight: f32) -> Self {
35 Self {
36 kind: DetectionKind::Include,
37 pattern: pattern.to_string(),
38 weight,
39 }
40 }
41
42 pub fn macro_pattern(pattern: &str, weight: f32) -> Self {
43 Self {
44 kind: DetectionKind::Macro,
45 pattern: pattern.to_string(),
46 weight,
47 }
48 }
49
50 pub fn function_call(pattern: &str, weight: f32) -> Self {
51 Self {
52 kind: DetectionKind::FunctionCall,
53 pattern: pattern.to_string(),
54 weight,
55 }
56 }
57
58 pub fn type_name(pattern: &str, weight: f32) -> Self {
59 Self {
60 kind: DetectionKind::TypeName,
61 pattern: pattern.to_string(),
62 weight,
63 }
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub enum CallbackCategory {
70 Init,
71 Cleanup,
72 Open,
73 Close,
74 Read,
75 Write,
76 Ioctl,
77 Mmap,
78 Poll,
79 Probe,
80 Remove,
81 Suspend,
82 Resume,
83 Interrupt,
84 Timer,
85 Workqueue,
86 Other,
87}
88
89#[derive(Debug, Clone)]
91pub struct OpsFieldDef {
92 pub name: String,
93 pub category: CallbackCategory,
94}
95
96#[derive(Debug, Clone)]
98pub struct OpsStructDef {
99 pub struct_name: String,
100 pub fields: Vec<OpsFieldDef>,
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct HeaderStubs {
106 headers: HashMap<String, String>,
107}
108
109impl HeaderStubs {
110 pub fn new() -> Self {
111 Self::default()
112 }
113
114 pub fn add(&mut self, path: &str, content: &str) {
116 self.headers.insert(path.to_string(), content.to_string());
117 }
118
119 pub fn get_for_includes(&self, source: &str) -> String {
121 let mut stubs = String::new();
122
123 for line in source.lines() {
124 let trimmed = line.trim();
125 if trimmed.starts_with("#include") {
126 if let Some(path) = Self::extract_include_path(trimmed) {
128 if let Some(stub) = self.headers.get(&path) {
129 stubs.push_str("/* Stub for ");
130 stubs.push_str(&path);
131 stubs.push_str(" */\n");
132 stubs.push_str(stub);
133 stubs.push('\n');
134 }
135 }
136 }
137 }
138
139 stubs
140 }
141
142 fn extract_include_path(line: &str) -> Option<String> {
143 let line = line.trim_start_matches("#include").trim();
145 if line.starts_with('<') {
146 line.strip_prefix('<')?.strip_suffix('>')
147 } else if line.starts_with('"') {
148 line.strip_prefix('"')?.strip_suffix('"')
149 } else {
150 None
151 }
152 .map(|s| s.to_string())
153 }
154
155 pub fn has_stub(&self, path: &str) -> bool {
157 self.headers.contains_key(path)
158 }
159
160 pub fn available_headers(&self) -> Vec<&str> {
162 self.headers.keys().map(|s| s.as_str()).collect()
163 }
164}
165
166pub trait PlatformModule: Send + Sync {
168 fn id(&self) -> &'static str;
170
171 fn name(&self) -> &'static str;
173
174 fn detection_patterns(&self) -> Vec<DetectionPattern>;
176
177 fn header_stubs(&self) -> &HeaderStubs;
179
180 fn attributes_to_strip(&self) -> &[&'static str];
182
183 fn ops_structs(&self) -> &[OpsStructDef];
185
186 fn call_normalizations(&self) -> &HashMap<&'static str, &'static str>;
188}
189
190#[derive(Debug, Clone)]
192pub struct DetectionResult {
193 pub platform_id: String,
194 pub confidence: f32,
195 pub matched_patterns: Vec<String>,
196}
197
198pub struct PlatformRegistry {
200 platforms: Vec<Box<dyn PlatformModule>>,
201}
202
203impl Default for PlatformRegistry {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209impl PlatformRegistry {
210 pub fn new() -> Self {
211 let mut registry = Self {
212 platforms: Vec::new(),
213 };
214 registry.register(Box::new(LinuxPlatform::new()));
216 registry
217 }
218
219 pub fn register(&mut self, platform: Box<dyn PlatformModule>) {
221 self.platforms.push(platform);
222 }
223
224 pub fn detect(&self, source: &str) -> DetectionResult {
226 let mut best_result = DetectionResult {
227 platform_id: "generic".to_string(),
228 confidence: 0.0,
229 matched_patterns: Vec::new(),
230 };
231
232 for platform in &self.platforms {
233 let result = self.score_platform(source, platform.as_ref());
234 if result.confidence > best_result.confidence {
235 best_result = result;
236 }
237 }
238
239 best_result
240 }
241
242 pub fn get(&self, id: &str) -> Option<&dyn PlatformModule> {
244 self.platforms
245 .iter()
246 .find(|p| p.id() == id)
247 .map(|p| p.as_ref())
248 }
249
250 fn score_platform(&self, source: &str, platform: &dyn PlatformModule) -> DetectionResult {
251 let mut total_weight = 0.0;
252 let mut matched_patterns = Vec::new();
253
254 let source_lower = source.to_lowercase();
255
256 for pattern in platform.detection_patterns() {
257 let matched = match pattern.kind {
258 DetectionKind::Include => {
259 source.contains(&format!("#include <{}", pattern.pattern))
261 || source.contains(&format!("#include \"{}", pattern.pattern))
262 }
263 DetectionKind::Macro => {
264 source.contains(&pattern.pattern)
266 }
267 DetectionKind::FunctionCall => {
268 source.contains(&format!("{}(", pattern.pattern))
270 }
271 DetectionKind::TypeName => {
272 source_lower.contains(&pattern.pattern.to_lowercase())
274 }
275 };
276
277 if matched {
278 total_weight += pattern.weight;
279 matched_patterns.push(pattern.pattern.clone());
280 }
281 }
282
283 let confidence = (total_weight / 10.0).min(1.0);
285
286 DetectionResult {
287 platform_id: platform.id().to_string(),
288 confidence,
289 matched_patterns,
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_header_stubs_extract_include_path() {
300 assert_eq!(
301 HeaderStubs::extract_include_path("#include <linux/types.h>"),
302 Some("linux/types.h".to_string())
303 );
304 assert_eq!(
305 HeaderStubs::extract_include_path("#include \"myheader.h\""),
306 Some("myheader.h".to_string())
307 );
308 assert_eq!(
309 HeaderStubs::extract_include_path(" #include <sys/param.h> "),
310 None );
312 }
313
314 #[test]
315 fn test_header_stubs_get_for_includes() {
316 let mut stubs = HeaderStubs::new();
317 stubs.add("linux/types.h", "typedef unsigned int u32;");
318 stubs.add("linux/kernel.h", "typedef unsigned long size_t;");
319
320 let source = r#"
321#include <linux/types.h>
322#include <linux/module.h>
323#include <linux/kernel.h>
324"#;
325
326 let result = stubs.get_for_includes(source);
327 assert!(result.contains("typedef unsigned int u32"));
328 assert!(result.contains("typedef unsigned long size_t"));
329 assert!(!result.contains("module")); }
331
332 #[test]
333 fn test_detection_pattern_creation() {
334 let p1 = DetectionPattern::include("linux/", 2.0);
335 assert_eq!(p1.kind, DetectionKind::Include);
336 assert_eq!(p1.pattern, "linux/");
337 assert!((p1.weight - 2.0).abs() < f32::EPSILON);
338
339 let p2 = DetectionPattern::macro_pattern("MODULE_LICENSE", 3.0);
340 assert_eq!(p2.kind, DetectionKind::Macro);
341 }
342
343 #[test]
344 fn test_platform_registry_detect_linux() {
345 let registry = PlatformRegistry::new();
346
347 let linux_source = r#"
348#include <linux/module.h>
349#include <linux/kernel.h>
350#include <linux/init.h>
351
352MODULE_LICENSE("GPL");
353MODULE_AUTHOR("Test");
354
355static int __init my_init(void) {
356 printk(KERN_INFO "Hello\n");
357 return 0;
358}
359module_init(my_init);
360"#;
361
362 let result = registry.detect(linux_source);
363 assert_eq!(result.platform_id, "linux");
364 assert!(result.confidence > 0.5);
365 assert!(!result.matched_patterns.is_empty());
366 }
367
368 #[test]
369 fn test_platform_registry_generic_code() {
370 let registry = PlatformRegistry::new();
371
372 let generic_source = r#"
373#include <stdio.h>
374#include <stdlib.h>
375
376int main(int argc, char **argv) {
377 printf("Hello, World!\n");
378 return 0;
379}
380"#;
381
382 let result = registry.detect(generic_source);
383 assert!(result.confidence < 0.3);
385 }
386
387 #[test]
388 fn test_platform_registry_get() {
389 let registry = PlatformRegistry::new();
390
391 let linux = registry.get("linux");
392 assert!(linux.is_some());
393 assert_eq!(linux.unwrap().name(), "Linux Kernel");
394
395 let unknown = registry.get("unknown");
396 assert!(unknown.is_none());
397 }
398}