Skip to main content

car_inference/
outcome.rs

1//! Outcome tracking — learn from inference results to improve routing.
2//!
3//! Two observation channels:
4//! 1. **Conversation signals** — implicit feedback from what happens after an inference
5//!    call (user moved on = accepted, user corrected = rejected, re-asked = rejected).
6//! 2. **Git-diff tracking** — for code generation, compare suggestions to actual commits
7//!    (ground truth, no classification model needed).
8//!
9//! Every inference call produces an `InferenceOutcome`. Outcomes accumulate into
10//! `ModelProfile`s with per-task statistics. The adaptive router uses profiles
11//! to make data-driven model selection.
12
13use std::collections::HashMap;
14use std::time::{SystemTime, UNIX_EPOCH};
15
16use serde::{Deserialize, Serialize};
17
18/// Task type for outcome tracking. Maps to ModelCapability but at the call level.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum InferenceTask {
22    Generate,
23    Embed,
24    Classify,
25    Code,
26    Reasoning,
27}
28
29impl std::fmt::Display for InferenceTask {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            InferenceTask::Generate => write!(f, "generate"),
33            InferenceTask::Embed => write!(f, "embed"),
34            InferenceTask::Classify => write!(f, "classify"),
35            InferenceTask::Code => write!(f, "code"),
36            InferenceTask::Reasoning => write!(f, "reasoning"),
37        }
38    }
39}
40
41/// A single inference invocation record.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct InferenceOutcome {
44    /// Unique trace ID for this invocation.
45    pub trace_id: String,
46    /// Model that was used.
47    pub model_id: String,
48    /// Task type.
49    pub task: InferenceTask,
50    /// How the model was selected.
51    pub routing_reason: String,
52    /// Wall-clock latency in milliseconds.
53    pub latency_ms: u64,
54    /// Input tokens (estimated).
55    pub input_tokens: usize,
56    /// Output tokens (estimated).
57    pub output_tokens: usize,
58    /// Outcome from conversation signal inference.
59    pub inferred_outcome: Option<InferredOutcome>,
60    /// Outcome from git-diff tracking (code only).
61    pub code_outcome: Option<CodeOutcome>,
62    /// Error message if inference failed.
63    pub error: Option<String>,
64    /// Unix timestamp.
65    pub timestamp: u64,
66}
67
68/// Outcome inferred from conversation flow (implicit feedback).
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum InferredOutcome {
72    /// User moved on, built on the response.
73    Accepted { confidence: f64 },
74    /// User used the result but modified it.
75    AcceptedWithEdits { confidence: f64 },
76    /// User corrected, re-asked, or explicitly rejected.
77    Rejected { confidence: f64 },
78    /// No follow-up signal (session ended, inconclusive).
79    Inconclusive,
80}
81
82impl InferredOutcome {
83    /// Convert to a quality score (0.0 = bad, 1.0 = good).
84    pub fn quality_score(&self) -> Option<f64> {
85        match self {
86            InferredOutcome::Accepted { confidence } => Some(*confidence),
87            InferredOutcome::AcceptedWithEdits { confidence } => Some(confidence * 0.7),
88            InferredOutcome::Rejected { confidence } => Some((1.0 - confidence) * 0.3),
89            InferredOutcome::Inconclusive => None,
90        }
91    }
92
93    pub fn is_success(&self) -> Option<bool> {
94        match self {
95            InferredOutcome::Accepted { .. } => Some(true),
96            InferredOutcome::AcceptedWithEdits { .. } => Some(true),
97            InferredOutcome::Rejected { .. } => Some(false),
98            InferredOutcome::Inconclusive => None,
99        }
100    }
101}
102
103/// Outcome from git-diff comparison (code generation ground truth).
104#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(tag = "type", rename_all = "snake_case")]
106pub enum CodeOutcome {
107    /// Suggestion was applied as-is (exact or near-exact match in diff).
108    Applied,
109    /// User changed the same file but differently (partial adoption).
110    Modified,
111    /// File unchanged despite suggestion (rejected / not used).
112    Ignored,
113    /// AST structural diff: signature was changed (breaking change).
114    SignatureChanged,
115    /// AST structural diff: body was modified but signature preserved (non-breaking).
116    BodyModified,
117    /// AST structural diff: new symbol was added.
118    SymbolAdded,
119}
120
121impl CodeOutcome {
122    pub fn quality_score(&self) -> f64 {
123        match self {
124            CodeOutcome::Applied => 1.0,
125            CodeOutcome::SignatureChanged => 0.8,
126            CodeOutcome::BodyModified => 0.7,
127            CodeOutcome::SymbolAdded => 0.7,
128            CodeOutcome::Modified => 0.6,
129            CodeOutcome::Ignored => 0.1,
130        }
131    }
132
133    pub fn is_success(&self) -> bool {
134        !matches!(self, CodeOutcome::Ignored)
135    }
136}
137
138/// Per-task statistics within a model profile.
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
140pub struct TaskStats {
141    pub calls: u64,
142    pub successes: u64,
143    pub failures: u64,
144    /// Running average latency in ms.
145    pub avg_latency_ms: f64,
146    /// Exponential moving average of quality score.
147    pub ema_quality: f64,
148}
149
150impl TaskStats {
151    pub fn success_rate(&self) -> f64 {
152        let total = self.successes + self.failures;
153        if total == 0 { return 0.5; } // prior: assume neutral
154        self.successes as f64 / total as f64
155    }
156}
157
158/// Per-model performance profile, built from observed outcomes.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ModelProfile {
161    pub model_id: String,
162    pub total_calls: u64,
163    pub success_count: u64,
164    pub fail_count: u64,
165    pub total_latency_ms: u64,
166    /// Per-task statistics.
167    pub task_stats: HashMap<String, TaskStats>,
168    /// Overall EMA quality score (0.0 - 1.0).
169    pub ema_quality: f64,
170    /// Last updated (unix timestamp).
171    pub updated_at: u64,
172}
173
174impl ModelProfile {
175    pub fn new(model_id: String) -> Self {
176        Self {
177            model_id,
178            total_calls: 0,
179            success_count: 0,
180            fail_count: 0,
181            total_latency_ms: 0,
182            task_stats: HashMap::new(),
183            ema_quality: 0.5, // neutral prior
184            updated_at: now_unix(),
185        }
186    }
187
188    pub fn success_rate(&self) -> f64 {
189        let total = self.success_count + self.fail_count;
190        if total == 0 { return 0.5; }
191        self.success_count as f64 / total as f64
192    }
193
194    pub fn avg_latency_ms(&self) -> f64 {
195        if self.total_calls == 0 { return 0.0; }
196        self.total_latency_ms as f64 / self.total_calls as f64
197    }
198
199    /// Same degradation pattern as SkillStats: fail_count > success_count + threshold.
200    pub fn should_degrade(&self, threshold: u64) -> bool {
201        self.fail_count > self.success_count + threshold
202    }
203
204    /// Get stats for a specific task type.
205    pub fn task_stats(&self, task: InferenceTask) -> Option<&TaskStats> {
206        self.task_stats.get(&task.to_string())
207    }
208}
209
210/// EMA smoothing factor. Higher = more weight on recent observations.
211const EMA_ALPHA: f64 = 0.2;
212
213/// Tracks inference outcomes and builds performance profiles.
214pub struct OutcomeTracker {
215    /// In-memory profiles, keyed by model_id.
216    profiles: HashMap<String, ModelProfile>,
217    /// Pending outcomes: completed inference calls awaiting outcome signal.
218    /// Keyed by trace_id.
219    pending: HashMap<String, InferenceOutcome>,
220    /// Counter for generating trace IDs.
221    trace_counter: u64,
222}
223
224impl OutcomeTracker {
225    pub fn new() -> Self {
226        Self {
227            profiles: HashMap::new(),
228            pending: HashMap::new(),
229            trace_counter: 0,
230        }
231    }
232
233    /// Record that an inference call started. Returns a trace_id.
234    pub fn record_start(&mut self, model_id: &str, task: InferenceTask, routing_reason: &str) -> String {
235        self.trace_counter += 1;
236        let trace_id = format!("t-{}-{}", now_unix(), self.trace_counter);
237
238        let outcome = InferenceOutcome {
239            trace_id: trace_id.clone(),
240            model_id: model_id.to_string(),
241            task,
242            routing_reason: routing_reason.to_string(),
243            latency_ms: 0,
244            input_tokens: 0,
245            output_tokens: 0,
246            inferred_outcome: None,
247            code_outcome: None,
248            error: None,
249            timestamp: now_unix(),
250        };
251
252        self.pending.insert(trace_id.clone(), outcome);
253        trace_id
254    }
255
256    /// Record completion of an inference call (timing + token counts).
257    pub fn record_complete(
258        &mut self,
259        trace_id: &str,
260        latency_ms: u64,
261        input_tokens: usize,
262        output_tokens: usize,
263    ) {
264        if let Some(outcome) = self.pending.get_mut(trace_id) {
265            outcome.latency_ms = latency_ms;
266            outcome.input_tokens = input_tokens;
267            outcome.output_tokens = output_tokens;
268
269            // Update profile with timing data
270            let profile = self.profiles
271                .entry(outcome.model_id.clone())
272                .or_insert_with(|| ModelProfile::new(outcome.model_id.clone()));
273
274            profile.total_calls += 1;
275            profile.total_latency_ms += latency_ms;
276
277            let task_key = outcome.task.to_string();
278            let ts = profile.task_stats.entry(task_key).or_default();
279            ts.calls += 1;
280            ts.avg_latency_ms = ts.avg_latency_ms + (latency_ms as f64 - ts.avg_latency_ms) / ts.calls as f64;
281
282            profile.updated_at = now_unix();
283        }
284    }
285
286    /// Record a failure.
287    pub fn record_failure(&mut self, trace_id: &str, error: &str) {
288        if let Some(outcome) = self.pending.get_mut(trace_id) {
289            outcome.error = Some(error.to_string());
290
291            let profile = self.profiles
292                .entry(outcome.model_id.clone())
293                .or_insert_with(|| ModelProfile::new(outcome.model_id.clone()));
294
295            profile.fail_count += 1;
296            profile.ema_quality = profile.ema_quality * (1.0 - EMA_ALPHA) + 0.0 * EMA_ALPHA;
297
298            let task_key = outcome.task.to_string();
299            let ts = profile.task_stats.entry(task_key).or_default();
300            ts.failures += 1;
301            ts.ema_quality = ts.ema_quality * (1.0 - EMA_ALPHA);
302
303            profile.updated_at = now_unix();
304        }
305
306        // Failed outcomes don't need further tracking
307        self.pending.remove(trace_id);
308    }
309
310    /// Record an inferred outcome from conversation signals.
311    pub fn record_inferred_outcome(&mut self, trace_id: &str, outcome: InferredOutcome) {
312        if let Some(pending) = self.pending.remove(trace_id) {
313            self.apply_outcome(&pending, outcome.quality_score(), outcome.is_success());
314        }
315    }
316
317    /// Record an outcome from git-diff comparison (code generation).
318    pub fn record_code_outcome(&mut self, trace_id: &str, outcome: CodeOutcome) {
319        if let Some(pending) = self.pending.remove(trace_id) {
320            self.apply_outcome(&pending, Some(outcome.quality_score()), Some(outcome.is_success()));
321        }
322    }
323
324    /// Resolve all pending outcomes for a completed conversation turn.
325    /// Called with the inferred outcomes from conversation signal analysis.
326    pub fn resolve_pending_from_signals(&mut self, outcomes: Vec<(String, InferredOutcome)>) {
327        for (trace_id, inferred) in outcomes {
328            self.record_inferred_outcome(&trace_id, inferred);
329        }
330    }
331
332    /// Infer outcomes from a sequence of action results.
333    ///
334    /// In a reasoning session, each action's output feeds the next. If action N
335    /// produced output and action N+1 succeeded using it, N was implicitly accepted.
336    /// If N produced empty output or N+1 failed, N was implicitly rejected.
337    ///
338    /// Returns (trace_id, inferred_outcome) pairs ready for `resolve_pending_from_signals`.
339    pub fn infer_outcomes_from_action_sequence(
340        &self,
341        action_results: &[(String, bool, f64, String)], // (trace_id, success, confidence, output)
342    ) -> Vec<(String, InferredOutcome)> {
343        let mut outcomes = Vec::new();
344
345        for (i, (trace_id, success, confidence, output)) in action_results.iter().enumerate() {
346            if trace_id.is_empty() {
347                continue; // No trace (e.g., memgine-only action)
348            }
349
350            if !success {
351                outcomes.push((
352                    trace_id.clone(),
353                    InferredOutcome::Rejected { confidence: *confidence },
354                ));
355                continue;
356            }
357
358            // Check if the next action used this one's output (implicit acceptance)
359            let next_succeeded = action_results.get(i + 1)
360                .map(|(_, s, _, _)| *s)
361                .unwrap_or(true); // Last action: assume accepted if successful
362
363            let has_output = !output.trim().is_empty();
364
365            if has_output && next_succeeded {
366                outcomes.push((
367                    trace_id.clone(),
368                    InferredOutcome::Accepted { confidence: *confidence },
369                ));
370            } else if has_output && !next_succeeded {
371                // Output existed but downstream failed — may not be this action's fault
372                outcomes.push((
373                    trace_id.clone(),
374                    InferredOutcome::AcceptedWithEdits { confidence: confidence * 0.7 },
375                ));
376            } else {
377                outcomes.push((
378                    trace_id.clone(),
379                    InferredOutcome::Inconclusive,
380                ));
381            }
382        }
383
384        outcomes
385    }
386
387    /// Get the profile for a model.
388    pub fn profile(&self, model_id: &str) -> Option<&ModelProfile> {
389        self.profiles.get(model_id)
390    }
391
392    /// Get all profiles.
393    pub fn all_profiles(&self) -> &HashMap<String, ModelProfile> {
394        &self.profiles
395    }
396
397    /// Get pending trace IDs (for conversation signal analysis).
398    pub fn pending_trace_ids(&self) -> Vec<String> {
399        self.pending.keys().cloned().collect()
400    }
401
402    /// Get a pending outcome by trace_id.
403    pub fn get_pending(&self, trace_id: &str) -> Option<&InferenceOutcome> {
404        self.pending.get(trace_id)
405    }
406
407    /// Export profiles for serialization / persistence.
408    pub fn export_profiles(&self) -> Vec<ModelProfile> {
409        self.profiles.values().cloned().collect()
410    }
411
412    /// Import profiles (from persistence).
413    pub fn import_profiles(&mut self, profiles: Vec<ModelProfile>) {
414        for p in profiles {
415            self.profiles.insert(p.model_id.clone(), p);
416        }
417    }
418
419    /// Apply a quality signal to the model's profile.
420    fn apply_outcome(
421        &mut self,
422        pending: &InferenceOutcome,
423        quality: Option<f64>,
424        success: Option<bool>,
425    ) {
426        let profile = self.profiles
427            .entry(pending.model_id.clone())
428            .or_insert_with(|| ModelProfile::new(pending.model_id.clone()));
429
430        if let Some(q) = quality {
431            profile.ema_quality = profile.ema_quality * (1.0 - EMA_ALPHA) + q * EMA_ALPHA;
432
433            let task_key = pending.task.to_string();
434            let ts = profile.task_stats.entry(task_key).or_default();
435            ts.ema_quality = ts.ema_quality * (1.0 - EMA_ALPHA) + q * EMA_ALPHA;
436        }
437
438        if let Some(ok) = success {
439            if ok {
440                profile.success_count += 1;
441                let task_key = pending.task.to_string();
442                let ts = profile.task_stats.entry(task_key).or_default();
443                ts.successes += 1;
444            } else {
445                profile.fail_count += 1;
446                let task_key = pending.task.to_string();
447                let ts = profile.task_stats.entry(task_key).or_default();
448                ts.failures += 1;
449            }
450        }
451
452        profile.updated_at = now_unix();
453    }
454
455    /// Check git diff for pending code suggestions and resolve outcomes.
456    ///
457    /// Two strategies:
458    /// 1. **AST structural diff** (when `ast` feature is enabled): parse the old
459    ///    and new versions of changed files and compare at the symbol level.
460    ///    This gives precise outcomes: SignatureChanged, BodyModified, SymbolAdded.
461    /// 2. **Text diff fallback**: token matching against the combined git diff.
462    pub fn check_git_outcomes(&mut self, repo_dir: &std::path::Path) {
463        let diff = match std::process::Command::new("git")
464            .args(["diff", "--no-color"])
465            .current_dir(repo_dir)
466            .output()
467        {
468            Ok(output) => String::from_utf8_lossy(&output.stdout).to_string(),
469            Err(_) => return,
470        };
471
472        let staged_diff = match std::process::Command::new("git")
473            .args(["diff", "--cached", "--no-color"])
474            .current_dir(repo_dir)
475            .output()
476        {
477            Ok(output) => String::from_utf8_lossy(&output.stdout).to_string(),
478            Err(_) => String::new(),
479        };
480
481        let combined_diff = format!("{}\n{}", diff, staged_diff);
482
483        if combined_diff.trim().is_empty() {
484            return; // No changes at all
485        }
486
487        // Try AST structural diff on changed files
488        #[cfg(feature = "ast")]
489        let ast_outcome = Self::check_git_outcomes_ast(repo_dir);
490
491        let code_traces: Vec<(String, String)> = self.pending.iter()
492            .filter(|(_, o)| matches!(o.task, InferenceTask::Code))
493            .map(|(id, o)| (id.clone(), o.model_id.clone()))
494            .collect();
495
496        for (trace_id, _model_id) in code_traces {
497            if let Some(pending) = self.pending.get(&trace_id) {
498                // Try AST-based outcome first
499                #[cfg(feature = "ast")]
500                if let Some(ref ast_out) = ast_outcome {
501                    let pending_clone = pending.clone();
502                    self.apply_outcome(&pending_clone, Some(ast_out.quality_score()), Some(ast_out.is_success()));
503                    continue;
504                }
505
506                // Fallback: text token matching
507                let output_tokens: Vec<&str> = pending.routing_reason
508                    .split_whitespace()
509                    .filter(|t| t.len() > 5)
510                    .collect();
511
512                let outcome = if output_tokens.iter().any(|t| combined_diff.contains(t)) {
513                    CodeOutcome::Applied
514                } else {
515                    CodeOutcome::Modified
516                };
517
518                let pending_clone = pending.clone();
519                self.apply_outcome(&pending_clone, Some(outcome.quality_score()), Some(outcome.is_success()));
520            }
521        }
522    }
523
524    /// AST-based git outcome: parse changed files before and after, diff symbols.
525    #[cfg(feature = "ast")]
526    fn check_git_outcomes_ast(repo_dir: &std::path::Path) -> Option<CodeOutcome> {
527        // Get list of changed files
528        let name_only = std::process::Command::new("git")
529            .args(["diff", "--name-only"])
530            .current_dir(repo_dir)
531            .output()
532            .ok()?;
533        let changed_files: Vec<&str> = std::str::from_utf8(&name_only.stdout)
534            .ok()?
535            .lines()
536            .filter(|f| !f.is_empty())
537            .collect();
538
539        if changed_files.is_empty() {
540            return None;
541        }
542
543        let mut has_sig_change = false;
544        let mut has_body_change = false;
545        let mut has_addition = false;
546
547        for file in &changed_files {
548            // Only parse files tree-sitter supports
549            if car_ast::Language::from_filename(file).is_none() {
550                continue;
551            }
552
553            // Get the HEAD version
554            let old_content = std::process::Command::new("git")
555                .args(["show", &format!("HEAD:{}", file)])
556                .current_dir(repo_dir)
557                .output()
558                .ok()
559                .and_then(|o| if o.status.success() {
560                    String::from_utf8(o.stdout).ok()
561                } else {
562                    None
563                });
564
565            // Get the working tree version
566            let new_path = repo_dir.join(file);
567            let new_content = std::fs::read_to_string(&new_path).ok();
568
569            match (old_content, new_content) {
570                (Some(old), Some(new)) => {
571                    let old_parsed = car_ast::parse_file(&old, file);
572                    let new_parsed = car_ast::parse_file(&new, file);
573
574                    if let (Some(old_p), Some(new_p)) = (old_parsed, new_parsed) {
575                        let changes = car_ast::diff_symbols(&old_p, &new_p);
576                        for change in &changes {
577                            match change {
578                                car_ast::SymbolChange::Added(_) => has_addition = true,
579                                car_ast::SymbolChange::Modified { signature_changed, .. } => {
580                                    if *signature_changed {
581                                        has_sig_change = true;
582                                    } else {
583                                        has_body_change = true;
584                                    }
585                                }
586                                car_ast::SymbolChange::Removed(_) => has_sig_change = true,
587                            }
588                        }
589                    }
590                }
591                (None, Some(_)) => has_addition = true, // New file
592                _ => {}
593            }
594        }
595
596        // Return the most significant outcome
597        if has_sig_change {
598            Some(CodeOutcome::SignatureChanged)
599        } else if has_body_change {
600            Some(CodeOutcome::BodyModified)
601        } else if has_addition {
602            Some(CodeOutcome::SymbolAdded)
603        } else {
604            None // No structural changes detected (maybe non-code files changed)
605        }
606    }
607}
608
609impl Default for OutcomeTracker {
610    fn default() -> Self {
611        Self::new()
612    }
613}
614
615fn now_unix() -> u64 {
616    SystemTime::now()
617        .duration_since(UNIX_EPOCH)
618        .unwrap_or_default()
619        .as_secs()
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[test]
627    fn lifecycle() {
628        let mut tracker = OutcomeTracker::new();
629
630        // Start an inference call
631        let trace = tracker.record_start(
632            "qwen/qwen3-4b:q4_k_m",
633            InferenceTask::Code,
634            "Code task -> Qwen3-4B",
635        );
636
637        // Complete it
638        tracker.record_complete(&trace, 1200, 100, 50);
639
640        // Profile should have 1 call
641        let profile = tracker.profile("qwen/qwen3-4b:q4_k_m").unwrap();
642        assert_eq!(profile.total_calls, 1);
643        assert_eq!(profile.avg_latency_ms(), 1200.0);
644
645        // Record positive outcome
646        tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.9 });
647
648        let profile = tracker.profile("qwen/qwen3-4b:q4_k_m").unwrap();
649        assert_eq!(profile.success_count, 1);
650        assert!(profile.ema_quality > 0.5); // should have gone up from 0.5
651    }
652
653    #[test]
654    fn failure_degrades() {
655        let mut tracker = OutcomeTracker::new();
656
657        // Simulate 5 failures
658        for i in 0..5 {
659            let trace = tracker.record_start(
660                "bad-model",
661                InferenceTask::Generate,
662                "test",
663            );
664            tracker.record_complete(&trace, 100, 10, 5);
665            tracker.record_failure(&format!("t-fail-{i}"), "timeout");
666        }
667
668        // But record_failure removes from pending, so we need to use the actual trace_ids
669        // Let's redo this properly
670        let mut tracker = OutcomeTracker::new();
671        for _ in 0..5 {
672            let trace = tracker.record_start("bad-model", InferenceTask::Generate, "test");
673            tracker.record_complete(&trace, 100, 10, 5);
674            tracker.record_failure(&trace, "timeout");
675        }
676
677        let profile = tracker.profile("bad-model").unwrap();
678        assert_eq!(profile.fail_count, 5);
679        assert!(profile.should_degrade(2)); // 5 > 0 + 2
680        assert!(profile.ema_quality < 0.3); // decayed toward 0
681    }
682
683    #[test]
684    fn code_outcome_ground_truth() {
685        let mut tracker = OutcomeTracker::new();
686
687        let trace = tracker.record_start("qwen/qwen3-4b:q4_k_m", InferenceTask::Code, "code");
688        tracker.record_complete(&trace, 500, 200, 100);
689        tracker.record_code_outcome(&trace, CodeOutcome::Applied);
690
691        let profile = tracker.profile("qwen/qwen3-4b:q4_k_m").unwrap();
692        assert_eq!(profile.success_count, 1);
693        // EMA should reflect Applied quality (1.0): 0.5 * 0.8 + 1.0 * 0.2 = 0.6
694        assert!((profile.ema_quality - 0.6).abs() < 0.01);
695    }
696
697    #[test]
698    fn per_task_stats() {
699        let mut tracker = OutcomeTracker::new();
700
701        // Two code calls, one generate call
702        for _ in 0..2 {
703            let trace = tracker.record_start("m1", InferenceTask::Code, "code");
704            tracker.record_complete(&trace, 1000, 100, 50);
705            tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.8 });
706        }
707        let trace = tracker.record_start("m1", InferenceTask::Generate, "gen");
708        tracker.record_complete(&trace, 500, 50, 25);
709        tracker.record_inferred_outcome(&trace, InferredOutcome::Rejected { confidence: 0.9 });
710
711        let profile = tracker.profile("m1").unwrap();
712        assert_eq!(profile.total_calls, 3);
713
714        let code_stats = profile.task_stats(InferenceTask::Code).unwrap();
715        assert_eq!(code_stats.calls, 2);
716        assert_eq!(code_stats.successes, 2);
717
718        let gen_stats = profile.task_stats(InferenceTask::Generate).unwrap();
719        assert_eq!(gen_stats.calls, 1);
720        assert_eq!(gen_stats.failures, 1);
721    }
722
723    #[test]
724    fn export_import() {
725        let mut tracker = OutcomeTracker::new();
726        let trace = tracker.record_start("m1", InferenceTask::Generate, "test");
727        tracker.record_complete(&trace, 100, 10, 5);
728        tracker.record_inferred_outcome(&trace, InferredOutcome::Accepted { confidence: 0.9 });
729
730        let exported = tracker.export_profiles();
731        assert_eq!(exported.len(), 1);
732
733        let mut tracker2 = OutcomeTracker::new();
734        tracker2.import_profiles(exported);
735        assert!(tracker2.profile("m1").is_some());
736    }
737}