Skip to main content

briefcase_core/replay/
executor.rs

1//! Model execution traits for replay
2//!
3//! These traits define how to re-execute a model during replay.
4//! Implementors connect the replay engine to actual model inference.
5
6use super::ReplayError;
7use crate::models::{ExecutionContext, Input, ModelParameters, Output};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[cfg(feature = "async")]
12use async_trait::async_trait;
13
14/// Result of model execution during replay
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ExecutionResult {
17    pub outputs: Vec<Output>,
18    pub execution_time_ms: f64,
19    pub metadata: HashMap<String, serde_json::Value>,
20    pub raw_response: Option<serde_json::Value>,
21}
22
23/// Configuration for how to execute during replay
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ExecutionConfig {
26    /// Maximum time to wait for execution (ms)
27    pub timeout_ms: u64,
28    /// Whether to use cached results if available
29    pub use_cache: bool,
30    /// Whether to record the execution for auditing
31    pub record_execution: bool,
32    /// Environment overrides for execution
33    pub env_overrides: HashMap<String, String>,
34    /// Custom parameters passed to executor
35    pub custom_params: HashMap<String, serde_json::Value>,
36}
37
38impl Default for ExecutionConfig {
39    fn default() -> Self {
40        Self {
41            timeout_ms: 30_000,
42            use_cache: false,
43            record_execution: true,
44            env_overrides: HashMap::new(),
45            custom_params: HashMap::new(),
46        }
47    }
48}
49
50impl ExecutionConfig {
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
56        self.timeout_ms = timeout_ms;
57        self
58    }
59
60    pub fn with_cache(mut self, use_cache: bool) -> Self {
61        self.use_cache = use_cache;
62        self
63    }
64
65    pub fn with_recording(mut self, record: bool) -> Self {
66        self.record_execution = record;
67        self
68    }
69}
70
71/// Comparison result between original and replayed outputs
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ComparisonResult {
74    pub is_match: bool,
75    pub similarity_score: f64, // 0.0 - 1.0
76    pub field_comparisons: Vec<FieldComparison>,
77    pub summary: String,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct FieldComparison {
82    pub field_name: String,
83    pub original_value: serde_json::Value,
84    pub replayed_value: serde_json::Value,
85    pub is_match: bool,
86    pub similarity: f64,
87}
88
89/// Async trait for model re-execution during replay.
90///
91/// Implementors provide the bridge between the replay engine and actual
92/// model inference (OpenAI, Anthropic, local models, etc.).
93///
94/// # Example
95///
96/// ```rust,ignore
97/// struct OpenAIExecutor { client: OpenAIClient }
98///
99/// #[async_trait]
100/// impl ModelExecutor for OpenAIExecutor {
101///     async fn execute(
102///         &self, inputs: &[Input], model_params: Option<&ModelParameters>,
103///         context: &ExecutionContext, config: &ExecutionConfig,
104///     ) -> Result<ExecutionResult, ReplayError> {
105///         let response = self.client.chat(inputs, model_params).await?;
106///         Ok(ExecutionResult { outputs: vec![...], ... })
107///     }
108/// }
109/// ```
110#[cfg(feature = "async")]
111#[async_trait]
112pub trait ModelExecutor: Send + Sync {
113    /// Execute a model with the given inputs and parameters
114    async fn execute(
115        &self,
116        inputs: &[Input],
117        model_params: Option<&ModelParameters>,
118        context: &ExecutionContext,
119        config: &ExecutionConfig,
120    ) -> Result<ExecutionResult, ReplayError>;
121
122    /// Check if this executor supports the given model
123    fn supports_model(&self, model_name: &str) -> bool;
124
125    /// Get executor name for logging/auditing
126    fn executor_name(&self) -> &str;
127
128    /// Compare original outputs with replayed outputs
129    fn compare_outputs(
130        &self,
131        original: &[Output],
132        replayed: &[Output],
133        tolerance: f64,
134    ) -> ComparisonResult {
135        default_compare_outputs(original, replayed, tolerance)
136    }
137}
138
139/// Sync version of ModelExecutor for non-async contexts
140pub trait SyncModelExecutor: Send + Sync {
141    fn execute(
142        &self,
143        inputs: &[Input],
144        model_params: Option<&ModelParameters>,
145        context: &ExecutionContext,
146        config: &ExecutionConfig,
147    ) -> Result<ExecutionResult, ReplayError>;
148
149    fn supports_model(&self, model_name: &str) -> bool;
150
151    fn executor_name(&self) -> &str;
152
153    fn compare_outputs(
154        &self,
155        original: &[Output],
156        replayed: &[Output],
157        tolerance: f64,
158    ) -> ComparisonResult {
159        default_compare_outputs(original, replayed, tolerance)
160    }
161}
162
163/// No-op executor that returns empty outputs.
164/// Used when no real model execution is available.
165pub struct NoOpExecutor;
166
167#[cfg(feature = "async")]
168#[async_trait]
169impl ModelExecutor for NoOpExecutor {
170    async fn execute(
171        &self,
172        _inputs: &[Input],
173        _model_params: Option<&ModelParameters>,
174        _context: &ExecutionContext,
175        _config: &ExecutionConfig,
176    ) -> Result<ExecutionResult, ReplayError> {
177        Ok(ExecutionResult {
178            outputs: vec![],
179            execution_time_ms: 0.0,
180            metadata: HashMap::new(),
181            raw_response: None,
182        })
183    }
184
185    fn supports_model(&self, _model_name: &str) -> bool {
186        true
187    }
188
189    fn executor_name(&self) -> &str {
190        "noop"
191    }
192}
193
194impl SyncModelExecutor for NoOpExecutor {
195    fn execute(
196        &self,
197        _inputs: &[Input],
198        _model_params: Option<&ModelParameters>,
199        _context: &ExecutionContext,
200        _config: &ExecutionConfig,
201    ) -> Result<ExecutionResult, ReplayError> {
202        Ok(ExecutionResult {
203            outputs: vec![],
204            execution_time_ms: 0.0,
205            metadata: HashMap::new(),
206            raw_response: None,
207        })
208    }
209
210    fn supports_model(&self, _model_name: &str) -> bool {
211        true
212    }
213
214    fn executor_name(&self) -> &str {
215        "noop"
216    }
217}
218
219/// Echo executor that returns the original snapshot outputs.
220/// Useful for testing the replay pipeline without a real model.
221pub struct EchoExecutor;
222
223#[cfg(feature = "async")]
224#[async_trait]
225impl ModelExecutor for EchoExecutor {
226    async fn execute(
227        &self,
228        _inputs: &[Input],
229        _model_params: Option<&ModelParameters>,
230        _context: &ExecutionContext,
231        _config: &ExecutionConfig,
232    ) -> Result<ExecutionResult, ReplayError> {
233        // Echo executor returns empty outputs - should be overridden with snapshot outputs
234        Ok(ExecutionResult {
235            outputs: vec![],
236            execution_time_ms: 0.0,
237            metadata: HashMap::new(),
238            raw_response: None,
239        })
240    }
241
242    fn supports_model(&self, _model_name: &str) -> bool {
243        true
244    }
245
246    fn executor_name(&self) -> &str {
247        "echo"
248    }
249}
250
251impl SyncModelExecutor for EchoExecutor {
252    fn execute(
253        &self,
254        _inputs: &[Input],
255        _model_params: Option<&ModelParameters>,
256        _context: &ExecutionContext,
257        _config: &ExecutionConfig,
258    ) -> Result<ExecutionResult, ReplayError> {
259        Ok(ExecutionResult {
260            outputs: vec![],
261            execution_time_ms: 0.0,
262            metadata: HashMap::new(),
263            raw_response: None,
264        })
265    }
266
267    fn supports_model(&self, _model_name: &str) -> bool {
268        true
269    }
270
271    fn executor_name(&self) -> &str {
272        "echo"
273    }
274}
275
276/// Default comparison implementation using string similarity
277fn default_compare_outputs(
278    original: &[Output],
279    replayed: &[Output],
280    tolerance: f64,
281) -> ComparisonResult {
282    if original.len() != replayed.len() {
283        return ComparisonResult {
284            is_match: false,
285            similarity_score: 0.0,
286            field_comparisons: vec![],
287            summary: format!(
288                "Output count mismatch: {} vs {}",
289                original.len(),
290                replayed.len()
291            ),
292        };
293    }
294
295    let mut comparisons = Vec::new();
296    let mut total_similarity = 0.0;
297
298    for (orig, replay) in original.iter().zip(replayed.iter()) {
299        let is_exact = orig.value == replay.value;
300        let similarity = if is_exact {
301            1.0
302        } else {
303            // Compute string similarity for string values
304            match (&orig.value, &replay.value) {
305                (serde_json::Value::String(a), serde_json::Value::String(b)) => {
306                    strsim::normalized_levenshtein(a, b)
307                }
308                (serde_json::Value::Number(a), serde_json::Value::Number(b)) => {
309                    // Numeric comparison with tolerance
310                    let a_f = a.as_f64().unwrap_or(0.0);
311                    let b_f = b.as_f64().unwrap_or(0.0);
312                    if a_f == 0.0 && b_f == 0.0 {
313                        1.0
314                    } else {
315                        let max = a_f.abs().max(b_f.abs());
316                        if max == 0.0 {
317                            1.0
318                        } else {
319                            1.0 - ((a_f - b_f).abs() / max).min(1.0)
320                        }
321                    }
322                }
323                _ => {
324                    if is_exact {
325                        1.0
326                    } else {
327                        0.0
328                    }
329                }
330            }
331        };
332
333        total_similarity += similarity;
334        comparisons.push(FieldComparison {
335            field_name: orig.name.clone(),
336            original_value: orig.value.clone(),
337            replayed_value: replay.value.clone(),
338            is_match: similarity >= tolerance,
339            similarity,
340        });
341    }
342
343    let avg_similarity = if comparisons.is_empty() {
344        1.0
345    } else {
346        total_similarity / comparisons.len() as f64
347    };
348    let all_match = comparisons.iter().all(|c| c.is_match);
349
350    ComparisonResult {
351        is_match: all_match,
352        similarity_score: avg_similarity,
353        field_comparisons: comparisons.clone(),
354        summary: if all_match {
355            format!(
356                "All outputs match (similarity: {:.2}%)",
357                avg_similarity * 100.0
358            )
359        } else {
360            let mismatched: Vec<_> = comparisons
361                .iter()
362                .filter(|c| !c.is_match)
363                .map(|c| c.field_name.as_str())
364                .collect();
365            format!("Mismatched fields: {}", mismatched.join(", "))
366        },
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use serde_json::json;
374
375    #[test]
376    fn test_execution_config_default() {
377        let config = ExecutionConfig::default();
378        assert_eq!(config.timeout_ms, 30_000);
379        assert!(!config.use_cache);
380        assert!(config.record_execution);
381    }
382
383    #[test]
384    fn test_execution_config_builder() {
385        let config = ExecutionConfig::new()
386            .with_timeout(60_000)
387            .with_cache(true)
388            .with_recording(false);
389
390        assert_eq!(config.timeout_ms, 60_000);
391        assert!(config.use_cache);
392        assert!(!config.record_execution);
393    }
394
395    #[test]
396    fn test_noop_executor_sync() {
397        let executor = NoOpExecutor;
398        assert!(SyncModelExecutor::supports_model(&executor, "any-model"));
399        assert_eq!(SyncModelExecutor::executor_name(&executor), "noop");
400
401        let result = SyncModelExecutor::execute(
402            &executor,
403            &[],
404            None,
405            &ExecutionContext::new(),
406            &ExecutionConfig::default(),
407        )
408        .unwrap();
409        assert!(result.outputs.is_empty());
410    }
411
412    #[test]
413    fn test_echo_executor_sync() {
414        let executor = EchoExecutor;
415        assert!(SyncModelExecutor::supports_model(&executor, "any-model"));
416        assert_eq!(SyncModelExecutor::executor_name(&executor), "echo");
417
418        let result = SyncModelExecutor::execute(
419            &executor,
420            &[],
421            None,
422            &ExecutionContext::new(),
423            &ExecutionConfig::default(),
424        )
425        .unwrap();
426        assert!(result.outputs.is_empty());
427    }
428
429    #[test]
430    fn test_default_compare_outputs_exact_match() {
431        let original = vec![Output::new("output", json!("hello"), "string")];
432        let replayed = vec![Output::new("output", json!("hello"), "string")];
433
434        let result = default_compare_outputs(&original, &replayed, 0.9);
435
436        assert!(result.is_match);
437        assert!(result.similarity_score >= 0.99);
438        assert_eq!(result.field_comparisons.len(), 1);
439    }
440
441    #[test]
442    fn test_default_compare_outputs_mismatch() {
443        let original = vec![Output::new("output", json!("hello"), "string")];
444        let replayed = vec![Output::new("output", json!("world"), "string")];
445
446        let result = default_compare_outputs(&original, &replayed, 0.95);
447
448        assert!(!result.is_match);
449        assert!(result.similarity_score < 1.0);
450    }
451
452    #[test]
453    fn test_default_compare_outputs_count_mismatch() {
454        let original = vec![
455            Output::new("output1", json!("hello"), "string"),
456            Output::new("output2", json!("world"), "string"),
457        ];
458        let replayed = vec![Output::new("output1", json!("hello"), "string")];
459
460        let result = default_compare_outputs(&original, &replayed, 0.9);
461
462        assert!(!result.is_match);
463        assert_eq!(result.similarity_score, 0.0);
464    }
465
466    #[test]
467    fn test_default_compare_outputs_numeric() {
468        let original = vec![Output::new("number", json!(100), "number")];
469        let replayed = vec![Output::new("number", json!(101), "number")];
470
471        let result = default_compare_outputs(&original, &replayed, 0.95);
472
473        assert!(result.is_match); // Should be > 0.95 similarity
474        assert!(result.similarity_score > 0.99);
475    }
476}