Skip to main content

briefcase_core/
models.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use uuid::Uuid;
6
7/// Input to an AI decision point
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9pub struct Input {
10    pub name: String,
11    pub value: serde_json::Value,
12    pub data_type: String,
13    #[serde(default = "default_schema_version")]
14    pub schema_version: String,
15}
16
17impl Input {
18    pub fn new(
19        name: impl Into<String>,
20        value: serde_json::Value,
21        data_type: impl Into<String>,
22    ) -> Self {
23        Self {
24            name: name.into(),
25            value,
26            data_type: data_type.into(),
27            schema_version: default_schema_version(),
28        }
29    }
30}
31
32/// Output from an AI decision point
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34pub struct Output {
35    pub name: String,
36    pub value: serde_json::Value,
37    pub data_type: String,
38    pub confidence: Option<f64>, // 0.0 - 1.0
39    #[serde(default = "default_schema_version")]
40    pub schema_version: String,
41}
42
43impl Output {
44    pub fn new(
45        name: impl Into<String>,
46        value: serde_json::Value,
47        data_type: impl Into<String>,
48    ) -> Self {
49        Self {
50            name: name.into(),
51            value,
52            data_type: data_type.into(),
53            confidence: None,
54            schema_version: default_schema_version(),
55        }
56    }
57
58    pub fn with_confidence(mut self, confidence: f64) -> Self {
59        self.confidence = Some(confidence.clamp(0.0, 1.0));
60        self
61    }
62}
63
64/// Model parameters for reproducibility
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66pub struct ModelParameters {
67    pub model_name: String,
68    pub model_version: Option<String>,
69    pub provider: Option<String>, // openai, anthropic, etc.
70    #[serde(default)]
71    pub parameters: HashMap<String, serde_json::Value>,
72    #[serde(default)]
73    pub hyperparameters: HashMap<String, serde_json::Value>,
74    pub weights_hash: Option<String>, // For reproducibility verification
75}
76
77impl ModelParameters {
78    pub fn new(model_name: impl Into<String>) -> Self {
79        Self {
80            model_name: model_name.into(),
81            model_version: None,
82            provider: None,
83            parameters: HashMap::new(),
84            hyperparameters: HashMap::new(),
85            weights_hash: None,
86        }
87    }
88
89    pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
90        self.provider = Some(provider.into());
91        self
92    }
93
94    pub fn with_version(mut self, version: impl Into<String>) -> Self {
95        self.model_version = Some(version.into());
96        self
97    }
98
99    pub fn with_parameter(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
100        self.parameters.insert(key.into(), value);
101        self
102    }
103
104    pub fn with_hyperparameter(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
105        self.hyperparameters.insert(key.into(), value);
106        self
107    }
108
109    pub fn with_weights_hash(mut self, hash: impl Into<String>) -> Self {
110        self.weights_hash = Some(hash.into());
111        self
112    }
113}
114
115/// Execution context for deterministic replay
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
117pub struct ExecutionContext {
118    pub runtime_version: Option<String>, // Python 3.11, Node 20, etc.
119    #[serde(default)]
120    pub dependencies: HashMap<String, String>,
121    pub random_seed: Option<i64>,
122    #[serde(default)]
123    pub environment_variables: HashMap<String, String>,
124    #[serde(default)]
125    pub hardware_info: HashMap<String, serde_json::Value>,
126}
127
128impl ExecutionContext {
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    pub fn with_runtime_version(mut self, version: impl Into<String>) -> Self {
134        self.runtime_version = Some(version.into());
135        self
136    }
137
138    pub fn with_dependency(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
139        self.dependencies.insert(name.into(), version.into());
140        self
141    }
142
143    pub fn with_random_seed(mut self, seed: i64) -> Self {
144        self.random_seed = Some(seed);
145        self
146    }
147
148    pub fn with_env_var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
149        self.environment_variables.insert(key.into(), value.into());
150        self
151    }
152}
153
154/// Metadata for a snapshot
155#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
156pub struct SnapshotMetadata {
157    pub snapshot_id: Uuid,
158    pub timestamp: DateTime<Utc>,
159    #[serde(default = "default_schema_version")]
160    pub schema_version: String,
161    pub sdk_version: String,
162    pub created_by: Option<String>,
163    pub checksum: Option<String>,
164}
165
166impl SnapshotMetadata {
167    pub fn new() -> Self {
168        Self {
169            snapshot_id: Uuid::new_v4(),
170            timestamp: Utc::now(),
171            schema_version: default_schema_version(),
172            sdk_version: env!("CARGO_PKG_VERSION").to_string(),
173            created_by: None,
174            checksum: None,
175        }
176    }
177
178    pub fn with_created_by(mut self, created_by: impl Into<String>) -> Self {
179        self.created_by = Some(created_by.into());
180        self
181    }
182
183    pub fn compute_checksum(&mut self, data: &[u8]) {
184        let mut hasher = Sha256::new();
185        hasher.update(data);
186        let result = hasher.finalize();
187        self.checksum = Some(format!("{:x}", result));
188    }
189}
190
191impl Default for SnapshotMetadata {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197/// A single AI decision capture
198#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
199pub struct DecisionSnapshot {
200    pub metadata: SnapshotMetadata,
201    pub context: ExecutionContext,
202    pub function_name: String,
203    pub module_name: Option<String>,
204    pub inputs: Vec<Input>,
205    pub outputs: Vec<Output>,
206    pub model_parameters: Option<ModelParameters>,
207    pub execution_time_ms: Option<f64>,
208    pub error: Option<String>,
209    pub error_type: Option<String>,
210    #[serde(default)]
211    pub tags: HashMap<String, String>,
212    #[serde(default)]
213    pub custom_data: HashMap<String, serde_json::Value>,
214}
215
216impl DecisionSnapshot {
217    pub fn new(function_name: impl Into<String>) -> Self {
218        let metadata = SnapshotMetadata::new();
219        let snapshot = Self {
220            metadata,
221            context: ExecutionContext::new(),
222            function_name: function_name.into(),
223            module_name: None,
224            inputs: Vec::new(),
225            outputs: Vec::new(),
226            model_parameters: None,
227            execution_time_ms: None,
228            error: None,
229            error_type: None,
230            tags: HashMap::new(),
231            custom_data: HashMap::new(),
232        };
233
234        // Update checksum after creation
235        let mut result = snapshot;
236        result.update_checksum();
237        result
238    }
239
240    pub fn with_module(mut self, module_name: impl Into<String>) -> Self {
241        self.module_name = Some(module_name.into());
242        self.update_checksum();
243        self
244    }
245
246    pub fn with_context(mut self, context: ExecutionContext) -> Self {
247        self.context = context;
248        self.update_checksum();
249        self
250    }
251
252    pub fn add_input(mut self, input: Input) -> Self {
253        self.inputs.push(input);
254        self.update_checksum();
255        self
256    }
257
258    pub fn add_output(mut self, output: Output) -> Self {
259        self.outputs.push(output);
260        self.update_checksum();
261        self
262    }
263
264    pub fn with_model_parameters(mut self, params: ModelParameters) -> Self {
265        self.model_parameters = Some(params);
266        self.update_checksum();
267        self
268    }
269
270    pub fn with_execution_time(mut self, time_ms: f64) -> Self {
271        self.execution_time_ms = Some(time_ms);
272        self.update_checksum();
273        self
274    }
275
276    pub fn with_error(mut self, error: impl Into<String>, error_type: Option<String>) -> Self {
277        self.error = Some(error.into());
278        self.error_type = error_type;
279        self.update_checksum();
280        self
281    }
282
283    pub fn add_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
284        self.tags.insert(key.into(), value.into());
285        self.update_checksum();
286        self
287    }
288
289    pub fn add_custom_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
290        self.custom_data.insert(key.into(), value);
291        self.update_checksum();
292        self
293    }
294
295    fn update_checksum(&mut self) {
296        if let Ok(json_bytes) = serde_json::to_vec(self) {
297            self.metadata.compute_checksum(&json_bytes);
298        }
299    }
300
301    /// Serialize to canonical JSON for consistent checksums
302    pub fn to_canonical_json(&self) -> Result<String, serde_json::Error> {
303        // Create a copy without the checksum for canonical representation
304        let mut copy = self.clone();
305        copy.metadata.checksum = None;
306
307        // Use sorted keys for consistent JSON
308        let value = serde_json::to_value(&copy)?;
309        serde_json::to_string(&value)
310    }
311}
312
313/// Root snapshot containing multiple decisions (e.g., a session)
314#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
315pub struct Snapshot {
316    pub metadata: SnapshotMetadata,
317    pub decisions: Vec<DecisionSnapshot>,
318    pub snapshot_type: SnapshotType,
319}
320
321impl Snapshot {
322    pub fn new(snapshot_type: SnapshotType) -> Self {
323        let metadata = SnapshotMetadata::new();
324        let snapshot = Self {
325            metadata,
326            decisions: Vec::new(),
327            snapshot_type,
328        };
329
330        let mut result = snapshot;
331        result.update_checksum();
332        result
333    }
334
335    pub fn add_decision(&mut self, decision: DecisionSnapshot) {
336        self.decisions.push(decision);
337        self.update_checksum();
338    }
339
340    pub fn with_created_by(mut self, created_by: impl Into<String>) -> Self {
341        self.metadata.created_by = Some(created_by.into());
342        self.update_checksum();
343        self
344    }
345
346    fn update_checksum(&mut self) {
347        if let Ok(json_bytes) = serde_json::to_vec(self) {
348            self.metadata.compute_checksum(&json_bytes);
349        }
350    }
351
352    /// Serialize to canonical JSON for consistent checksums
353    pub fn to_canonical_json(&self) -> Result<String, serde_json::Error> {
354        let mut copy = self.clone();
355        copy.metadata.checksum = None;
356
357        let value = serde_json::to_value(&copy)?;
358        serde_json::to_string(&value)
359    }
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
363pub enum SnapshotType {
364    Decision,
365    Batch,
366    Session,
367}
368
369fn default_schema_version() -> String {
370    "1.0".to_string()
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use serde_json::json;
377
378    #[test]
379    fn test_input_creation() {
380        let input = Input::new("test", json!("value"), "string");
381        assert_eq!(input.name, "test");
382        assert_eq!(input.value, json!("value"));
383        assert_eq!(input.data_type, "string");
384        assert_eq!(input.schema_version, "1.0");
385    }
386
387    #[test]
388    fn test_output_with_confidence() {
389        let output = Output::new("result", json!(42), "number").with_confidence(0.95);
390        assert_eq!(output.confidence, Some(0.95));
391    }
392
393    #[test]
394    fn test_output_confidence_clamping() {
395        let output = Output::new("result", json!(42), "number").with_confidence(1.5);
396        assert_eq!(output.confidence, Some(1.0));
397
398        let output = Output::new("result", json!(42), "number").with_confidence(-0.5);
399        assert_eq!(output.confidence, Some(0.0));
400    }
401
402    #[test]
403    fn test_model_parameters_builder() {
404        let params = ModelParameters::new("gpt-4")
405            .with_provider("openai")
406            .with_version("1.0")
407            .with_parameter("temperature", json!(0.7))
408            .with_hyperparameter("max_tokens", json!(1000));
409
410        assert_eq!(params.model_name, "gpt-4");
411        assert_eq!(params.provider, Some("openai".to_string()));
412        assert_eq!(params.model_version, Some("1.0".to_string()));
413        assert_eq!(params.parameters.get("temperature"), Some(&json!(0.7)));
414        assert_eq!(params.hyperparameters.get("max_tokens"), Some(&json!(1000)));
415    }
416
417    #[test]
418    fn test_execution_context_builder() {
419        let context = ExecutionContext::new()
420            .with_runtime_version("Python 3.11")
421            .with_dependency("numpy", "1.24.0")
422            .with_random_seed(42)
423            .with_env_var("DEBUG", "true");
424
425        assert_eq!(context.runtime_version, Some("Python 3.11".to_string()));
426        assert_eq!(
427            context.dependencies.get("numpy"),
428            Some(&"1.24.0".to_string())
429        );
430        assert_eq!(context.random_seed, Some(42));
431        assert_eq!(
432            context.environment_variables.get("DEBUG"),
433            Some(&"true".to_string())
434        );
435    }
436
437    #[test]
438    fn test_decision_snapshot_creation() {
439        let snapshot = DecisionSnapshot::new("test_function");
440        assert_eq!(snapshot.function_name, "test_function");
441        assert!(snapshot.metadata.checksum.is_some());
442    }
443
444    #[test]
445    fn test_decision_snapshot_builder() {
446        let input = Input::new("x", json!(1), "number");
447        let output = Output::new("result", json!(2), "number");
448        let params = ModelParameters::new("gpt-4");
449
450        let snapshot = DecisionSnapshot::new("add")
451            .with_module("math")
452            .add_input(input)
453            .add_output(output)
454            .with_model_parameters(params)
455            .with_execution_time(100.5)
456            .add_tag("version", "1.0");
457
458        assert_eq!(snapshot.function_name, "add");
459        assert_eq!(snapshot.module_name, Some("math".to_string()));
460        assert_eq!(snapshot.inputs.len(), 1);
461        assert_eq!(snapshot.outputs.len(), 1);
462        assert!(snapshot.model_parameters.is_some());
463        assert_eq!(snapshot.execution_time_ms, Some(100.5));
464        assert_eq!(snapshot.tags.get("version"), Some(&"1.0".to_string()));
465    }
466
467    #[test]
468    fn test_snapshot_creation() {
469        let mut snapshot = Snapshot::new(SnapshotType::Session);
470        let decision = DecisionSnapshot::new("test");
471
472        snapshot.add_decision(decision);
473
474        assert_eq!(snapshot.snapshot_type, SnapshotType::Session);
475        assert_eq!(snapshot.decisions.len(), 1);
476        assert!(snapshot.metadata.checksum.is_some());
477    }
478
479    #[test]
480    fn test_json_serialization_roundtrip() {
481        let input = Input::new("test", json!({"key": "value"}), "object");
482        let output = Output::new("result", json!([1, 2, 3]), "array");
483        let params = ModelParameters::new("gpt-4")
484            .with_provider("openai")
485            .with_parameter("temperature", json!(0.8));
486
487        let snapshot = DecisionSnapshot::new("process")
488            .add_input(input)
489            .add_output(output)
490            .with_model_parameters(params);
491
492        let json = serde_json::to_string(&snapshot).unwrap();
493        let deserialized: DecisionSnapshot = serde_json::from_str(&json).unwrap();
494
495        assert_eq!(snapshot, deserialized);
496    }
497
498    #[test]
499    fn test_canonical_json_consistency() {
500        let snapshot = DecisionSnapshot::new("test")
501            .add_tag("key1", "value1")
502            .add_tag("key2", "value2");
503
504        let json1 = snapshot.to_canonical_json().unwrap();
505        let json2 = snapshot.to_canonical_json().unwrap();
506
507        assert_eq!(json1, json2);
508    }
509
510    #[test]
511    fn test_checksum_updates() {
512        let mut snapshot = DecisionSnapshot::new("test");
513        let initial_checksum = snapshot.metadata.checksum.clone();
514
515        snapshot = snapshot.add_tag("new", "tag");
516        let updated_checksum = snapshot.metadata.checksum.clone();
517
518        assert_ne!(initial_checksum, updated_checksum);
519    }
520}