1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum FixStrategy {
16 AddClone,
18 AddBorrow,
20 AddLifetime,
22 WrapInOption,
24 WrapInResult,
26 AddTypeAnnotation,
28 Unknown,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ImportDecision {
35 Accept,
37 AcceptWithWarning(String),
39 Reject(String),
41}
42
43impl ImportDecision {
44 pub fn allows_import(&self) -> bool {
46 matches!(self, ImportDecision::Accept | ImportDecision::AcceptWithWarning(_))
47 }
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
52pub struct ImportStats {
53 pub accepted_by_strategy: HashMap<FixStrategy, usize>,
55 pub rejected_by_strategy: HashMap<FixStrategy, usize>,
57 pub warnings: usize,
59 pub total_evaluated: usize,
61}
62
63impl ImportStats {
64 pub fn new() -> Self {
66 Self::default()
67 }
68
69 pub fn record(&mut self, strategy: FixStrategy, decision: &ImportDecision) {
71 self.total_evaluated += 1;
72 match decision {
73 ImportDecision::Accept => {
74 *self.accepted_by_strategy.entry(strategy).or_insert(0) += 1;
75 }
76 ImportDecision::AcceptWithWarning(_) => {
77 *self.accepted_by_strategy.entry(strategy).or_insert(0) += 1;
78 self.warnings += 1;
79 }
80 ImportDecision::Reject(_) => {
81 *self.rejected_by_strategy.entry(strategy).or_insert(0) += 1;
82 }
83 }
84 }
85
86 pub fn acceptance_rate(&self, strategy: FixStrategy) -> f32 {
88 let accepted = self.accepted_by_strategy.get(&strategy).copied().unwrap_or(0);
89 let rejected = self.rejected_by_strategy.get(&strategy).copied().unwrap_or(0);
90 let total = accepted + rejected;
91 if total == 0 {
92 0.0
93 } else {
94 accepted as f32 / total as f32
95 }
96 }
97
98 pub fn overall_acceptance_rate(&self) -> f32 {
100 let accepted: usize = self.accepted_by_strategy.values().sum();
101 if self.total_evaluated == 0 {
102 0.0
103 } else {
104 accepted as f32 / self.total_evaluated as f32
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct SmartImportConfig {
112 pub source_language: SourceLanguage,
114 pub min_confidence: f32,
116 pub allow_warnings: bool,
118}
119
120impl Default for SmartImportConfig {
121 fn default() -> Self {
122 Self { source_language: SourceLanguage::Python, min_confidence: 0.5, allow_warnings: true }
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum SourceLanguage {
129 Python,
130 C,
131 Cpp,
132 Other,
133}
134
135pub fn analyze_fix_strategy(fix_diff: &str) -> FixStrategy {
137 if fix_diff.contains(".clone()") || fix_diff.contains(".to_owned()") {
142 return FixStrategy::AddClone;
143 }
144
145 if fix_diff.contains("<'a>")
147 || fix_diff.contains("'static")
148 || fix_diff.contains("'_")
149 || (fix_diff.contains("'a") && fix_diff.contains("fn "))
150 {
151 return FixStrategy::AddLifetime;
152 }
153
154 if fix_diff.contains(": &mut ")
157 || fix_diff.contains(": &")
158 || fix_diff.contains("(&self)")
159 || fix_diff.contains("(&mut self)")
160 || fix_diff.contains("(x: &")
161 || fix_diff.contains("(y: &")
162 || fix_diff.contains("(z: &")
163 || (fix_diff.contains("&") && fix_diff.contains("+ fn"))
164 {
165 return FixStrategy::AddBorrow;
166 }
167
168 if fix_diff.contains("Option<")
170 || fix_diff.contains("Some(")
171 || fix_diff.contains(".unwrap()")
172 || fix_diff.contains(".is_none()")
173 || fix_diff.contains(".is_some()")
174 {
175 return FixStrategy::WrapInOption;
176 }
177
178 if fix_diff.contains("Result<") || fix_diff.contains("Ok(") || fix_diff.contains("Err(") {
180 return FixStrategy::WrapInResult;
181 }
182
183 if fix_diff.contains(": i32")
185 || fix_diff.contains(": String")
186 || (fix_diff.contains(": ") && !fix_diff.contains(": &"))
187 {
188 return FixStrategy::AddTypeAnnotation;
189 }
190
191 FixStrategy::Unknown
192}
193
194pub fn smart_import_filter(
196 fix_diff: &str,
197 metadata: &HashMap<String, String>,
198 config: &SmartImportConfig,
199) -> ImportDecision {
200 let strategy = analyze_fix_strategy(fix_diff);
201
202 match strategy {
203 FixStrategy::AddClone => {
204 if config.source_language == SourceLanguage::Python {
206 if let Some(construct) = metadata.get("source_construct") {
207 if construct.contains("list") || construct.contains("dict") {
208 return ImportDecision::Reject(
209 "Python collection copy != Rust clone".to_string(),
210 );
211 }
212 }
213 }
214 ImportDecision::Accept
215 }
216 FixStrategy::AddBorrow => {
217 ImportDecision::Accept
219 }
220 FixStrategy::AddLifetime => {
221 ImportDecision::Accept
223 }
224 FixStrategy::WrapInOption => {
225 if config.source_language == SourceLanguage::Python {
227 let has_null_handling = fix_diff.contains("NULL")
229 || fix_diff.contains("nullptr")
230 || fix_diff.contains("null")
231 || fix_diff.contains(".is_none()")
232 || fix_diff.contains(".is_some()")
233 || fix_diff.contains(".unwrap_or");
234
235 if has_null_handling {
236 ImportDecision::Accept
237 } else {
238 ImportDecision::AcceptWithWarning(
239 "Verify NULL handling for C context".to_string(),
240 )
241 }
242 } else {
243 ImportDecision::Accept
244 }
245 }
246 FixStrategy::WrapInResult => {
247 ImportDecision::Accept
249 }
250 FixStrategy::AddTypeAnnotation => {
251 if config.source_language == SourceLanguage::Python {
253 ImportDecision::AcceptWithWarning("Verify type mapping for C context".to_string())
254 } else {
255 ImportDecision::Accept
256 }
257 }
258 FixStrategy::Unknown => ImportDecision::Reject("Unknown fix strategy".to_string()),
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
273 fn test_analyze_strategy_add_clone() {
274 let diff = "- let x = value;\n+ let x = value.clone();";
275 assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddClone);
276 }
277
278 #[test]
279 fn test_analyze_strategy_to_owned() {
280 let diff = "- let s = str_slice;\n+ let s = str_slice.to_owned();";
281 assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddClone);
282 }
283
284 #[test]
285 fn test_analyze_strategy_add_borrow() {
286 let diff = "- fn foo(x: String)\n+ fn foo(x: &String)";
287 assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddBorrow);
288 }
289
290 #[test]
291 fn test_analyze_strategy_add_mut_borrow() {
292 let diff = "- fn foo(x: Vec<i32>)\n+ fn foo(x: &mut Vec<i32>)";
293 assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddBorrow);
294 }
295
296 #[test]
297 fn test_analyze_strategy_add_lifetime() {
298 let diff = "- fn foo(x: &str) -> &str\n+ fn foo<'a>(x: &'a str) -> &'a str";
299 assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddLifetime);
300 }
301
302 #[test]
303 fn test_analyze_strategy_wrap_option() {
304 let diff = "- let x: *const T\n+ let x: Option<&T>";
305 assert_eq!(analyze_fix_strategy(diff), FixStrategy::WrapInOption);
306 }
307
308 #[test]
309 fn test_analyze_strategy_wrap_result() {
310 let diff = "- fn foo() -> i32\n+ fn foo() -> Result<i32, Error>";
311 assert_eq!(analyze_fix_strategy(diff), FixStrategy::WrapInResult);
312 }
313
314 #[test]
315 fn test_analyze_strategy_unknown() {
316 let diff = "- some random change\n+ another random change";
317 assert_eq!(analyze_fix_strategy(diff), FixStrategy::Unknown);
318 }
319
320 #[test]
323 fn test_import_decision_allows_import() {
324 assert!(ImportDecision::Accept.allows_import());
325 assert!(ImportDecision::AcceptWithWarning("warning".into()).allows_import());
326 assert!(!ImportDecision::Reject("reason".into()).allows_import());
327 }
328
329 #[test]
332 fn test_smart_filter_accepts_borrow_from_python() {
333 let diff = "- fn foo(x: String)\n+ fn foo(x: &String)";
334 let metadata = HashMap::new();
335 let config =
336 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
337
338 let decision = smart_import_filter(diff, &metadata, &config);
339 assert_eq!(decision, ImportDecision::Accept);
340 }
341
342 #[test]
343 fn test_smart_filter_rejects_python_list_clone() {
344 let diff = "- let x = lst;\n+ let x = lst.clone();";
345 let mut metadata = HashMap::new();
346 metadata.insert("source_construct".into(), "list_copy".into());
347 let config =
348 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
349
350 let decision = smart_import_filter(diff, &metadata, &config);
351 assert!(matches!(decision, ImportDecision::Reject(_)));
352 }
353
354 #[test]
355 fn test_smart_filter_accepts_clone_without_list_context() {
356 let diff = "- let x = value;\n+ let x = value.clone();";
357 let metadata = HashMap::new();
358 let config =
359 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
360
361 let decision = smart_import_filter(diff, &metadata, &config);
362 assert_eq!(decision, ImportDecision::Accept);
363 }
364
365 #[test]
366 fn test_smart_filter_warns_on_option_without_null() {
367 let diff = "- let x = value\n+ let x = Some(value)";
368 let metadata = HashMap::new();
369 let config =
370 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
371
372 let decision = smart_import_filter(diff, &metadata, &config);
373 assert!(matches!(decision, ImportDecision::AcceptWithWarning(_)));
374 }
375
376 #[test]
377 fn test_smart_filter_accepts_option_with_null() {
378 let diff = "- if (ptr == NULL)\n+ if ptr.is_none()";
379 let metadata = HashMap::new();
380 let config =
381 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
382
383 let decision = smart_import_filter(diff, &metadata, &config);
384 assert!(decision.allows_import());
386 }
387
388 #[test]
389 fn test_smart_filter_rejects_unknown_strategy() {
390 let diff = "random gibberish change";
391 let metadata = HashMap::new();
392 let config = SmartImportConfig::default();
393
394 let decision = smart_import_filter(diff, &metadata, &config);
395 assert!(matches!(decision, ImportDecision::Reject(_)));
396 }
397
398 #[test]
399 fn test_smart_filter_accepts_lifetime_from_any_source() {
400 let diff = "- fn foo(x: &str)\n+ fn foo<'a>(x: &'a str)";
401 let metadata = HashMap::new();
402
403 let config_py =
405 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
406 assert_eq!(smart_import_filter(diff, &metadata, &config_py), ImportDecision::Accept);
407
408 let config_c =
410 SmartImportConfig { source_language: SourceLanguage::C, ..Default::default() };
411 assert_eq!(smart_import_filter(diff, &metadata, &config_c), ImportDecision::Accept);
412 }
413
414 #[test]
417 fn test_import_stats_new() {
418 let stats = ImportStats::new();
419 assert_eq!(stats.total_evaluated, 0);
420 assert_eq!(stats.warnings, 0);
421 }
422
423 #[test]
424 fn test_import_stats_record_accept() {
425 let mut stats = ImportStats::new();
426 stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
427
428 assert_eq!(stats.total_evaluated, 1);
429 assert_eq!(stats.accepted_by_strategy.get(&FixStrategy::AddBorrow), Some(&1));
430 }
431
432 #[test]
433 fn test_import_stats_record_reject() {
434 let mut stats = ImportStats::new();
435 stats.record(FixStrategy::AddClone, &ImportDecision::Reject("reason".into()));
436
437 assert_eq!(stats.total_evaluated, 1);
438 assert_eq!(stats.rejected_by_strategy.get(&FixStrategy::AddClone), Some(&1));
439 }
440
441 #[test]
442 fn test_import_stats_record_warning() {
443 let mut stats = ImportStats::new();
444 stats.record(
445 FixStrategy::WrapInOption,
446 &ImportDecision::AcceptWithWarning("warning".into()),
447 );
448
449 assert_eq!(stats.total_evaluated, 1);
450 assert_eq!(stats.warnings, 1);
451 assert_eq!(stats.accepted_by_strategy.get(&FixStrategy::WrapInOption), Some(&1));
452 }
453
454 #[test]
455 fn test_import_stats_acceptance_rate() {
456 let mut stats = ImportStats::new();
457 stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
459 stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
460 stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
461 stats.record(FixStrategy::AddBorrow, &ImportDecision::Reject("reason".into()));
462
463 let rate = stats.acceptance_rate(FixStrategy::AddBorrow);
464 assert!((rate - 0.75).abs() < 0.01);
465 }
466
467 #[test]
468 fn test_import_stats_overall_acceptance_rate() {
469 let mut stats = ImportStats::new();
470 stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
471 stats.record(FixStrategy::AddClone, &ImportDecision::Accept);
472 stats.record(FixStrategy::Unknown, &ImportDecision::Reject("reason".into()));
473
474 let rate = stats.overall_acceptance_rate();
475 assert!((rate - 0.666).abs() < 0.01);
476 }
477
478 #[test]
479 fn test_import_stats_empty_acceptance_rate() {
480 let stats = ImportStats::new();
481 assert_eq!(stats.acceptance_rate(FixStrategy::AddBorrow), 0.0);
482 assert_eq!(stats.overall_acceptance_rate(), 0.0);
483 }
484
485 #[test]
488 fn test_expected_acceptance_rates_add_borrow() {
489 let mut stats = ImportStats::new();
492 let config =
493 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
494
495 let borrow_diffs = [
497 "- fn foo(x: String)\n+ fn foo(x: &String)",
498 "- fn bar(y: Vec<i32>)\n+ fn bar(y: &Vec<i32>)",
499 "- fn baz(z: T)\n+ fn baz(z: &mut T)",
500 ];
501
502 for diff in &borrow_diffs {
503 let decision = smart_import_filter(diff, &HashMap::new(), &config);
504 stats.record(FixStrategy::AddBorrow, &decision);
505 }
506
507 assert!(
509 stats.acceptance_rate(FixStrategy::AddBorrow) >= 0.95,
510 "AddBorrow should have >=95% acceptance rate, got {}",
511 stats.acceptance_rate(FixStrategy::AddBorrow)
512 );
513 }
514
515 #[test]
516 fn test_expected_acceptance_rates_add_lifetime() {
517 let mut stats = ImportStats::new();
519 let config =
520 SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
521
522 let lifetime_diffs = [
523 "- fn foo(x: &str)\n+ fn foo<'a>(x: &'a str)",
524 "- struct Foo { x: &str }\n+ struct Foo<'a> { x: &'a str }",
525 ];
526
527 for diff in &lifetime_diffs {
528 let decision = smart_import_filter(diff, &HashMap::new(), &config);
529 stats.record(FixStrategy::AddLifetime, &decision);
530 }
531
532 assert!(
533 stats.acceptance_rate(FixStrategy::AddLifetime) >= 0.90,
534 "AddLifetime should have >=90% acceptance rate"
535 );
536 }
537}