codegraph_c/platform/
mod.rs

1//! Platform detection and abstraction for C codebases
2//!
3//! This module provides automatic detection of the target platform (Linux, FreeBSD, Darwin)
4//! based on source code patterns, and provides platform-specific configurations for parsing.
5
6mod linux;
7
8pub use linux::LinuxPlatform;
9
10use std::collections::HashMap;
11
12/// Detection pattern kind
13#[derive(Debug, Clone, PartialEq)]
14pub enum DetectionKind {
15    /// Include directive pattern (e.g., "linux/")
16    Include,
17    /// Macro definition or usage
18    Macro,
19    /// Function call pattern
20    FunctionCall,
21    /// Type name pattern
22    TypeName,
23}
24
25/// A pattern used to detect platform
26#[derive(Debug, Clone)]
27pub struct DetectionPattern {
28    pub kind: DetectionKind,
29    pub pattern: String,
30    pub weight: f32,
31}
32
33impl DetectionPattern {
34    pub fn include(pattern: &str, weight: f32) -> Self {
35        Self {
36            kind: DetectionKind::Include,
37            pattern: pattern.to_string(),
38            weight,
39        }
40    }
41
42    pub fn macro_pattern(pattern: &str, weight: f32) -> Self {
43        Self {
44            kind: DetectionKind::Macro,
45            pattern: pattern.to_string(),
46            weight,
47        }
48    }
49
50    pub fn function_call(pattern: &str, weight: f32) -> Self {
51        Self {
52            kind: DetectionKind::FunctionCall,
53            pattern: pattern.to_string(),
54            weight,
55        }
56    }
57
58    pub fn type_name(pattern: &str, weight: f32) -> Self {
59        Self {
60            kind: DetectionKind::TypeName,
61            pattern: pattern.to_string(),
62            weight,
63        }
64    }
65}
66
67/// Category of callback functions in ops structures
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub enum CallbackCategory {
70    Init,
71    Cleanup,
72    Open,
73    Close,
74    Read,
75    Write,
76    Ioctl,
77    Mmap,
78    Poll,
79    Probe,
80    Remove,
81    Suspend,
82    Resume,
83    Interrupt,
84    Timer,
85    Workqueue,
86    Other,
87}
88
89/// Definition of an ops struct field
90#[derive(Debug, Clone)]
91pub struct OpsFieldDef {
92    pub name: String,
93    pub category: CallbackCategory,
94}
95
96/// Definition of an ops struct (like file_operations, pci_driver)
97#[derive(Debug, Clone)]
98pub struct OpsStructDef {
99    pub struct_name: String,
100    pub fields: Vec<OpsFieldDef>,
101}
102
103/// Header stub definitions - actual type definitions to inject
104#[derive(Debug, Clone, Default)]
105pub struct HeaderStubs {
106    headers: HashMap<String, String>,
107}
108
109impl HeaderStubs {
110    pub fn new() -> Self {
111        Self::default()
112    }
113
114    /// Add a stub for a header path
115    pub fn add(&mut self, path: &str, content: &str) {
116        self.headers.insert(path.to_string(), content.to_string());
117    }
118
119    /// Get stub content for all matching includes in source
120    pub fn get_for_includes(&self, source: &str) -> String {
121        let mut stubs = String::new();
122
123        for line in source.lines() {
124            let trimmed = line.trim();
125            if trimmed.starts_with("#include") {
126                // Extract header path from #include <path> or #include "path"
127                if let Some(path) = Self::extract_include_path(trimmed) {
128                    if let Some(stub) = self.headers.get(&path) {
129                        stubs.push_str("/* Stub for ");
130                        stubs.push_str(&path);
131                        stubs.push_str(" */\n");
132                        stubs.push_str(stub);
133                        stubs.push('\n');
134                    }
135                }
136            }
137        }
138
139        stubs
140    }
141
142    fn extract_include_path(line: &str) -> Option<String> {
143        // Handle #include <path> and #include "path"
144        let line = line.trim_start_matches("#include").trim();
145        if line.starts_with('<') {
146            line.strip_prefix('<')?.strip_suffix('>')
147        } else if line.starts_with('"') {
148            line.strip_prefix('"')?.strip_suffix('"')
149        } else {
150            None
151        }
152        .map(|s| s.to_string())
153    }
154
155    /// Check if stubs exist for a header
156    pub fn has_stub(&self, path: &str) -> bool {
157        self.headers.contains_key(path)
158    }
159
160    /// Get all available stub headers
161    pub fn available_headers(&self) -> Vec<&str> {
162        self.headers.keys().map(|s| s.as_str()).collect()
163    }
164}
165
166/// Trait for platform-specific modules
167pub trait PlatformModule: Send + Sync {
168    /// Unique identifier for this platform
169    fn id(&self) -> &'static str;
170
171    /// Human-readable name
172    fn name(&self) -> &'static str;
173
174    /// Get detection patterns for this platform
175    fn detection_patterns(&self) -> Vec<DetectionPattern>;
176
177    /// Get header stubs for this platform
178    fn header_stubs(&self) -> &HeaderStubs;
179
180    /// Get attributes that should be stripped for this platform
181    fn attributes_to_strip(&self) -> &[&'static str];
182
183    /// Get ops struct definitions for callback resolution
184    fn ops_structs(&self) -> &[OpsStructDef];
185
186    /// Get call normalization mappings (platform-specific → unified)
187    fn call_normalizations(&self) -> &HashMap<&'static str, &'static str>;
188}
189
190/// Detection result with confidence score
191#[derive(Debug, Clone)]
192pub struct DetectionResult {
193    pub platform_id: String,
194    pub confidence: f32,
195    pub matched_patterns: Vec<String>,
196}
197
198/// Registry of available platforms
199pub struct PlatformRegistry {
200    platforms: Vec<Box<dyn PlatformModule>>,
201}
202
203impl Default for PlatformRegistry {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl PlatformRegistry {
210    pub fn new() -> Self {
211        let mut registry = Self {
212            platforms: Vec::new(),
213        };
214        // Register default platforms
215        registry.register(Box::new(LinuxPlatform::new()));
216        registry
217    }
218
219    /// Register a platform module
220    pub fn register(&mut self, platform: Box<dyn PlatformModule>) {
221        self.platforms.push(platform);
222    }
223
224    /// Detect platform from source code
225    pub fn detect(&self, source: &str) -> DetectionResult {
226        let mut best_result = DetectionResult {
227            platform_id: "generic".to_string(),
228            confidence: 0.0,
229            matched_patterns: Vec::new(),
230        };
231
232        for platform in &self.platforms {
233            let result = self.score_platform(source, platform.as_ref());
234            if result.confidence > best_result.confidence {
235                best_result = result;
236            }
237        }
238
239        best_result
240    }
241
242    /// Get a platform by ID
243    pub fn get(&self, id: &str) -> Option<&dyn PlatformModule> {
244        self.platforms
245            .iter()
246            .find(|p| p.id() == id)
247            .map(|p| p.as_ref())
248    }
249
250    fn score_platform(&self, source: &str, platform: &dyn PlatformModule) -> DetectionResult {
251        let mut total_weight = 0.0;
252        let mut matched_patterns = Vec::new();
253
254        let source_lower = source.to_lowercase();
255
256        for pattern in platform.detection_patterns() {
257            let matched = match pattern.kind {
258                DetectionKind::Include => {
259                    // Check for #include with this path
260                    source.contains(&format!("#include <{}", pattern.pattern))
261                        || source.contains(&format!("#include \"{}", pattern.pattern))
262                }
263                DetectionKind::Macro => {
264                    // Check for macro usage or definition
265                    source.contains(&pattern.pattern)
266                }
267                DetectionKind::FunctionCall => {
268                    // Check for function call pattern
269                    source.contains(&format!("{}(", pattern.pattern))
270                }
271                DetectionKind::TypeName => {
272                    // Check for type usage (case-insensitive for some)
273                    source_lower.contains(&pattern.pattern.to_lowercase())
274                }
275            };
276
277            if matched {
278                total_weight += pattern.weight;
279                matched_patterns.push(pattern.pattern.clone());
280            }
281        }
282
283        // Normalize confidence to 0.0-1.0 range (cap at 1.0)
284        let confidence = (total_weight / 10.0).min(1.0);
285
286        DetectionResult {
287            platform_id: platform.id().to_string(),
288            confidence,
289            matched_patterns,
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_header_stubs_extract_include_path() {
300        assert_eq!(
301            HeaderStubs::extract_include_path("#include <linux/types.h>"),
302            Some("linux/types.h".to_string())
303        );
304        assert_eq!(
305            HeaderStubs::extract_include_path("#include \"myheader.h\""),
306            Some("myheader.h".to_string())
307        );
308        assert_eq!(
309            HeaderStubs::extract_include_path("  #include <sys/param.h>  "),
310            None // trimmed line doesn't have leading space
311        );
312    }
313
314    #[test]
315    fn test_header_stubs_get_for_includes() {
316        let mut stubs = HeaderStubs::new();
317        stubs.add("linux/types.h", "typedef unsigned int u32;");
318        stubs.add("linux/kernel.h", "typedef unsigned long size_t;");
319
320        let source = r#"
321#include <linux/types.h>
322#include <linux/module.h>
323#include <linux/kernel.h>
324"#;
325
326        let result = stubs.get_for_includes(source);
327        assert!(result.contains("typedef unsigned int u32"));
328        assert!(result.contains("typedef unsigned long size_t"));
329        assert!(!result.contains("module")); // No stub for module.h
330    }
331
332    #[test]
333    fn test_detection_pattern_creation() {
334        let p1 = DetectionPattern::include("linux/", 2.0);
335        assert_eq!(p1.kind, DetectionKind::Include);
336        assert_eq!(p1.pattern, "linux/");
337        assert!((p1.weight - 2.0).abs() < f32::EPSILON);
338
339        let p2 = DetectionPattern::macro_pattern("MODULE_LICENSE", 3.0);
340        assert_eq!(p2.kind, DetectionKind::Macro);
341    }
342
343    #[test]
344    fn test_platform_registry_detect_linux() {
345        let registry = PlatformRegistry::new();
346
347        let linux_source = r#"
348#include <linux/module.h>
349#include <linux/kernel.h>
350#include <linux/init.h>
351
352MODULE_LICENSE("GPL");
353MODULE_AUTHOR("Test");
354
355static int __init my_init(void) {
356    printk(KERN_INFO "Hello\n");
357    return 0;
358}
359module_init(my_init);
360"#;
361
362        let result = registry.detect(linux_source);
363        assert_eq!(result.platform_id, "linux");
364        assert!(result.confidence > 0.5);
365        assert!(!result.matched_patterns.is_empty());
366    }
367
368    #[test]
369    fn test_platform_registry_generic_code() {
370        let registry = PlatformRegistry::new();
371
372        let generic_source = r#"
373#include <stdio.h>
374#include <stdlib.h>
375
376int main(int argc, char **argv) {
377    printf("Hello, World!\n");
378    return 0;
379}
380"#;
381
382        let result = registry.detect(generic_source);
383        // Generic code should have low confidence for any platform
384        assert!(result.confidence < 0.3);
385    }
386
387    #[test]
388    fn test_platform_registry_get() {
389        let registry = PlatformRegistry::new();
390
391        let linux = registry.get("linux");
392        assert!(linux.is_some());
393        assert_eq!(linux.unwrap().name(), "Linux Kernel");
394
395        let unknown = registry.get("unknown");
396        assert!(unknown.is_none());
397    }
398}