ai_code_buddy/core/
ai_analyzer.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4use tokio::sync::mpsc;
5
6use crate::core::review::{CommitStatus, Issue};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct AnalysisRequest {
10    pub file_path: String,
11    pub content: String,
12    pub language: String,
13    pub commit_status: CommitStatus,
14}
15
16#[derive(Debug, Clone)]
17pub struct ProgressUpdate {
18    pub current_file: String,
19    pub progress: f64,
20    pub stage: String,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum GpuBackend {
25    Metal,
26    Cuda,
27    Mkl,
28    Cpu,
29}
30
31impl std::fmt::Display for GpuBackend {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            GpuBackend::Metal => write!(f, "Metal"),
35            GpuBackend::Cuda => write!(f, "CUDA"),
36            GpuBackend::Mkl => write!(f, "MKL"),
37            GpuBackend::Cpu => write!(f, "CPU"),
38        }
39    }
40}
41
42pub struct AIAnalyzer {
43    backend: GpuBackend,
44}
45
46impl AIAnalyzer {
47    pub async fn new(use_gpu: bool) -> Result<Self> {
48        println!("🧠 Initializing AI analyzer...");
49
50        // Detect and configure GPU backend
51        let backend = if use_gpu {
52            Self::detect_gpu_backend()
53        } else {
54            GpuBackend::Cpu
55        };
56
57        println!("šŸ”§ Using backend: {backend:?}");
58
59        println!("šŸ” AI inference currently disabled due to token sampling issues");
60        println!("šŸ”§ Using enhanced rule-based analysis for comprehensive code review");
61
62        let analyzer = AIAnalyzer { backend };
63
64        // Display the configured backend for diagnostics
65        println!(
66            "šŸ”§ AI Analyzer initialized with {} backend",
67            analyzer.get_backend()
68        );
69
70        Ok(analyzer)
71    }
72
73    /// Get the GPU backend being used by this analyzer
74    pub fn get_backend(&self) -> &GpuBackend {
75        &self.backend
76    }
77
78    fn detect_gpu_backend() -> GpuBackend {
79        // Check if we're on Apple Silicon (Metal support)
80        if cfg!(target_os = "macos") && Self::is_apple_silicon() {
81            println!("šŸŽ Apple Silicon detected, using Metal backend");
82            GpuBackend::Metal
83        }
84        // Check for CUDA support (NVIDIA)
85        else if Self::has_cuda_support() {
86            println!("🟢 NVIDIA CUDA detected, using CUDA backend");
87            GpuBackend::Cuda
88        }
89        // Check for Intel MKL support
90        else if Self::has_mkl_support() {
91            println!("šŸ”µ Intel MKL detected, using MKL backend");
92            GpuBackend::Mkl
93        }
94        // Fallback to CPU
95        else {
96            println!("šŸ’» No GPU acceleration detected, falling back to CPU");
97            GpuBackend::Cpu
98        }
99    }
100
101    fn is_apple_silicon() -> bool {
102        // Check if we're running on Apple Silicon
103        cfg!(target_arch = "aarch64") && cfg!(target_os = "macos")
104    }
105
106    fn has_cuda_support() -> bool {
107        // Check for NVIDIA GPU presence
108        // This is a simplified check - in production you might want to check for actual CUDA runtime
109        std::process::Command::new("nvidia-smi")
110            .output()
111            .map(|output| output.status.success())
112            .unwrap_or(false)
113    }
114
115    fn has_mkl_support() -> bool {
116        // Check for Intel processor
117        // This is a simplified check
118        cfg!(target_arch = "x86_64")
119    }
120
121    pub async fn analyze_file(
122        &self,
123        request: AnalysisRequest,
124        progress_tx: Option<mpsc::UnboundedSender<ProgressUpdate>>,
125    ) -> Result<Vec<Issue>> {
126        let _language = self.detect_language(&request.file_path);
127
128        if let Some(ref tx) = progress_tx {
129            let _ = tx.send(ProgressUpdate {
130                current_file: request.file_path.clone(),
131                progress: 0.0,
132                stage: "Starting analysis".to_string(),
133            });
134        }
135
136        let mut issues = Vec::new();
137
138        // AI inference is currently disabled due to token sampling issues
139        // Using enhanced rule-based analysis which provides comprehensive coverage
140        issues.extend(self.rule_based_analysis(&request)?);
141
142        // TODO: Re-enable AI analysis once token sampling issues are resolved
143        // The AI methods are preserved below for future use
144
145        if let Some(ref tx) = progress_tx {
146            let _ = tx.send(ProgressUpdate {
147                current_file: request.file_path.clone(),
148                progress: 100.0,
149                stage: "Analysis complete".to_string(),
150            });
151        }
152
153        Ok(issues)
154    }
155
156    fn rule_based_analysis(&self, request: &AnalysisRequest) -> Result<Vec<Issue>> {
157        let mut issues = Vec::new();
158
159        for (line_num, line) in request.content.lines().enumerate() {
160            let line_number = line_num + 1;
161            let line_lower = line.to_lowercase();
162
163            // SECURITY PATTERNS
164
165            // Hardcoded credentials
166            if (line_lower.contains("password")
167                || line_lower.contains("api_key")
168                || line_lower.contains("secret"))
169                && line.contains("=")
170                && (line.contains("\"") || line.contains("'"))
171            {
172                issues.push(Issue {
173                    file: request.file_path.clone(),
174                    line: line_number,
175                    severity: "Critical".to_string(),
176                    category: "Security".to_string(),
177                    description: "Hardcoded credentials detected - use environment variables"
178                        .to_string(),
179                    commit_status: request.commit_status.clone(),
180                });
181            }
182
183            // Code injection
184            if line.contains("eval(") || line.contains("exec(") {
185                issues.push(Issue {
186                    file: request.file_path.clone(),
187                    line: line_number,
188                    severity: "Critical".to_string(),
189                    category: "Security".to_string(),
190                    description: "Code injection vulnerability - avoid eval/exec".to_string(),
191                    commit_status: request.commit_status.clone(),
192                });
193            }
194
195            // SQL injection patterns
196            if line.contains("query")
197                && line.contains("format!")
198                && (line.contains("SELECT") || line.contains("INSERT") || line.contains("UPDATE"))
199            {
200                issues.push(Issue {
201                    file: request.file_path.clone(),
202                    line: line_number,
203                    severity: "Critical".to_string(),
204                    category: "Security".to_string(),
205                    description: "Potential SQL injection - use parameterized queries".to_string(),
206                    commit_status: request.commit_status.clone(),
207                });
208            }
209
210            // Command injection patterns
211            if (line.contains("Command::new")
212                || line.contains("subprocess")
213                || line.contains("system("))
214                && (line.contains("format!")
215                    || line.contains("user_input")
216                    || line.contains("args"))
217            {
218                issues.push(Issue {
219                    file: request.file_path.clone(),
220                    line: line_number,
221                    severity: "Critical".to_string(),
222                    category: "Security".to_string(),
223                    description: "Command injection vulnerability - sanitize inputs".to_string(),
224                    commit_status: request.commit_status.clone(),
225                });
226            }
227
228            // Path traversal patterns
229            if line.contains("../")
230                && (line.contains("read") || line.contains("open") || line.contains("file"))
231            {
232                issues.push(Issue {
233                    file: request.file_path.clone(),
234                    line: line_number,
235                    severity: "High".to_string(),
236                    category: "Security".to_string(),
237                    description: "Path traversal vulnerability - validate file paths".to_string(),
238                    commit_status: request.commit_status.clone(),
239                });
240            }
241
242            // PERFORMANCE PATTERNS
243
244            // Nested loops (O(n²) complexity)
245            if line.contains("for") && line.trim().starts_with("for") {
246                // Check if there's another for loop nearby (simple heuristic)
247                let lines: Vec<&str> = request.content.lines().collect();
248                for (idx, _) in lines
249                    .iter()
250                    .enumerate()
251                    .take(std::cmp::min(line_num + 10, lines.len()))
252                    .skip(line_num + 1)
253                {
254                    if lines[idx].trim().starts_with("for") {
255                        issues.push(Issue {
256                            file: request.file_path.clone(),
257                            line: line_number,
258                            severity: "Medium".to_string(),
259                            category: "Performance".to_string(),
260                            description: "Nested loops detected - consider optimization"
261                                .to_string(),
262                            commit_status: request.commit_status.clone(),
263                        });
264                        break;
265                    }
266                }
267            }
268
269            // Language-specific analysis
270            match request.language.as_str() {
271                "rust" => {
272                    // Security
273                    if line.contains("unsafe") {
274                        issues.push(Issue {
275                            file: request.file_path.clone(),
276                            line: line_number,
277                            severity: "High".to_string(),
278                            category: "Security".to_string(),
279                            description: "Unsafe code block - requires justification and review"
280                                .to_string(),
281                            commit_status: request.commit_status.clone(),
282                        });
283                    }
284
285                    if line.contains("std::ptr::null") {
286                        issues.push(Issue {
287                            file: request.file_path.clone(),
288                            line: line_number,
289                            severity: "Critical".to_string(),
290                            category: "Security".to_string(),
291                            description: "Null pointer dereference - will cause segfault"
292                                .to_string(),
293                            commit_status: request.commit_status.clone(),
294                        });
295                    }
296
297                    // Error handling
298                    if line.contains("unwrap()") && !line.contains("expect(") {
299                        issues.push(Issue {
300                            file: request.file_path.clone(),
301                            line: line_number,
302                            severity: "Medium".to_string(),
303                            category: "Error Handling".to_string(),
304                            description:
305                                "Use expect() or proper error handling instead of unwrap()"
306                                    .to_string(),
307                            commit_status: request.commit_status.clone(),
308                        });
309                    }
310
311                    // Performance
312                    if line.contains(".clone()") && line.contains("&") {
313                        issues.push(Issue {
314                            file: request.file_path.clone(),
315                            line: line_number,
316                            severity: "Low".to_string(),
317                            category: "Performance".to_string(),
318                            description: "Unnecessary clone - consider borrowing instead"
319                                .to_string(),
320                            commit_status: request.commit_status.clone(),
321                        });
322                    }
323                }
324                "python" => {
325                    // Security
326                    if line.contains("pickle.loads") && !line.contains("trusted") {
327                        issues.push(Issue {
328                            file: request.file_path.clone(),
329                            line: line_number,
330                            severity: "Critical".to_string(),
331                            category: "Security".to_string(),
332                            description: "Unsafe deserialization - pickle.loads is dangerous"
333                                .to_string(),
334                            commit_status: request.commit_status.clone(),
335                        });
336                    }
337
338                    if line.contains("yaml.load") && !line.contains("safe_load") {
339                        issues.push(Issue {
340                            file: request.file_path.clone(),
341                            line: line_number,
342                            severity: "High".to_string(),
343                            category: "Security".to_string(),
344                            description: "Use yaml.safe_load instead of yaml.load".to_string(),
345                            commit_status: request.commit_status.clone(),
346                        });
347                    }
348
349                    // Performance
350                    if line.contains("+=") && (line.contains("\"") || line.contains("'")) {
351                        issues.push(Issue {
352                            file: request.file_path.clone(),
353                            line: line_number,
354                            severity: "Medium".to_string(),
355                            category: "Performance".to_string(),
356                            description:
357                                "String concatenation in loop - use join() for better performance"
358                                    .to_string(),
359                            commit_status: request.commit_status.clone(),
360                        });
361                    }
362                }
363                "javascript" | "typescript" => {
364                    // Security
365                    if line.contains("innerHTML") && line.contains("+") {
366                        issues.push(Issue {
367                            file: request.file_path.clone(),
368                            line: line_number,
369                            severity: "High".to_string(),
370                            category: "Security".to_string(),
371                            description: "XSS vulnerability - validate before setting innerHTML"
372                                .to_string(),
373                            commit_status: request.commit_status.clone(),
374                        });
375                    }
376
377                    // Performance
378                    if line.contains("document.getElementById") && line.contains("for") {
379                        issues.push(Issue {
380                            file: request.file_path.clone(),
381                            line: line_number,
382                            severity: "Medium".to_string(),
383                            category: "Performance".to_string(),
384                            description: "DOM query in loop - cache the element reference"
385                                .to_string(),
386                            commit_status: request.commit_status.clone(),
387                        });
388                    }
389                }
390                _ => {}
391            }
392
393            // CODE QUALITY PATTERNS
394
395            if line.contains("TODO") || line.contains("FIXME") || line.contains("HACK") {
396                issues.push(Issue {
397                    file: request.file_path.clone(),
398                    line: line_number,
399                    severity: "Low".to_string(),
400                    category: "Code Quality".to_string(),
401                    description: "Code comment indicates incomplete implementation".to_string(),
402                    commit_status: request.commit_status.clone(),
403                });
404            }
405
406            // Long line detection
407            if line.len() > 120 {
408                issues.push(Issue {
409                    file: request.file_path.clone(),
410                    line: line_number,
411                    severity: "Low".to_string(),
412                    category: "Code Quality".to_string(),
413                    description: format!(
414                        "Line too long ({} chars) - consider breaking into multiple lines",
415                        line.len()
416                    ),
417                    commit_status: request.commit_status.clone(),
418                });
419            }
420        }
421
422        Ok(issues)
423    }
424
425    fn detect_language(&self, file_path: &str) -> String {
426        let path = Path::new(file_path);
427        match path.extension().and_then(|ext| ext.to_str()) {
428            Some("rs") => "rust".to_string(),
429            Some("js") => "javascript".to_string(),
430            Some("ts") => "typescript".to_string(),
431            Some("py") => "python".to_string(),
432            Some("java") => "java".to_string(),
433            Some("cpp") | Some("cc") | Some("cxx") => "cpp".to_string(),
434            Some("c") => "c".to_string(),
435            Some("go") => "go".to_string(),
436            Some("php") => "php".to_string(),
437            Some("rb") => "ruby".to_string(),
438            Some("cs") => "csharp".to_string(),
439            _ => "unknown".to_string(),
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::core::review::CommitStatus;
448
449    fn make_request(file: &str, content: &str, language: &str) -> AnalysisRequest {
450        AnalysisRequest {
451            file_path: file.to_string(),
452            content: content.to_string(),
453            language: language.to_string(),
454            commit_status: CommitStatus::Modified,
455        }
456    }
457
458    #[test]
459    fn test_detect_language_variants() {
460        let analyzer = AIAnalyzer {
461            backend: GpuBackend::Cpu,
462        };
463        assert_eq!(analyzer.detect_language("src/main.rs"), "rust");
464        assert_eq!(analyzer.detect_language("a/b/c.py"), "python");
465        assert_eq!(analyzer.detect_language("index.ts"), "typescript");
466        assert_eq!(analyzer.detect_language("script.js"), "javascript");
467        assert_eq!(analyzer.detect_language("unknown.foo"), "unknown");
468    }
469
470    #[test]
471    fn test_rule_based_analysis_rust_patterns() {
472        let analyzer = AIAnalyzer {
473            backend: GpuBackend::Cpu,
474        };
475        let content = r#"
476            // SECURITY
477            let password = "secret";
478            let _ = eval("2+2");
479            let query = format!("SELECT * FROM users");
480            std::process::Command::new("sh").arg(format!("{}", user_input));
481            let _ = std::fs::read("../etc/passwd");
482            // PERFORMANCE
483            for i in 0..10 {
484                for j in 0..10 {}
485            }
486            // RUST SPECIFIC
487            unsafe { /* do unsafe things */ }
488            let p = std::ptr::null();
489            let _ = something.unwrap();
490            let _y = &x.clone();
491            // QUALITY
492            // TODO: fix
493            // Long line next
494            aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
495        "#;
496        let req = make_request("file.rs", content, "rust");
497        let issues = analyzer.rule_based_analysis(&req).unwrap();
498        assert!(!issues.is_empty());
499        // Ensure we hit multiple categories
500        assert!(issues.iter().any(|i| i.category == "Security"));
501        assert!(issues.iter().any(|i| i.category == "Performance"));
502        assert!(issues.iter().any(|i| i.category == "Code Quality"));
503    }
504
505    #[test]
506    fn test_rule_based_analysis_python_patterns() {
507        let analyzer = AIAnalyzer {
508            backend: GpuBackend::Cpu,
509        };
510        let content = r#"
511            import pickle
512            data = pickle.loads(b"...")
513            import yaml
514            result = yaml.load("x: 1")
515            s = "";
516            for i in range(10): s += "x"
517        "#;
518        let req = make_request("script.py", content, "python");
519        let issues = analyzer.rule_based_analysis(&req).unwrap();
520        assert!(issues.iter().any(|i| i.category == "Security"));
521        assert!(issues.iter().any(|i| i.category == "Performance"));
522    }
523
524    #[test]
525    fn test_rule_based_analysis_js_patterns() {
526        let analyzer = AIAnalyzer {
527            backend: GpuBackend::Cpu,
528        };
529        let content = r#"
530            let x = "user";
531            element.innerHTML = "<div>" + x;
532            for (let i = 0; i < 10; i++) { document.getElementById("id"); }
533        "#;
534        let req = make_request("script.js", content, "javascript");
535        let issues = analyzer.rule_based_analysis(&req).unwrap();
536        assert!(issues.iter().any(|i| i.category == "Security"));
537        assert!(issues.iter().any(|i| i.category == "Performance"));
538    }
539
540    #[test]
541    fn test_analyze_file_emits_progress_and_issues() {
542        let rt = tokio::runtime::Runtime::new().unwrap();
543        rt.block_on(async {
544            let analyzer = AIAnalyzer::new(false).await.unwrap();
545            let (tx, mut rx) = mpsc::unbounded_channel::<ProgressUpdate>();
546            let req = make_request("file.rs", "let password = \"x\";", "rust");
547            let issues = analyzer.analyze_file(req, Some(tx)).await.unwrap();
548            assert!(!issues.is_empty());
549            // Try receive up to a couple of progress messages (non-blocking)
550            let mut got_any = false;
551            for _ in 0..4 {
552                if rx.try_recv().is_ok() {
553                    got_any = true;
554                    break;
555                }
556            }
557            assert!(got_any, "expected at least one progress message");
558        });
559    }
560}