Skip to main content

decy_oracle/
import.rs

1//! Smart import filter for cross-project pattern transfer
2//!
3//! Filters imported patterns by fix strategy applicability, not just error code.
4//! Python ownership issues differ from C issues (reference counting vs pointer aliasing).
5//!
6//! # References
7//! - training-oracle-spec.md ยง3.1.2: Smart Import Filter (Yokoten Enhancement)
8//! - Gemini Review: "smart import is better than blind bulk import"
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Fix strategy type derived from diff analysis
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum FixStrategy {
16    /// Add .clone() call
17    AddClone,
18    /// Add borrow (&T or &mut T)
19    AddBorrow,
20    /// Add lifetime annotation
21    AddLifetime,
22    /// Wrap in Option<T>
23    WrapInOption,
24    /// Wrap in Result<T, E>
25    WrapInResult,
26    /// Add explicit type annotation
27    AddTypeAnnotation,
28    /// Unknown or complex strategy
29    Unknown,
30}
31
32/// Decision for importing a pattern
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ImportDecision {
35    /// Accept pattern for import
36    Accept,
37    /// Accept with warning message
38    AcceptWithWarning(String),
39    /// Reject pattern with reason
40    Reject(String),
41}
42
43impl ImportDecision {
44    /// Check if decision allows import
45    pub fn allows_import(&self) -> bool {
46        matches!(self, ImportDecision::Accept | ImportDecision::AcceptWithWarning(_))
47    }
48}
49
50/// Statistics for import operations
51#[derive(Debug, Clone, Default, Serialize, Deserialize)]
52pub struct ImportStats {
53    /// Patterns accepted by strategy
54    pub accepted_by_strategy: HashMap<FixStrategy, usize>,
55    /// Patterns rejected by strategy
56    pub rejected_by_strategy: HashMap<FixStrategy, usize>,
57    /// Patterns accepted with warnings
58    pub warnings: usize,
59    /// Total patterns evaluated
60    pub total_evaluated: usize,
61}
62
63impl ImportStats {
64    /// Create new empty stats
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    /// Record an import decision
70    pub fn record(&mut self, strategy: FixStrategy, decision: &ImportDecision) {
71        self.total_evaluated += 1;
72        match decision {
73            ImportDecision::Accept => {
74                *self.accepted_by_strategy.entry(strategy).or_insert(0) += 1;
75            }
76            ImportDecision::AcceptWithWarning(_) => {
77                *self.accepted_by_strategy.entry(strategy).or_insert(0) += 1;
78                self.warnings += 1;
79            }
80            ImportDecision::Reject(_) => {
81                *self.rejected_by_strategy.entry(strategy).or_insert(0) += 1;
82            }
83        }
84    }
85
86    /// Get acceptance rate for a strategy
87    pub fn acceptance_rate(&self, strategy: FixStrategy) -> f32 {
88        let accepted = self.accepted_by_strategy.get(&strategy).copied().unwrap_or(0);
89        let rejected = self.rejected_by_strategy.get(&strategy).copied().unwrap_or(0);
90        let total = accepted + rejected;
91        if total == 0 {
92            0.0
93        } else {
94            accepted as f32 / total as f32
95        }
96    }
97
98    /// Get overall acceptance rate
99    pub fn overall_acceptance_rate(&self) -> f32 {
100        let accepted: usize = self.accepted_by_strategy.values().sum();
101        if self.total_evaluated == 0 {
102            0.0
103        } else {
104            accepted as f32 / self.total_evaluated as f32
105        }
106    }
107}
108
109/// Smart import filter configuration
110#[derive(Debug, Clone)]
111pub struct SmartImportConfig {
112    /// Source language of patterns (for context-aware filtering)
113    pub source_language: SourceLanguage,
114    /// Minimum confidence threshold for patterns
115    pub min_confidence: f32,
116    /// Allow patterns with warnings
117    pub allow_warnings: bool,
118}
119
120impl Default for SmartImportConfig {
121    fn default() -> Self {
122        Self { source_language: SourceLanguage::Python, min_confidence: 0.5, allow_warnings: true }
123    }
124}
125
126/// Source language for imported patterns
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum SourceLanguage {
129    Python,
130    C,
131    Cpp,
132    Other,
133}
134
135/// Analyze fix diff to determine strategy type
136pub fn analyze_fix_strategy(fix_diff: &str) -> FixStrategy {
137    // Pattern matching on common fix patterns
138    // Order matters: more specific patterns first
139
140    // Clone patterns
141    if fix_diff.contains(".clone()") || fix_diff.contains(".to_owned()") {
142        return FixStrategy::AddClone;
143    }
144
145    // Lifetime patterns (check before borrow since 'a appears in both)
146    if fix_diff.contains("<'a>")
147        || fix_diff.contains("'static")
148        || fix_diff.contains("'_")
149        || (fix_diff.contains("'a") && fix_diff.contains("fn "))
150    {
151        return FixStrategy::AddLifetime;
152    }
153
154    // Borrow patterns - check for borrow additions in function signatures
155    // Look for patterns like ": &String" or ": &mut Vec" or "(x: &"
156    if fix_diff.contains(": &mut ")
157        || fix_diff.contains(": &")
158        || fix_diff.contains("(&self)")
159        || fix_diff.contains("(&mut self)")
160        || fix_diff.contains("(x: &")
161        || fix_diff.contains("(y: &")
162        || fix_diff.contains("(z: &")
163        || (fix_diff.contains("&") && fix_diff.contains("+ fn"))
164    {
165        return FixStrategy::AddBorrow;
166    }
167
168    // Option patterns
169    if fix_diff.contains("Option<")
170        || fix_diff.contains("Some(")
171        || fix_diff.contains(".unwrap()")
172        || fix_diff.contains(".is_none()")
173        || fix_diff.contains(".is_some()")
174    {
175        return FixStrategy::WrapInOption;
176    }
177
178    // Result patterns
179    if fix_diff.contains("Result<") || fix_diff.contains("Ok(") || fix_diff.contains("Err(") {
180        return FixStrategy::WrapInResult;
181    }
182
183    // Type annotation patterns (only if no borrow/option/result)
184    if fix_diff.contains(": i32")
185        || fix_diff.contains(": String")
186        || (fix_diff.contains(": ") && !fix_diff.contains(": &"))
187    {
188        return FixStrategy::AddTypeAnnotation;
189    }
190
191    FixStrategy::Unknown
192}
193
194/// Evaluate whether a pattern should be imported based on fix strategy
195pub fn smart_import_filter(
196    fix_diff: &str,
197    metadata: &HashMap<String, String>,
198    config: &SmartImportConfig,
199) -> ImportDecision {
200    let strategy = analyze_fix_strategy(fix_diff);
201
202    match strategy {
203        FixStrategy::AddClone => {
204            // Clone semantics differ: Python shallow copy vs Rust deep clone
205            if config.source_language == SourceLanguage::Python {
206                if let Some(construct) = metadata.get("source_construct") {
207                    if construct.contains("list") || construct.contains("dict") {
208                        return ImportDecision::Reject(
209                            "Python collection copy != Rust clone".to_string(),
210                        );
211                    }
212                }
213            }
214            ImportDecision::Accept
215        }
216        FixStrategy::AddBorrow => {
217            // Borrow semantics are largely universal
218            ImportDecision::Accept
219        }
220        FixStrategy::AddLifetime => {
221            // Lifetime patterns transfer well
222            ImportDecision::Accept
223        }
224        FixStrategy::WrapInOption => {
225            // Python None vs C NULL have different semantics
226            if config.source_language == SourceLanguage::Python {
227                // Check if pattern handles C NULL pointer checks or uses idiomatic Option methods
228                let has_null_handling = fix_diff.contains("NULL")
229                    || fix_diff.contains("nullptr")
230                    || fix_diff.contains("null")
231                    || fix_diff.contains(".is_none()")
232                    || fix_diff.contains(".is_some()")
233                    || fix_diff.contains(".unwrap_or");
234
235                if has_null_handling {
236                    ImportDecision::Accept
237                } else {
238                    ImportDecision::AcceptWithWarning(
239                        "Verify NULL handling for C context".to_string(),
240                    )
241                }
242            } else {
243                ImportDecision::Accept
244            }
245        }
246        FixStrategy::WrapInResult => {
247            // Error handling patterns are largely universal
248            ImportDecision::Accept
249        }
250        FixStrategy::AddTypeAnnotation => {
251            // Type annotation patterns depend on type system differences
252            if config.source_language == SourceLanguage::Python {
253                ImportDecision::AcceptWithWarning("Verify type mapping for C context".to_string())
254            } else {
255                ImportDecision::Accept
256            }
257        }
258        FixStrategy::Unknown => ImportDecision::Reject("Unknown fix strategy".to_string()),
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    // ============================================================================
267    // RED PHASE TESTS - These should FAIL until implementation is complete
268    // ============================================================================
269
270    // ============ FixStrategy Analysis Tests ============
271
272    #[test]
273    fn test_analyze_strategy_add_clone() {
274        let diff = "- let x = value;\n+ let x = value.clone();";
275        assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddClone);
276    }
277
278    #[test]
279    fn test_analyze_strategy_to_owned() {
280        let diff = "- let s = str_slice;\n+ let s = str_slice.to_owned();";
281        assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddClone);
282    }
283
284    #[test]
285    fn test_analyze_strategy_add_borrow() {
286        let diff = "- fn foo(x: String)\n+ fn foo(x: &String)";
287        assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddBorrow);
288    }
289
290    #[test]
291    fn test_analyze_strategy_add_mut_borrow() {
292        let diff = "- fn foo(x: Vec<i32>)\n+ fn foo(x: &mut Vec<i32>)";
293        assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddBorrow);
294    }
295
296    #[test]
297    fn test_analyze_strategy_add_lifetime() {
298        let diff = "- fn foo(x: &str) -> &str\n+ fn foo<'a>(x: &'a str) -> &'a str";
299        assert_eq!(analyze_fix_strategy(diff), FixStrategy::AddLifetime);
300    }
301
302    #[test]
303    fn test_analyze_strategy_wrap_option() {
304        let diff = "- let x: *const T\n+ let x: Option<&T>";
305        assert_eq!(analyze_fix_strategy(diff), FixStrategy::WrapInOption);
306    }
307
308    #[test]
309    fn test_analyze_strategy_wrap_result() {
310        let diff = "- fn foo() -> i32\n+ fn foo() -> Result<i32, Error>";
311        assert_eq!(analyze_fix_strategy(diff), FixStrategy::WrapInResult);
312    }
313
314    #[test]
315    fn test_analyze_strategy_unknown() {
316        let diff = "- some random change\n+ another random change";
317        assert_eq!(analyze_fix_strategy(diff), FixStrategy::Unknown);
318    }
319
320    // ============ Import Decision Tests ============
321
322    #[test]
323    fn test_import_decision_allows_import() {
324        assert!(ImportDecision::Accept.allows_import());
325        assert!(ImportDecision::AcceptWithWarning("warning".into()).allows_import());
326        assert!(!ImportDecision::Reject("reason".into()).allows_import());
327    }
328
329    // ============ Smart Import Filter Tests ============
330
331    #[test]
332    fn test_smart_filter_accepts_borrow_from_python() {
333        let diff = "- fn foo(x: String)\n+ fn foo(x: &String)";
334        let metadata = HashMap::new();
335        let config =
336            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
337
338        let decision = smart_import_filter(diff, &metadata, &config);
339        assert_eq!(decision, ImportDecision::Accept);
340    }
341
342    #[test]
343    fn test_smart_filter_rejects_python_list_clone() {
344        let diff = "- let x = lst;\n+ let x = lst.clone();";
345        let mut metadata = HashMap::new();
346        metadata.insert("source_construct".into(), "list_copy".into());
347        let config =
348            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
349
350        let decision = smart_import_filter(diff, &metadata, &config);
351        assert!(matches!(decision, ImportDecision::Reject(_)));
352    }
353
354    #[test]
355    fn test_smart_filter_accepts_clone_without_list_context() {
356        let diff = "- let x = value;\n+ let x = value.clone();";
357        let metadata = HashMap::new();
358        let config =
359            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
360
361        let decision = smart_import_filter(diff, &metadata, &config);
362        assert_eq!(decision, ImportDecision::Accept);
363    }
364
365    #[test]
366    fn test_smart_filter_warns_on_option_without_null() {
367        let diff = "- let x = value\n+ let x = Some(value)";
368        let metadata = HashMap::new();
369        let config =
370            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
371
372        let decision = smart_import_filter(diff, &metadata, &config);
373        assert!(matches!(decision, ImportDecision::AcceptWithWarning(_)));
374    }
375
376    #[test]
377    fn test_smart_filter_accepts_option_with_null() {
378        let diff = "- if (ptr == NULL)\n+ if ptr.is_none()";
379        let metadata = HashMap::new();
380        let config =
381            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
382
383        let decision = smart_import_filter(diff, &metadata, &config);
384        // NULL in diff means it's applicable to C context
385        assert!(decision.allows_import());
386    }
387
388    #[test]
389    fn test_smart_filter_rejects_unknown_strategy() {
390        let diff = "random gibberish change";
391        let metadata = HashMap::new();
392        let config = SmartImportConfig::default();
393
394        let decision = smart_import_filter(diff, &metadata, &config);
395        assert!(matches!(decision, ImportDecision::Reject(_)));
396    }
397
398    #[test]
399    fn test_smart_filter_accepts_lifetime_from_any_source() {
400        let diff = "- fn foo(x: &str)\n+ fn foo<'a>(x: &'a str)";
401        let metadata = HashMap::new();
402
403        // Python source
404        let config_py =
405            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
406        assert_eq!(smart_import_filter(diff, &metadata, &config_py), ImportDecision::Accept);
407
408        // C source
409        let config_c =
410            SmartImportConfig { source_language: SourceLanguage::C, ..Default::default() };
411        assert_eq!(smart_import_filter(diff, &metadata, &config_c), ImportDecision::Accept);
412    }
413
414    // ============ Import Stats Tests ============
415
416    #[test]
417    fn test_import_stats_new() {
418        let stats = ImportStats::new();
419        assert_eq!(stats.total_evaluated, 0);
420        assert_eq!(stats.warnings, 0);
421    }
422
423    #[test]
424    fn test_import_stats_record_accept() {
425        let mut stats = ImportStats::new();
426        stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
427
428        assert_eq!(stats.total_evaluated, 1);
429        assert_eq!(stats.accepted_by_strategy.get(&FixStrategy::AddBorrow), Some(&1));
430    }
431
432    #[test]
433    fn test_import_stats_record_reject() {
434        let mut stats = ImportStats::new();
435        stats.record(FixStrategy::AddClone, &ImportDecision::Reject("reason".into()));
436
437        assert_eq!(stats.total_evaluated, 1);
438        assert_eq!(stats.rejected_by_strategy.get(&FixStrategy::AddClone), Some(&1));
439    }
440
441    #[test]
442    fn test_import_stats_record_warning() {
443        let mut stats = ImportStats::new();
444        stats.record(
445            FixStrategy::WrapInOption,
446            &ImportDecision::AcceptWithWarning("warning".into()),
447        );
448
449        assert_eq!(stats.total_evaluated, 1);
450        assert_eq!(stats.warnings, 1);
451        assert_eq!(stats.accepted_by_strategy.get(&FixStrategy::WrapInOption), Some(&1));
452    }
453
454    #[test]
455    fn test_import_stats_acceptance_rate() {
456        let mut stats = ImportStats::new();
457        // 3 accepts, 1 reject for AddBorrow
458        stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
459        stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
460        stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
461        stats.record(FixStrategy::AddBorrow, &ImportDecision::Reject("reason".into()));
462
463        let rate = stats.acceptance_rate(FixStrategy::AddBorrow);
464        assert!((rate - 0.75).abs() < 0.01);
465    }
466
467    #[test]
468    fn test_import_stats_overall_acceptance_rate() {
469        let mut stats = ImportStats::new();
470        stats.record(FixStrategy::AddBorrow, &ImportDecision::Accept);
471        stats.record(FixStrategy::AddClone, &ImportDecision::Accept);
472        stats.record(FixStrategy::Unknown, &ImportDecision::Reject("reason".into()));
473
474        let rate = stats.overall_acceptance_rate();
475        assert!((rate - 0.666).abs() < 0.01);
476    }
477
478    #[test]
479    fn test_import_stats_empty_acceptance_rate() {
480        let stats = ImportStats::new();
481        assert_eq!(stats.acceptance_rate(FixStrategy::AddBorrow), 0.0);
482        assert_eq!(stats.overall_acceptance_rate(), 0.0);
483    }
484
485    // ============ Expected Acceptance Rates from Spec ============
486
487    #[test]
488    fn test_expected_acceptance_rates_add_borrow() {
489        // Spec says AddBorrow should have 95% acceptance
490        // This is a property test that will guide implementation
491        let mut stats = ImportStats::new();
492        let config =
493            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
494
495        // Simulate typical borrow patterns
496        let borrow_diffs = [
497            "- fn foo(x: String)\n+ fn foo(x: &String)",
498            "- fn bar(y: Vec<i32>)\n+ fn bar(y: &Vec<i32>)",
499            "- fn baz(z: T)\n+ fn baz(z: &mut T)",
500        ];
501
502        for diff in &borrow_diffs {
503            let decision = smart_import_filter(diff, &HashMap::new(), &config);
504            stats.record(FixStrategy::AddBorrow, &decision);
505        }
506
507        // All should be accepted
508        assert!(
509            stats.acceptance_rate(FixStrategy::AddBorrow) >= 0.95,
510            "AddBorrow should have >=95% acceptance rate, got {}",
511            stats.acceptance_rate(FixStrategy::AddBorrow)
512        );
513    }
514
515    #[test]
516    fn test_expected_acceptance_rates_add_lifetime() {
517        // Spec says AddLifetime should have 90% acceptance
518        let mut stats = ImportStats::new();
519        let config =
520            SmartImportConfig { source_language: SourceLanguage::Python, ..Default::default() };
521
522        let lifetime_diffs = [
523            "- fn foo(x: &str)\n+ fn foo<'a>(x: &'a str)",
524            "- struct Foo { x: &str }\n+ struct Foo<'a> { x: &'a str }",
525        ];
526
527        for diff in &lifetime_diffs {
528            let decision = smart_import_filter(diff, &HashMap::new(), &config);
529            stats.record(FixStrategy::AddLifetime, &decision);
530        }
531
532        assert!(
533            stats.acceptance_rate(FixStrategy::AddLifetime) >= 0.90,
534            "AddLifetime should have >=90% acceptance rate"
535        );
536    }
537}