Skip to main content

car_inference/
outcome.rs

1//! Outcome tracking — learn from inference results to improve routing.
2//!
3//! Two observation channels:
4//! 1. **Conversation signals** — implicit feedback from what happens after an inference
5//!    call (user moved on = accepted, user corrected = rejected, re-asked = rejected).
6//! 2. **Git-diff tracking** — for code generation, compare suggestions to actual commits
7//!    (ground truth, no classification model needed).
8//!
9//! Every inference call produces an `InferenceOutcome`. Outcomes accumulate into
10//! `ModelProfile`s with per-task statistics. The adaptive router uses profiles
11//! to make data-driven model selection.
12
13use std::collections::{HashMap, HashSet};
14use std::time::{SystemTime, UNIX_EPOCH};
15
16use serde::{Deserialize, Serialize};
17
18/// Task type for outcome tracking. Maps to ModelCapability but at the call level.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum InferenceTask {
22    Generate,
23    Embed,
24    Classify,
25    Code,
26    Reasoning,
27}
28
29impl std::fmt::Display for InferenceTask {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            InferenceTask::Generate => write!(f, "generate"),
33            InferenceTask::Embed => write!(f, "embed"),
34            InferenceTask::Classify => write!(f, "classify"),
35            InferenceTask::Code => write!(f, "code"),
36            InferenceTask::Reasoning => write!(f, "reasoning"),
37        }
38    }
39}
40
41/// A single inference invocation record.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct InferenceOutcome {
44    /// Unique trace ID for this invocation.
45    pub trace_id: String,
46    /// Model that was used.
47    pub model_id: String,
48    /// Task type.
49    pub task: InferenceTask,
50    /// How the model was selected.
51    pub routing_reason: String,
52    /// Wall-clock latency in milliseconds.
53    pub latency_ms: u64,
54    /// Input tokens (estimated).
55    pub input_tokens: usize,
56    /// Output tokens (estimated).
57    pub output_tokens: usize,
58    /// Outcome from conversation signal inference.
59    pub inferred_outcome: Option<InferredOutcome>,
60    /// Outcome from git-diff tracking (code only).
61    pub code_outcome: Option<CodeOutcome>,
62    /// Error message if inference failed.
63    pub error: Option<String>,
64    /// Unix timestamp.
65    pub timestamp: u64,
66}
67
68/// Outcome inferred from conversation flow (implicit feedback).
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum InferredOutcome {
72    /// User moved on, built on the response.
73    Accepted { confidence: f64 },
74    /// User used the result but modified it.
75    AcceptedWithEdits { confidence: f64 },
76    /// User corrected, re-asked, or explicitly rejected.
77    Rejected { confidence: f64 },
78    /// No follow-up signal (session ended, inconclusive).
79    Inconclusive,
80}
81
82impl InferredOutcome {
83    /// Convert to a quality score (0.0 = bad, 1.0 = good).
84    pub fn quality_score(&self) -> Option<f64> {
85        match self {
86            InferredOutcome::Accepted { confidence } => Some(*confidence),
87            InferredOutcome::AcceptedWithEdits { confidence } => Some(confidence * 0.7),
88            InferredOutcome::Rejected { confidence } => Some((1.0 - confidence) * 0.3),
89            InferredOutcome::Inconclusive => None,
90        }
91    }
92
93    pub fn is_success(&self) -> Option<bool> {
94        match self {
95            InferredOutcome::Accepted { .. } => Some(true),
96            InferredOutcome::AcceptedWithEdits { .. } => Some(true),
97            InferredOutcome::Rejected { .. } => Some(false),
98            InferredOutcome::Inconclusive => None,
99        }
100    }
101}
102
103/// Outcome from git-diff comparison (code generation ground truth).
104#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(tag = "type", rename_all = "snake_case")]
106pub enum CodeOutcome {
107    /// Suggestion was applied as-is (exact or near-exact match in diff).
108    Applied,
109    /// User changed the same file but differently (partial adoption).
110    Modified,
111    /// File unchanged despite suggestion (rejected / not used).
112    Ignored,
113    /// AST structural diff: signature was changed (breaking change).
114    SignatureChanged,
115    /// AST structural diff: body was modified but signature preserved (non-breaking).
116    BodyModified,
117    /// AST structural diff: new symbol was added.
118    SymbolAdded,
119}
120
121impl CodeOutcome {
122    pub fn quality_score(&self) -> f64 {
123        match self {
124            CodeOutcome::Applied => 1.0,
125            CodeOutcome::SignatureChanged => 0.8,
126            CodeOutcome::BodyModified => 0.7,
127            CodeOutcome::SymbolAdded => 0.7,
128            CodeOutcome::Modified => 0.6,
129            CodeOutcome::Ignored => 0.1,
130        }
131    }
132
133    pub fn is_success(&self) -> bool {
134        !matches!(self, CodeOutcome::Ignored)
135    }
136}
137
138/// Per-task statistics within a model profile.
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
140pub struct TaskStats {
141    pub calls: u64,
142    pub successes: u64,
143    pub failures: u64,
144    /// Running average latency in ms.
145    pub avg_latency_ms: f64,
146    /// Exponential moving average of quality score.
147    pub ema_quality: f64,
148}
149
150impl TaskStats {
151    pub fn success_rate(&self) -> f64 {
152        let total = self.successes + self.failures;
153        if total == 0 {
154            return 0.5;
155        } // prior: assume neutral
156        self.successes as f64 / total as f64
157    }
158}
159
160/// Per-model performance profile, built from observed outcomes.
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ModelProfile {
163    pub model_id: String,
164    pub total_calls: u64,
165    pub success_count: u64,
166    pub fail_count: u64,
167    pub total_latency_ms: u64,
168    /// Total estimated input tokens across all calls.
169    #[serde(default)]
170    pub total_input_tokens: u64,
171    /// Total estimated output tokens across all calls.
172    #[serde(default)]
173    pub total_output_tokens: u64,
174    /// Per-task statistics.
175    pub task_stats: HashMap<String, TaskStats>,
176    /// Overall EMA quality score (0.0 - 1.0).
177    pub ema_quality: f64,
178    /// Derived metric: quality per 1K total tokens. Populated on export
179    /// (not on every update) so it always reflects the latest snapshot.
180    /// Inspired by Meta-Harness: context-token efficiency is a first-class
181    /// optimization target, so it needs to be visible in model_stats.
182    #[serde(default)]
183    pub quality_per_1k_tokens: f64,
184    /// Last updated (unix timestamp).
185    pub updated_at: u64,
186}
187
188impl ModelProfile {
189    pub fn new(model_id: String) -> Self {
190        Self {
191            model_id,
192            total_calls: 0,
193            success_count: 0,
194            fail_count: 0,
195            total_latency_ms: 0,
196            total_input_tokens: 0,
197            total_output_tokens: 0,
198            task_stats: HashMap::new(),
199            ema_quality: 0.5, // neutral prior
200            quality_per_1k_tokens: 0.0,
201            updated_at: now_unix(),
202        }
203    }
204
205    pub fn success_rate(&self) -> f64 {
206        let total = self.success_count + self.fail_count;
207        if total == 0 {
208            return 0.5;
209        }
210        self.success_count as f64 / total as f64
211    }
212
213    pub fn avg_latency_ms(&self) -> f64 {
214        if self.total_calls == 0 {
215            return 0.0;
216        }
217        self.total_latency_ms as f64 / self.total_calls as f64
218    }
219
220    /// Same degradation pattern as SkillStats: fail_count > success_count + threshold.
221    pub fn should_degrade(&self, threshold: u64) -> bool {
222        self.fail_count > self.success_count + threshold
223    }
224
225    /// Get stats for a specific task type.
226    pub fn task_stats(&self, task: InferenceTask) -> Option<&TaskStats> {
227        self.task_stats.get(&task.to_string())
228    }
229
230    /// Total tokens observed across all calls (input + output).
231    pub fn total_tokens(&self) -> u64 {
232        self.total_input_tokens + self.total_output_tokens
233    }
234
235    /// Quality per 1000 tokens: `ema_quality * 1000 / total_tokens`.
236    /// Returns 0.0 before any tokens have been observed.
237    pub fn compute_quality_per_1k_tokens(&self) -> f64 {
238        let total = self.total_tokens();
239        if total == 0 {
240            return 0.0;
241        }
242        self.ema_quality * 1000.0 / total as f64
243    }
244}
245
246/// EMA smoothing factor. Higher = more weight on recent observations.
247const EMA_ALPHA: f64 = 0.2;
248
249/// Tracks inference outcomes and builds performance profiles.
250pub struct OutcomeTracker {
251    /// In-memory profiles, keyed by model_id.
252    profiles: HashMap<String, ModelProfile>,
253    /// Pending outcomes: completed inference calls awaiting outcome signal.
254    /// Keyed by trace_id.
255    pending: HashMap<String, InferenceOutcome>,
256    /// Counter for generating trace IDs.
257    trace_counter: u64,
258    /// Models excluded for this session (429/rate-limited). Hard exclusion.
259    excluded: HashSet<String>,
260}
261
262impl OutcomeTracker {
263    pub fn new() -> Self {
264        Self {
265            profiles: HashMap::new(),
266            pending: HashMap::new(),
267            trace_counter: 0,
268            excluded: HashSet::new(),
269        }
270    }
271
272    /// Check if a model is excluded (rate-limited) for this session.
273    pub fn is_excluded(&self, model_id: &str) -> bool {
274        self.excluded.contains(model_id)
275    }
276
277    /// Record that an inference call started. Returns a trace_id.
278    pub fn record_start(
279        &mut self,
280        model_id: &str,
281        task: InferenceTask,
282        routing_reason: &str,
283    ) -> String {
284        self.trace_counter += 1;
285        let trace_id = format!("t-{}-{}", now_unix(), self.trace_counter);
286
287        let outcome = InferenceOutcome {
288            trace_id: trace_id.clone(),
289            model_id: model_id.to_string(),
290            task,
291            routing_reason: routing_reason.to_string(),
292            latency_ms: 0,
293            input_tokens: 0,
294            output_tokens: 0,
295            inferred_outcome: None,
296            code_outcome: None,
297            error: None,
298            timestamp: now_unix(),
299        };
300
301        self.pending.insert(trace_id.clone(), outcome);
302        trace_id
303    }
304
305    /// Record completion of an inference call (timing + token counts).
306    pub fn record_complete(
307        &mut self,
308        trace_id: &str,
309        latency_ms: u64,
310        input_tokens: usize,
311        output_tokens: usize,
312    ) {
313        if let Some(outcome) = self.pending.get_mut(trace_id) {
314            outcome.latency_ms = latency_ms;
315            outcome.input_tokens = input_tokens;
316            outcome.output_tokens = output_tokens;
317
318            // Update profile with timing data
319            let profile = self
320                .profiles
321                .entry(outcome.model_id.clone())
322                .or_insert_with(|| ModelProfile::new(outcome.model_id.clone()));
323
324            profile.total_calls += 1;
325            profile.total_latency_ms += latency_ms;
326            profile.total_input_tokens += input_tokens as u64;
327            profile.total_output_tokens += output_tokens as u64;
328
329            let task_key = outcome.task.to_string();
330            let ts = profile.task_stats.entry(task_key).or_default();
331            ts.calls += 1;
332            ts.avg_latency_ms =
333                ts.avg_latency_ms + (latency_ms as f64 - ts.avg_latency_ms) / ts.calls as f64;
334
335            profile.updated_at = now_unix();
336        }
337    }
338
339    /// Record a failure.
340    pub fn record_failure(&mut self, trace_id: &str, error: &str) {
341        if let Some(outcome) = self.pending.get_mut(trace_id) {
342            outcome.error = Some(error.to_string());
343
344            let profile = self
345                .profiles
346                .entry(outcome.model_id.clone())
347                .or_insert_with(|| ModelProfile::new(outcome.model_id.clone()));
348
349            profile.fail_count += 1;
350
351            // Rate-limit errors (429) get a harsher penalty — the model is
352            // guaranteed to fail again, so drop quality aggressively.
353            let is_rate_limited = error.contains("429") || error.contains("RESOURCE_EXHAUSTED");
354            if is_rate_limited {
355                // Hard-exclude for the rest of this session (#13)
356                self.excluded.insert(outcome.model_id.clone());
357                profile.ema_quality *= 0.1;
358            } else {
359                profile.ema_quality = profile.ema_quality * (1.0 - EMA_ALPHA) + 0.0 * EMA_ALPHA;
360            }
361
362            let task_key = outcome.task.to_string();
363            let ts = profile.task_stats.entry(task_key).or_default();
364            ts.failures += 1;
365            if is_rate_limited {
366                ts.ema_quality *= 0.1;
367            } else {
368                ts.ema_quality = ts.ema_quality * (1.0 - EMA_ALPHA);
369            }
370
371            profile.updated_at = now_unix();
372        }
373
374        // Failed outcomes don't need further tracking
375        self.pending.remove(trace_id);
376    }
377
378    /// Record an inferred outcome from conversation signals.
379    pub fn record_inferred_outcome(&mut self, trace_id: &str, outcome: InferredOutcome) {
380        if let Some(pending) = self.pending.remove(trace_id) {
381            self.apply_outcome(&pending, outcome.quality_score(), outcome.is_success());
382        }
383    }
384
385    /// Record an outcome from git-diff comparison (code generation).
386    pub fn record_code_outcome(&mut self, trace_id: &str, outcome: CodeOutcome) {
387        if let Some(pending) = self.pending.remove(trace_id) {
388            self.apply_outcome(
389                &pending,
390                Some(outcome.quality_score()),
391                Some(outcome.is_success()),
392            );
393        }
394    }
395
396    /// Resolve all pending outcomes for a completed conversation turn.
397    /// Called with the inferred outcomes from conversation signal analysis.
398    pub fn resolve_pending_from_signals(&mut self, outcomes: Vec<(String, InferredOutcome)>) {
399        for (trace_id, inferred) in outcomes {
400            self.record_inferred_outcome(&trace_id, inferred);
401        }
402    }
403
404    /// Infer outcomes from a sequence of action results.
405    ///
406    /// In a reasoning session, each action's output feeds the next. If action N
407    /// produced output and action N+1 succeeded using it, N was implicitly accepted.
408    /// If N produced empty output or N+1 failed, N was implicitly rejected.
409    ///
410    /// Returns (trace_id, inferred_outcome) pairs ready for `resolve_pending_from_signals`.
411    pub fn infer_outcomes_from_action_sequence(
412        &self,
413        action_results: &[(String, bool, f64, String)], // (trace_id, success, confidence, output)
414    ) -> Vec<(String, InferredOutcome)> {
415        let mut outcomes = Vec::new();
416
417        for (i, (trace_id, success, confidence, output)) in action_results.iter().enumerate() {
418            if trace_id.is_empty() {
419                continue; // No trace (e.g., memgine-only action)
420            }
421
422            if !success {
423                outcomes.push((
424                    trace_id.clone(),
425                    InferredOutcome::Rejected {
426                        confidence: *confidence,
427                    },
428                ));
429                continue;
430            }
431
432            // Check if the next action used this one's output (implicit acceptance)
433            let next_succeeded = action_results
434                .get(i + 1)
435                .map(|(_, s, _, _)| *s)
436                .unwrap_or(true); // Last action: assume accepted if successful
437
438            let has_output = !output.trim().is_empty();
439
440            if has_output && next_succeeded {
441                outcomes.push((
442                    trace_id.clone(),
443                    InferredOutcome::Accepted {
444                        confidence: *confidence,
445                    },
446                ));
447            } else if has_output && !next_succeeded {
448                // Output existed but downstream failed — may not be this action's fault
449                outcomes.push((
450                    trace_id.clone(),
451                    InferredOutcome::AcceptedWithEdits {
452                        confidence: confidence * 0.7,
453                    },
454                ));
455            } else {
456                outcomes.push((trace_id.clone(), InferredOutcome::Inconclusive));
457            }
458        }
459
460        outcomes
461    }
462
463    /// Get the profile for a model.
464    pub fn profile(&self, model_id: &str) -> Option<&ModelProfile> {
465        self.profiles.get(model_id)
466    }
467
468    /// Get all profiles.
469    pub fn all_profiles(&self) -> &HashMap<String, ModelProfile> {
470        &self.profiles
471    }
472
473    /// Get pending trace IDs (for conversation signal analysis).
474    pub fn pending_trace_ids(&self) -> Vec<String> {
475        self.pending.keys().cloned().collect()
476    }
477
478    /// Get a pending outcome by trace_id.
479    pub fn get_pending(&self, trace_id: &str) -> Option<&InferenceOutcome> {
480        self.pending.get(trace_id)
481    }
482
483    /// Export profiles for serialization / persistence. Derived metrics
484    /// (quality_per_1k_tokens) are recomputed on the way out so callers
485    /// always see a consistent snapshot.
486    pub fn export_profiles(&self) -> Vec<ModelProfile> {
487        self.profiles
488            .values()
489            .cloned()
490            .map(|mut p| {
491                p.quality_per_1k_tokens = p.compute_quality_per_1k_tokens();
492                p
493            })
494            .collect()
495    }
496
497    /// Import profiles (from persistence).
498    pub fn import_profiles(&mut self, profiles: Vec<ModelProfile>) {
499        for p in profiles {
500            self.profiles.insert(p.model_id.clone(), p);
501        }
502    }
503
504    /// Save profiles to a JSON file for cross-session persistence (#13).
505    pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), std::io::Error> {
506        let profiles = self.export_profiles();
507        let json = serde_json::to_string_pretty(&profiles)
508            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
509        if let Some(parent) = path.parent() {
510            std::fs::create_dir_all(parent)?;
511        }
512        std::fs::write(path, json)
513    }
514
515    /// Load profiles from a JSON file for cross-session persistence (#13).
516    pub fn load_from_file(&mut self, path: &std::path::Path) -> Result<usize, std::io::Error> {
517        if !path.exists() {
518            return Ok(0);
519        }
520        let json = std::fs::read_to_string(path)?;
521        let profiles: Vec<ModelProfile> = serde_json::from_str(&json)
522            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
523        let count = profiles.len();
524        self.import_profiles(profiles);
525        Ok(count)
526    }
527
528    /// Apply a quality signal to the model's profile.
529    fn apply_outcome(
530        &mut self,
531        pending: &InferenceOutcome,
532        quality: Option<f64>,
533        success: Option<bool>,
534    ) {
535        let profile = self
536            .profiles
537            .entry(pending.model_id.clone())
538            .or_insert_with(|| ModelProfile::new(pending.model_id.clone()));
539
540        if let Some(q) = quality {
541            profile.ema_quality = profile.ema_quality * (1.0 - EMA_ALPHA) + q * EMA_ALPHA;
542
543            let task_key = pending.task.to_string();
544            let ts = profile.task_stats.entry(task_key).or_default();
545            ts.ema_quality = ts.ema_quality * (1.0 - EMA_ALPHA) + q * EMA_ALPHA;
546        }
547
548        if let Some(ok) = success {
549            if ok {
550                profile.success_count += 1;
551                let task_key = pending.task.to_string();
552                let ts = profile.task_stats.entry(task_key).or_default();
553                ts.successes += 1;
554            } else {
555                profile.fail_count += 1;
556                let task_key = pending.task.to_string();
557                let ts = profile.task_stats.entry(task_key).or_default();
558                ts.failures += 1;
559            }
560        }
561
562        profile.updated_at = now_unix();
563    }
564
565    /// Check git diff for pending code suggestions and resolve outcomes.
566    ///
567    /// Two strategies:
568    /// 1. **AST structural diff** (when `ast` feature is enabled): parse the old
569    ///    and new versions of changed files and compare at the symbol level.
570    ///    This gives precise outcomes: SignatureChanged, BodyModified, SymbolAdded.
571    /// 2. **Text diff fallback**: token matching against the combined git diff.
572    pub fn check_git_outcomes(&mut self, repo_dir: &std::path::Path) {
573        let diff = match std::process::Command::new("git")
574            .args(["diff", "--no-color"])
575            .current_dir(repo_dir)
576            .output()
577        {
578            Ok(output) => String::from_utf8_lossy(&output.stdout).to_string(),
579            Err(_) => return,
580        };
581
582        let staged_diff = match std::process::Command::new("git")
583            .args(["diff", "--cached", "--no-color"])
584            .current_dir(repo_dir)
585            .output()
586        {
587            Ok(output) => String::from_utf8_lossy(&output.stdout).to_string(),
588            Err(_) => String::new(),
589        };
590
591        let combined_diff = format!("{}\n{}", diff, staged_diff);
592
593        if combined_diff.trim().is_empty() {
594            return; // No changes at all
595        }
596
597        // Try AST structural diff on changed files
598        #[cfg(feature = "ast")]
599        let ast_outcome = Self::check_git_outcomes_ast(repo_dir);
600
601        let code_traces: Vec<(String, String)> = self
602            .pending
603            .iter()
604            .filter(|(_, o)| matches!(o.task, InferenceTask::Code))
605            .map(|(id, o)| (id.clone(), o.model_id.clone()))
606            .collect();
607
608        for (trace_id, _model_id) in code_traces {
609            if let Some(pending) = self.pending.get(&trace_id) {
610                // Try AST-based outcome first
611                #[cfg(feature = "ast")]
612                if let Some(ref ast_out) = ast_outcome {
613                    let pending_clone = pending.clone();
614                    self.apply_outcome(
615                        &pending_clone,
616                        Some(ast_out.quality_score()),
617                        Some(ast_out.is_success()),
618                    );
619                    continue;
620                }
621
622                // Fallback: text token matching
623                let output_tokens: Vec<&str> = pending
624                    .routing_reason
625                    .split_whitespace()
626                    .filter(|t| t.len() > 5)
627                    .collect();
628
629                let outcome = if output_tokens.iter().any(|t| combined_diff.contains(t)) {
630                    CodeOutcome::Applied
631                } else {
632                    CodeOutcome::Modified
633                };
634
635                let pending_clone = pending.clone();
636                self.apply_outcome(
637                    &pending_clone,
638                    Some(outcome.quality_score()),
639                    Some(outcome.is_success()),
640                );
641            }
642        }
643    }
644
645    /// AST-based git outcome: parse changed files before and after, diff symbols.
646    #[cfg(feature = "ast")]
647    fn check_git_outcomes_ast(repo_dir: &std::path::Path) -> Option<CodeOutcome> {
648        // Get list of changed files
649        let name_only = std::process::Command::new("git")
650            .args(["diff", "--name-only"])
651            .current_dir(repo_dir)
652            .output()
653            .ok()?;
654        let changed_files: Vec<&str> = std::str::from_utf8(&name_only.stdout)
655            .ok()?
656            .lines()
657            .filter(|f| !f.is_empty())
658            .collect();
659
660        if changed_files.is_empty() {
661            return None;
662        }
663
664        let mut has_sig_change = false;
665        let mut has_body_change = false;
666        let mut has_addition = false;
667
668        for file in &changed_files {
669            // Only parse files tree-sitter supports
670            if car_ast::Language::from_filename(file).is_none() {
671                continue;
672            }
673
674            // Get the HEAD version
675            let old_content = std::process::Command::new("git")
676                .args(["show", &format!("HEAD:{}", file)])
677                .current_dir(repo_dir)
678                .output()
679                .ok()
680                .and_then(|o| {
681                    if o.status.success() {
682                        String::from_utf8(o.stdout).ok()
683                    } else {
684                        None
685                    }
686                });
687
688            // Get the working tree version
689            let new_path = repo_dir.join(file);
690            let new_content = std::fs::read_to_string(&new_path).ok();
691
692            match (old_content, new_content) {
693                (Some(old), Some(new)) => {
694                    let old_parsed = car_ast::parse_file(&old, file);
695                    let new_parsed = car_ast::parse_file(&new, file);
696
697                    if let (Some(old_p), Some(new_p)) = (old_parsed, new_parsed) {
698                        let changes = car_ast::diff_symbols(&old_p, &new_p);
699                        for change in &changes {
700                            match change {
701                                car_ast::SymbolChange::Added(_) => has_addition = true,
702                                car_ast::SymbolChange::Modified {
703                                    signature_changed, ..
704                                } => {
705                                    if *signature_changed {
706                                        has_sig_change = true;
707                                    } else {
708                                        has_body_change = true;
709                                    }
710                                }
711                                car_ast::SymbolChange::Removed(_) => has_sig_change = true,
712                            }
713                        }
714                    }
715                }
716                (None, Some(_)) => has_addition = true, // New file
717                _ => {}
718            }
719        }
720
721        // Return the most significant outcome
722        if has_sig_change {
723            Some(CodeOutcome::SignatureChanged)
724        } else if has_body_change {
725            Some(CodeOutcome::BodyModified)
726        } else if has_addition {
727            Some(CodeOutcome::SymbolAdded)
728        } else {
729            None // No structural changes detected (maybe non-code files changed)
730        }
731    }
732}
733
734impl Default for OutcomeTracker {
735    fn default() -> Self {
736        Self::new()
737    }
738}
739
740fn now_unix() -> u64 {
741    SystemTime::now()
742        .duration_since(UNIX_EPOCH)
743        .unwrap_or_default()
744        .as_secs()
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750
751    #[test]
752    fn lifecycle() {
753        let mut tracker = OutcomeTracker::new();
754
755        // Start an inference call
756        let trace = tracker.record_start(
757            "qwen/qwen3-4b:q4_k_m",
758            InferenceTask::Code,
759            "Code task -> Qwen3-4B",
760        );
761
762        // Complete it
763        tracker.record_complete(&trace, 1200, 100, 50);
764
765        // Profile should have 1 call
766        let profile = tracker.profile("qwen/qwen3-4b:q4_k_m").unwrap();
767        assert_eq!(profile.total_calls, 1);
768        assert_eq!(profile.avg_latency_ms(), 1200.0);
769
770        // Record positive outcome
771        tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.9 });
772
773        let profile = tracker.profile("qwen/qwen3-4b:q4_k_m").unwrap();
774        assert_eq!(profile.success_count, 1);
775        assert!(profile.ema_quality > 0.5); // should have gone up from 0.5
776    }
777
778    #[test]
779    fn failure_degrades() {
780        let mut tracker = OutcomeTracker::new();
781
782        // Simulate 5 failures
783        for i in 0..5 {
784            let trace = tracker.record_start("bad-model", InferenceTask::Generate, "test");
785            tracker.record_complete(&trace, 100, 10, 5);
786            tracker.record_failure(&format!("t-fail-{i}"), "timeout");
787        }
788
789        // But record_failure removes from pending, so we need to use the actual trace_ids
790        // Let's redo this properly
791        let mut tracker = OutcomeTracker::new();
792        for _ in 0..5 {
793            let trace = tracker.record_start("bad-model", InferenceTask::Generate, "test");
794            tracker.record_complete(&trace, 100, 10, 5);
795            tracker.record_failure(&trace, "timeout");
796        }
797
798        let profile = tracker.profile("bad-model").unwrap();
799        assert_eq!(profile.fail_count, 5);
800        assert!(profile.should_degrade(2)); // 5 > 0 + 2
801        assert!(profile.ema_quality < 0.3); // decayed toward 0
802    }
803
804    #[test]
805    fn code_outcome_ground_truth() {
806        let mut tracker = OutcomeTracker::new();
807
808        let trace = tracker.record_start("qwen/qwen3-4b:q4_k_m", InferenceTask::Code, "code");
809        tracker.record_complete(&trace, 500, 200, 100);
810        tracker.record_code_outcome(&trace, CodeOutcome::Applied);
811
812        let profile = tracker.profile("qwen/qwen3-4b:q4_k_m").unwrap();
813        assert_eq!(profile.success_count, 1);
814        // EMA should reflect Applied quality (1.0): 0.5 * 0.8 + 1.0 * 0.2 = 0.6
815        assert!((profile.ema_quality - 0.6).abs() < 0.01);
816    }
817
818    #[test]
819    fn per_task_stats() {
820        let mut tracker = OutcomeTracker::new();
821
822        // Two code calls, one generate call
823        for _ in 0..2 {
824            let trace = tracker.record_start("m1", InferenceTask::Code, "code");
825            tracker.record_complete(&trace, 1000, 100, 50);
826            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.8 });
827        }
828        let trace = tracker.record_start("m1", InferenceTask::Generate, "gen");
829        tracker.record_complete(&trace, 500, 50, 25);
830        tracker.record_inferred_outcome(&trace, InferredOutcome::Rejected { confidence: 0.9 });
831
832        let profile = tracker.profile("m1").unwrap();
833        assert_eq!(profile.total_calls, 3);
834
835        let code_stats = profile.task_stats(InferenceTask::Code).unwrap();
836        assert_eq!(code_stats.calls, 2);
837        assert_eq!(code_stats.successes, 2);
838
839        let gen_stats = profile.task_stats(InferenceTask::Generate).unwrap();
840        assert_eq!(gen_stats.calls, 1);
841        assert_eq!(gen_stats.failures, 1);
842    }
843
844    #[test]
845    fn export_populates_quality_per_1k_tokens() {
846        let mut tracker = OutcomeTracker::new();
847        let trace = tracker.record_start("m1", InferenceTask::Generate, "test");
848        tracker.record_complete(&trace, 100, 800, 200); // 1000 tokens total
849        tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 1.0 });
850
851        let exported = tracker.export_profiles();
852        assert_eq!(exported.len(), 1);
853        let p = &exported[0];
854        // ema_quality after one Accepted{1.0}: 0.5 * 0.8 + 1.0 * 0.2 = 0.6
855        // quality_per_1k = 0.6 * 1000 / 1000 = 0.6
856        assert!(
857            (p.quality_per_1k_tokens - 0.6).abs() < 1e-6,
858            "got {}",
859            p.quality_per_1k_tokens
860        );
861    }
862
863    #[test]
864    fn quality_per_1k_tokens_zero_without_tokens() {
865        let profile = ModelProfile::new("x".into());
866        assert_eq!(profile.compute_quality_per_1k_tokens(), 0.0);
867    }
868
869    #[test]
870    fn export_import() {
871        let mut tracker = OutcomeTracker::new();
872        let trace = tracker.record_start("m1", InferenceTask::Generate, "test");
873        tracker.record_complete(&trace, 100, 10, 5);
874        tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.9 });
875
876        let exported = tracker.export_profiles();
877        assert_eq!(exported.len(), 1);
878
879        let mut tracker2 = OutcomeTracker::new();
880        tracker2.import_profiles(exported);
881        assert!(tracker2.profile("m1").is_some());
882    }
883}