1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4use tokio::sync::mpsc;
5
6use crate::core::review::{CommitStatus, Issue};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct AnalysisRequest {
10 pub file_path: String,
11 pub content: String,
12 pub language: String,
13 pub commit_status: CommitStatus,
14}
15
16#[derive(Debug, Clone)]
17pub struct ProgressUpdate {
18 pub current_file: String,
19 pub progress: f64,
20 pub stage: String,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum GpuBackend {
25 Metal,
26 Cuda,
27 Mkl,
28 Cpu,
29}
30
31impl std::fmt::Display for GpuBackend {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 GpuBackend::Metal => write!(f, "Metal"),
35 GpuBackend::Cuda => write!(f, "CUDA"),
36 GpuBackend::Mkl => write!(f, "MKL"),
37 GpuBackend::Cpu => write!(f, "CPU"),
38 }
39 }
40}
41
42pub struct AIAnalyzer {
43 backend: GpuBackend,
44}
45
46impl AIAnalyzer {
47 pub async fn new(use_gpu: bool) -> Result<Self> {
48 println!("š§ Initializing AI analyzer...");
49
50 let backend = if use_gpu {
52 Self::detect_gpu_backend()
53 } else {
54 GpuBackend::Cpu
55 };
56
57 println!("š§ Using backend: {backend:?}");
58
59 println!("š AI inference currently disabled due to token sampling issues");
60 println!("š§ Using enhanced rule-based analysis for comprehensive code review");
61
62 let analyzer = AIAnalyzer { backend };
63
64 println!(
66 "š§ AI Analyzer initialized with {} backend",
67 analyzer.get_backend()
68 );
69
70 Ok(analyzer)
71 }
72
73 pub fn get_backend(&self) -> &GpuBackend {
75 &self.backend
76 }
77
78 fn detect_gpu_backend() -> GpuBackend {
79 if cfg!(target_os = "macos") && Self::is_apple_silicon() {
81 println!("š Apple Silicon detected, using Metal backend");
82 GpuBackend::Metal
83 }
84 else if Self::has_cuda_support() {
86 println!("š¢ NVIDIA CUDA detected, using CUDA backend");
87 GpuBackend::Cuda
88 }
89 else if Self::has_mkl_support() {
91 println!("šµ Intel MKL detected, using MKL backend");
92 GpuBackend::Mkl
93 }
94 else {
96 println!("š» No GPU acceleration detected, falling back to CPU");
97 GpuBackend::Cpu
98 }
99 }
100
101 fn is_apple_silicon() -> bool {
102 cfg!(target_arch = "aarch64") && cfg!(target_os = "macos")
104 }
105
106 fn has_cuda_support() -> bool {
107 std::process::Command::new("nvidia-smi")
110 .output()
111 .map(|output| output.status.success())
112 .unwrap_or(false)
113 }
114
115 fn has_mkl_support() -> bool {
116 cfg!(target_arch = "x86_64")
119 }
120
121 pub async fn analyze_file(
122 &self,
123 request: AnalysisRequest,
124 progress_tx: Option<mpsc::UnboundedSender<ProgressUpdate>>,
125 ) -> Result<Vec<Issue>> {
126 let _language = self.detect_language(&request.file_path);
127
128 if let Some(ref tx) = progress_tx {
129 let _ = tx.send(ProgressUpdate {
130 current_file: request.file_path.clone(),
131 progress: 0.0,
132 stage: "Starting analysis".to_string(),
133 });
134 }
135
136 let mut issues = Vec::new();
137
138 issues.extend(self.rule_based_analysis(&request)?);
141
142 if let Some(ref tx) = progress_tx {
146 let _ = tx.send(ProgressUpdate {
147 current_file: request.file_path.clone(),
148 progress: 100.0,
149 stage: "Analysis complete".to_string(),
150 });
151 }
152
153 Ok(issues)
154 }
155
156 fn rule_based_analysis(&self, request: &AnalysisRequest) -> Result<Vec<Issue>> {
157 let mut issues = Vec::new();
158
159 for (line_num, line) in request.content.lines().enumerate() {
160 let line_number = line_num + 1;
161 let line_lower = line.to_lowercase();
162
163 if (line_lower.contains("password")
167 || line_lower.contains("api_key")
168 || line_lower.contains("secret"))
169 && line.contains("=")
170 && (line.contains("\"") || line.contains("'"))
171 {
172 issues.push(Issue {
173 file: request.file_path.clone(),
174 line: line_number,
175 severity: "Critical".to_string(),
176 category: "Security".to_string(),
177 description: "Hardcoded credentials detected - use environment variables"
178 .to_string(),
179 commit_status: request.commit_status.clone(),
180 });
181 }
182
183 if line.contains("eval(") || line.contains("exec(") {
185 issues.push(Issue {
186 file: request.file_path.clone(),
187 line: line_number,
188 severity: "Critical".to_string(),
189 category: "Security".to_string(),
190 description: "Code injection vulnerability - avoid eval/exec".to_string(),
191 commit_status: request.commit_status.clone(),
192 });
193 }
194
195 if line.contains("query")
197 && line.contains("format!")
198 && (line.contains("SELECT") || line.contains("INSERT") || line.contains("UPDATE"))
199 {
200 issues.push(Issue {
201 file: request.file_path.clone(),
202 line: line_number,
203 severity: "Critical".to_string(),
204 category: "Security".to_string(),
205 description: "Potential SQL injection - use parameterized queries".to_string(),
206 commit_status: request.commit_status.clone(),
207 });
208 }
209
210 if (line.contains("Command::new")
212 || line.contains("subprocess")
213 || line.contains("system("))
214 && (line.contains("format!")
215 || line.contains("user_input")
216 || line.contains("args"))
217 {
218 issues.push(Issue {
219 file: request.file_path.clone(),
220 line: line_number,
221 severity: "Critical".to_string(),
222 category: "Security".to_string(),
223 description: "Command injection vulnerability - sanitize inputs".to_string(),
224 commit_status: request.commit_status.clone(),
225 });
226 }
227
228 if line.contains("../")
230 && (line.contains("read") || line.contains("open") || line.contains("file"))
231 {
232 issues.push(Issue {
233 file: request.file_path.clone(),
234 line: line_number,
235 severity: "High".to_string(),
236 category: "Security".to_string(),
237 description: "Path traversal vulnerability - validate file paths".to_string(),
238 commit_status: request.commit_status.clone(),
239 });
240 }
241
242 if line.contains("for") && line.trim().starts_with("for") {
246 let lines: Vec<&str> = request.content.lines().collect();
248 for (idx, _) in lines
249 .iter()
250 .enumerate()
251 .take(std::cmp::min(line_num + 10, lines.len()))
252 .skip(line_num + 1)
253 {
254 if lines[idx].trim().starts_with("for") {
255 issues.push(Issue {
256 file: request.file_path.clone(),
257 line: line_number,
258 severity: "Medium".to_string(),
259 category: "Performance".to_string(),
260 description: "Nested loops detected - consider optimization"
261 .to_string(),
262 commit_status: request.commit_status.clone(),
263 });
264 break;
265 }
266 }
267 }
268
269 match request.language.as_str() {
271 "rust" => {
272 if line.contains("unsafe") {
274 issues.push(Issue {
275 file: request.file_path.clone(),
276 line: line_number,
277 severity: "High".to_string(),
278 category: "Security".to_string(),
279 description: "Unsafe code block - requires justification and review"
280 .to_string(),
281 commit_status: request.commit_status.clone(),
282 });
283 }
284
285 if line.contains("std::ptr::null") {
286 issues.push(Issue {
287 file: request.file_path.clone(),
288 line: line_number,
289 severity: "Critical".to_string(),
290 category: "Security".to_string(),
291 description: "Null pointer dereference - will cause segfault"
292 .to_string(),
293 commit_status: request.commit_status.clone(),
294 });
295 }
296
297 if line.contains("unwrap()") && !line.contains("expect(") {
299 issues.push(Issue {
300 file: request.file_path.clone(),
301 line: line_number,
302 severity: "Medium".to_string(),
303 category: "Error Handling".to_string(),
304 description:
305 "Use expect() or proper error handling instead of unwrap()"
306 .to_string(),
307 commit_status: request.commit_status.clone(),
308 });
309 }
310
311 if line.contains(".clone()") && line.contains("&") {
313 issues.push(Issue {
314 file: request.file_path.clone(),
315 line: line_number,
316 severity: "Low".to_string(),
317 category: "Performance".to_string(),
318 description: "Unnecessary clone - consider borrowing instead"
319 .to_string(),
320 commit_status: request.commit_status.clone(),
321 });
322 }
323 }
324 "python" => {
325 if line.contains("pickle.loads") && !line.contains("trusted") {
327 issues.push(Issue {
328 file: request.file_path.clone(),
329 line: line_number,
330 severity: "Critical".to_string(),
331 category: "Security".to_string(),
332 description: "Unsafe deserialization - pickle.loads is dangerous"
333 .to_string(),
334 commit_status: request.commit_status.clone(),
335 });
336 }
337
338 if line.contains("yaml.load") && !line.contains("safe_load") {
339 issues.push(Issue {
340 file: request.file_path.clone(),
341 line: line_number,
342 severity: "High".to_string(),
343 category: "Security".to_string(),
344 description: "Use yaml.safe_load instead of yaml.load".to_string(),
345 commit_status: request.commit_status.clone(),
346 });
347 }
348
349 if line.contains("+=") && (line.contains("\"") || line.contains("'")) {
351 issues.push(Issue {
352 file: request.file_path.clone(),
353 line: line_number,
354 severity: "Medium".to_string(),
355 category: "Performance".to_string(),
356 description:
357 "String concatenation in loop - use join() for better performance"
358 .to_string(),
359 commit_status: request.commit_status.clone(),
360 });
361 }
362 }
363 "javascript" | "typescript" => {
364 if line.contains("innerHTML") && line.contains("+") {
366 issues.push(Issue {
367 file: request.file_path.clone(),
368 line: line_number,
369 severity: "High".to_string(),
370 category: "Security".to_string(),
371 description: "XSS vulnerability - validate before setting innerHTML"
372 .to_string(),
373 commit_status: request.commit_status.clone(),
374 });
375 }
376
377 if line.contains("document.getElementById") && line.contains("for") {
379 issues.push(Issue {
380 file: request.file_path.clone(),
381 line: line_number,
382 severity: "Medium".to_string(),
383 category: "Performance".to_string(),
384 description: "DOM query in loop - cache the element reference"
385 .to_string(),
386 commit_status: request.commit_status.clone(),
387 });
388 }
389 }
390 _ => {}
391 }
392
393 if line.contains("TODO") || line.contains("FIXME") || line.contains("HACK") {
396 issues.push(Issue {
397 file: request.file_path.clone(),
398 line: line_number,
399 severity: "Low".to_string(),
400 category: "Code Quality".to_string(),
401 description: "Code comment indicates incomplete implementation".to_string(),
402 commit_status: request.commit_status.clone(),
403 });
404 }
405
406 if line.len() > 120 {
408 issues.push(Issue {
409 file: request.file_path.clone(),
410 line: line_number,
411 severity: "Low".to_string(),
412 category: "Code Quality".to_string(),
413 description: format!(
414 "Line too long ({} chars) - consider breaking into multiple lines",
415 line.len()
416 ),
417 commit_status: request.commit_status.clone(),
418 });
419 }
420 }
421
422 Ok(issues)
423 }
424
425 fn detect_language(&self, file_path: &str) -> String {
426 let path = Path::new(file_path);
427 match path.extension().and_then(|ext| ext.to_str()) {
428 Some("rs") => "rust".to_string(),
429 Some("js") => "javascript".to_string(),
430 Some("ts") => "typescript".to_string(),
431 Some("py") => "python".to_string(),
432 Some("java") => "java".to_string(),
433 Some("cpp") | Some("cc") | Some("cxx") => "cpp".to_string(),
434 Some("c") => "c".to_string(),
435 Some("go") => "go".to_string(),
436 Some("php") => "php".to_string(),
437 Some("rb") => "ruby".to_string(),
438 Some("cs") => "csharp".to_string(),
439 _ => "unknown".to_string(),
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::core::review::CommitStatus;
448
449 fn make_request(file: &str, content: &str, language: &str) -> AnalysisRequest {
450 AnalysisRequest {
451 file_path: file.to_string(),
452 content: content.to_string(),
453 language: language.to_string(),
454 commit_status: CommitStatus::Modified,
455 }
456 }
457
458 #[test]
459 fn test_detect_language_variants() {
460 let analyzer = AIAnalyzer {
461 backend: GpuBackend::Cpu,
462 };
463 assert_eq!(analyzer.detect_language("src/main.rs"), "rust");
464 assert_eq!(analyzer.detect_language("a/b/c.py"), "python");
465 assert_eq!(analyzer.detect_language("index.ts"), "typescript");
466 assert_eq!(analyzer.detect_language("script.js"), "javascript");
467 assert_eq!(analyzer.detect_language("unknown.foo"), "unknown");
468 }
469
470 #[test]
471 fn test_rule_based_analysis_rust_patterns() {
472 let analyzer = AIAnalyzer {
473 backend: GpuBackend::Cpu,
474 };
475 let content = r#"
476 // SECURITY
477 let password = "secret";
478 let _ = eval("2+2");
479 let query = format!("SELECT * FROM users");
480 std::process::Command::new("sh").arg(format!("{}", user_input));
481 let _ = std::fs::read("../etc/passwd");
482 // PERFORMANCE
483 for i in 0..10 {
484 for j in 0..10 {}
485 }
486 // RUST SPECIFIC
487 unsafe { /* do unsafe things */ }
488 let p = std::ptr::null();
489 let _ = something.unwrap();
490 let _y = &x.clone();
491 // QUALITY
492 // TODO: fix
493 // Long line next
494 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
495 "#;
496 let req = make_request("file.rs", content, "rust");
497 let issues = analyzer.rule_based_analysis(&req).unwrap();
498 assert!(!issues.is_empty());
499 assert!(issues.iter().any(|i| i.category == "Security"));
501 assert!(issues.iter().any(|i| i.category == "Performance"));
502 assert!(issues.iter().any(|i| i.category == "Code Quality"));
503 }
504
505 #[test]
506 fn test_rule_based_analysis_python_patterns() {
507 let analyzer = AIAnalyzer {
508 backend: GpuBackend::Cpu,
509 };
510 let content = r#"
511 import pickle
512 data = pickle.loads(b"...")
513 import yaml
514 result = yaml.load("x: 1")
515 s = "";
516 for i in range(10): s += "x"
517 "#;
518 let req = make_request("script.py", content, "python");
519 let issues = analyzer.rule_based_analysis(&req).unwrap();
520 assert!(issues.iter().any(|i| i.category == "Security"));
521 assert!(issues.iter().any(|i| i.category == "Performance"));
522 }
523
524 #[test]
525 fn test_rule_based_analysis_js_patterns() {
526 let analyzer = AIAnalyzer {
527 backend: GpuBackend::Cpu,
528 };
529 let content = r#"
530 let x = "user";
531 element.innerHTML = "<div>" + x;
532 for (let i = 0; i < 10; i++) { document.getElementById("id"); }
533 "#;
534 let req = make_request("script.js", content, "javascript");
535 let issues = analyzer.rule_based_analysis(&req).unwrap();
536 assert!(issues.iter().any(|i| i.category == "Security"));
537 assert!(issues.iter().any(|i| i.category == "Performance"));
538 }
539
540 #[test]
541 fn test_analyze_file_emits_progress_and_issues() {
542 let rt = tokio::runtime::Runtime::new().unwrap();
543 rt.block_on(async {
544 let analyzer = AIAnalyzer::new(false).await.unwrap();
545 let (tx, mut rx) = mpsc::unbounded_channel::<ProgressUpdate>();
546 let req = make_request("file.rs", "let password = \"x\";", "rust");
547 let issues = analyzer.analyze_file(req, Some(tx)).await.unwrap();
548 assert!(!issues.is_empty());
549 let mut got_any = false;
551 for _ in 0..4 {
552 if rx.try_recv().is_ok() {
553 got_any = true;
554 break;
555 }
556 }
557 assert!(got_any, "expected at least one progress message");
558 });
559 }
560}