ipfrs_semantic/
provenance.rs

1//! Provenance Tracking for Embeddings
2//!
3//! This module tracks the provenance and lineage of embeddings:
4//! - Source tracking for embedding generation
5//! - Version control for embeddings
6//! - Immutable audit trails
7//! - Explanation generation for search results
8
9use ipfrs_core::{Cid, Error, Result};
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15/// Source of an embedding
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub enum EmbeddingSource {
18    /// Generated by a machine learning model
19    Model {
20        /// Model name
21        name: String,
22        /// Model version
23        version: String,
24        /// Model parameters/config
25        config: HashMap<String, String>,
26    },
27    /// Manually created
28    Manual {
29        /// Creator identifier
30        creator: String,
31        /// Description
32        description: String,
33    },
34    /// Derived from another embedding
35    Derived {
36        /// Source embedding CID (as string)
37        #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
38        source_cid: Cid,
39        /// Transformation applied
40        transformation: String,
41    },
42    /// Aggregated from multiple embeddings
43    Aggregated {
44        /// Source embedding CIDs (as strings)
45        #[serde(
46            serialize_with = "serialize_cid_vec",
47            deserialize_with = "deserialize_cid_vec"
48        )]
49        source_cids: Vec<Cid>,
50        /// Aggregation method
51        method: String,
52    },
53}
54
55/// Helper function to serialize a CID
56fn serialize_cid<S>(cid: &Cid, serializer: S) -> std::result::Result<S::Ok, S::Error>
57where
58    S: serde::Serializer,
59{
60    serializer.serialize_str(&cid.to_string())
61}
62
63/// Helper function to deserialize a CID
64fn deserialize_cid<'de, D>(deserializer: D) -> std::result::Result<Cid, D::Error>
65where
66    D: serde::Deserializer<'de>,
67{
68    let s = String::deserialize(deserializer)?;
69    s.parse().map_err(serde::de::Error::custom)
70}
71
72/// Helper function to serialize a Vec<CID>
73fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> std::result::Result<S::Ok, S::Error>
74where
75    S: serde::Serializer,
76{
77    let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
78    strings.serialize(serializer)
79}
80
81/// Helper function to deserialize a Vec<CID>
82fn deserialize_cid_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<Cid>, D::Error>
83where
84    D: serde::Deserializer<'de>,
85{
86    let strings: Vec<String> = Vec::deserialize(deserializer)?;
87    strings
88        .into_iter()
89        .map(|s| s.parse().map_err(serde::de::Error::custom))
90        .collect()
91}
92
93/// Embedding metadata with provenance
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct EmbeddingMetadata {
96    /// Content identifier
97    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
98    pub cid: Cid,
99    /// Embedding version
100    pub version: u32,
101    /// Source of embedding
102    pub source: EmbeddingSource,
103    /// Creation timestamp (Unix epoch ms)
104    pub created_at: u64,
105    /// Input data reference
106    pub input_reference: Option<String>,
107    /// Embedding dimension
108    pub dimension: usize,
109    /// Additional metadata
110    pub extra: HashMap<String, String>,
111}
112
113impl EmbeddingMetadata {
114    /// Create new embedding metadata
115    pub fn new(cid: Cid, dimension: usize, source: EmbeddingSource) -> Self {
116        Self {
117            cid,
118            version: 1,
119            source,
120            created_at: current_timestamp_ms(),
121            input_reference: None,
122            dimension,
123            extra: HashMap::new(),
124        }
125    }
126
127    /// Set input reference
128    pub fn with_input_reference(mut self, reference: impl Into<String>) -> Self {
129        self.input_reference = Some(reference.into());
130        self
131    }
132
133    /// Add extra metadata
134    pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
135        self.extra.insert(key.into(), value.into());
136        self
137    }
138}
139
140/// Audit log entry for embedding operations
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct AuditLogEntry {
143    /// Entry ID
144    pub id: u64,
145    /// Timestamp (Unix epoch ms)
146    pub timestamp: u64,
147    /// Operation type
148    pub operation: AuditOperation,
149    /// CID affected
150    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
151    pub cid: Cid,
152    /// User/system identifier
153    pub actor: String,
154    /// Additional context
155    pub context: HashMap<String, String>,
156}
157
158/// Type of audit operation
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
160pub enum AuditOperation {
161    /// Embedding created
162    Create,
163    /// Embedding updated
164    Update,
165    /// Embedding deleted
166    Delete,
167    /// Embedding queried
168    Query,
169    /// Embedding accessed
170    Access,
171}
172
173/// Version history for an embedding
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct VersionHistory {
176    /// Embedding CID
177    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
178    pub cid: Cid,
179    /// All versions
180    pub versions: Vec<EmbeddingVersion>,
181}
182
183/// A single version of an embedding
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct EmbeddingVersion {
186    /// Version number
187    pub version: u32,
188    /// Timestamp
189    pub timestamp: u64,
190    /// Change description
191    pub change_log: String,
192    /// Previous version CID (if updated)
193    #[serde(
194        skip_serializing_if = "Option::is_none",
195        serialize_with = "serialize_cid_option",
196        deserialize_with = "deserialize_cid_option"
197    )]
198    pub previous_cid: Option<Cid>,
199    /// Metadata for this version
200    pub metadata: EmbeddingMetadata,
201}
202
203/// Helper function to serialize an Option<CID>
204fn serialize_cid_option<S>(cid: &Option<Cid>, serializer: S) -> std::result::Result<S::Ok, S::Error>
205where
206    S: serde::Serializer,
207{
208    match cid {
209        Some(c) => serializer.serialize_some(&c.to_string()),
210        None => serializer.serialize_none(),
211    }
212}
213
214/// Helper function to deserialize an Option<CID>
215fn deserialize_cid_option<'de, D>(deserializer: D) -> std::result::Result<Option<Cid>, D::Error>
216where
217    D: serde::Deserializer<'de>,
218{
219    let opt: Option<String> = Option::deserialize(deserializer)?;
220    opt.map(|s| s.parse().map_err(serde::de::Error::custom))
221        .transpose()
222}
223
224/// Explanation for a search result
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct SearchExplanation {
227    /// Query CID
228    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
229    pub query_cid: Cid,
230    /// Result CID
231    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
232    pub result_cid: Cid,
233    /// Similarity score
234    pub score: f32,
235    /// Feature attributions (which features contributed to similarity)
236    pub attributions: Vec<FeatureAttribution>,
237    /// Explanation text
238    pub explanation: String,
239}
240
241/// Attribution of a feature to similarity
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct FeatureAttribution {
244    /// Feature index
245    pub feature_idx: usize,
246    /// Contribution to score (-1.0 to 1.0)
247    pub contribution: f32,
248    /// Feature description
249    pub description: Option<String>,
250}
251
252/// Provenance tracker for embeddings
253pub struct ProvenanceTracker {
254    /// Embedding metadata storage
255    metadata: Arc<RwLock<HashMap<Cid, EmbeddingMetadata>>>,
256    /// Version history
257    versions: Arc<RwLock<HashMap<Cid, VersionHistory>>>,
258    /// Audit log (immutable append-only)
259    audit_log: Arc<RwLock<Vec<AuditLogEntry>>>,
260    /// Next audit log ID
261    next_audit_id: Arc<RwLock<u64>>,
262}
263
264impl ProvenanceTracker {
265    /// Create a new provenance tracker
266    pub fn new() -> Self {
267        Self {
268            metadata: Arc::new(RwLock::new(HashMap::new())),
269            versions: Arc::new(RwLock::new(HashMap::new())),
270            audit_log: Arc::new(RwLock::new(Vec::new())),
271            next_audit_id: Arc::new(RwLock::new(0)),
272        }
273    }
274
275    /// Track a new embedding
276    pub fn track_embedding(&self, metadata: EmbeddingMetadata) -> Result<()> {
277        let cid = metadata.cid;
278
279        // Store metadata
280        self.metadata.write().insert(cid, metadata.clone());
281
282        // Initialize version history
283        let version = EmbeddingVersion {
284            version: 1,
285            timestamp: current_timestamp_ms(),
286            change_log: "Initial version".to_string(),
287            previous_cid: None,
288            metadata: metadata.clone(),
289        };
290
291        let history = VersionHistory {
292            cid,
293            versions: vec![version],
294        };
295
296        self.versions.write().insert(cid, history);
297
298        // Add audit log entry
299        self.add_audit_entry(
300            AuditOperation::Create,
301            cid,
302            "system".to_string(),
303            HashMap::new(),
304        )?;
305
306        Ok(())
307    }
308
309    /// Update embedding metadata (creates new version)
310    pub fn update_embedding(
311        &self,
312        cid: Cid,
313        new_cid: Cid,
314        change_log: impl Into<String>,
315    ) -> Result<()> {
316        let metadata = self
317            .metadata
318            .read()
319            .get(&cid)
320            .cloned()
321            .ok_or_else(|| Error::InvalidInput(format!("Embedding not found: {}", cid)))?;
322
323        // Create new version
324        let mut new_metadata = metadata.clone();
325        new_metadata.cid = new_cid;
326        new_metadata.version += 1;
327        new_metadata.created_at = current_timestamp_ms();
328
329        // Store new metadata
330        self.metadata.write().insert(new_cid, new_metadata.clone());
331
332        // Update version history
333        let mut versions = self.versions.write();
334        let history = versions.entry(cid).or_insert_with(|| VersionHistory {
335            cid,
336            versions: Vec::new(),
337        });
338
339        history.versions.push(EmbeddingVersion {
340            version: new_metadata.version,
341            timestamp: new_metadata.created_at,
342            change_log: change_log.into(),
343            previous_cid: Some(cid),
344            metadata: new_metadata,
345        });
346
347        // Add audit log entry
348        drop(versions);
349        self.add_audit_entry(
350            AuditOperation::Update,
351            new_cid,
352            "system".to_string(),
353            HashMap::from([("previous_cid".to_string(), cid.to_string())]),
354        )?;
355
356        Ok(())
357    }
358
359    /// Get embedding metadata
360    pub fn get_metadata(&self, cid: &Cid) -> Option<EmbeddingMetadata> {
361        self.metadata.read().get(cid).cloned()
362    }
363
364    /// Get version history
365    pub fn get_version_history(&self, cid: &Cid) -> Option<VersionHistory> {
366        self.versions.read().get(cid).cloned()
367    }
368
369    /// Get audit log entries for a CID
370    pub fn get_audit_log(&self, cid: &Cid) -> Vec<AuditLogEntry> {
371        self.audit_log
372            .read()
373            .iter()
374            .filter(|e| &e.cid == cid)
375            .cloned()
376            .collect()
377    }
378
379    /// Get all audit log entries (for compliance/export)
380    pub fn get_full_audit_log(&self) -> Vec<AuditLogEntry> {
381        self.audit_log.read().clone()
382    }
383
384    /// Add an audit log entry
385    fn add_audit_entry(
386        &self,
387        operation: AuditOperation,
388        cid: Cid,
389        actor: String,
390        context: HashMap<String, String>,
391    ) -> Result<()> {
392        let id = {
393            let mut next_id = self.next_audit_id.write();
394            let id = *next_id;
395            *next_id += 1;
396            id
397        };
398
399        let entry = AuditLogEntry {
400            id,
401            timestamp: current_timestamp_ms(),
402            operation,
403            cid,
404            actor,
405            context,
406        };
407
408        self.audit_log.write().push(entry);
409        Ok(())
410    }
411
412    /// Log a query operation
413    pub fn log_query(&self, query_cid: Cid, result_cids: &[Cid]) -> Result<()> {
414        let context = HashMap::from([("result_count".to_string(), result_cids.len().to_string())]);
415
416        self.add_audit_entry(
417            AuditOperation::Query,
418            query_cid,
419            "system".to_string(),
420            context,
421        )?;
422
423        // Log access to each result
424        for result_cid in result_cids {
425            self.add_audit_entry(
426                AuditOperation::Access,
427                *result_cid,
428                "query".to_string(),
429                HashMap::from([("query_cid".to_string(), query_cid.to_string())]),
430            )?;
431        }
432
433        Ok(())
434    }
435
436    /// Generate explanation for a search result
437    pub fn explain_result(
438        &self,
439        query_embedding: &[f32],
440        result_cid: &Cid,
441        result_embedding: &[f32],
442        score: f32,
443    ) -> SearchExplanation {
444        // Calculate feature attributions
445        let mut attributions = Vec::new();
446
447        for (idx, (q, r)) in query_embedding
448            .iter()
449            .zip(result_embedding.iter())
450            .enumerate()
451        {
452            let contribution = q * r; // Simplified: dot product contribution
453            if contribution.abs() > 0.1 {
454                // Only include significant contributions
455                attributions.push(FeatureAttribution {
456                    feature_idx: idx,
457                    contribution,
458                    description: Some(format!("Dimension {}", idx)),
459                });
460            }
461        }
462
463        // Sort by absolute contribution
464        attributions.sort_by(|a, b| {
465            b.contribution
466                .abs()
467                .partial_cmp(&a.contribution.abs())
468                .unwrap_or(std::cmp::Ordering::Equal)
469        });
470
471        // Keep top 10 features
472        attributions.truncate(10);
473
474        // Generate explanation text
475        let explanation = if attributions.is_empty() {
476            format!("Result matched with similarity score {:.3}", score)
477        } else {
478            let top_features: Vec<String> = attributions
479                .iter()
480                .take(3)
481                .map(|a| format!("dim {}: {:.3}", a.feature_idx, a.contribution))
482                .collect();
483
484            format!(
485                "Result matched with similarity score {:.3}. Top contributing features: {}",
486                score,
487                top_features.join(", ")
488            )
489        };
490
491        SearchExplanation {
492            query_cid: Cid::default(), // Placeholder
493            result_cid: *result_cid,
494            score,
495            attributions,
496            explanation,
497        }
498    }
499
500    /// Get statistics about tracked embeddings
501    pub fn stats(&self) -> ProvenanceStats {
502        let metadata = self.metadata.read();
503        let versions = self.versions.read();
504        let audit_log = self.audit_log.read();
505
506        ProvenanceStats {
507            total_embeddings: metadata.len(),
508            total_versions: versions.values().map(|h| h.versions.len()).sum(),
509            total_audit_entries: audit_log.len(),
510            oldest_timestamp: audit_log.first().map(|e| e.timestamp),
511            newest_timestamp: audit_log.last().map(|e| e.timestamp),
512        }
513    }
514}
515
516impl Default for ProvenanceTracker {
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522/// Statistics about provenance tracking
523#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct ProvenanceStats {
525    /// Total number of tracked embeddings
526    pub total_embeddings: usize,
527    /// Total number of versions across all embeddings
528    pub total_versions: usize,
529    /// Total audit log entries
530    pub total_audit_entries: usize,
531    /// Oldest audit entry timestamp
532    pub oldest_timestamp: Option<u64>,
533    /// Newest audit entry timestamp
534    pub newest_timestamp: Option<u64>,
535}
536
537/// Get current timestamp in milliseconds
538fn current_timestamp_ms() -> u64 {
539    std::time::SystemTime::now()
540        .duration_since(std::time::UNIX_EPOCH)
541        .unwrap()
542        .as_millis() as u64
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    fn test_cid() -> Cid {
550        "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
551            .parse()
552            .unwrap()
553    }
554
555    fn test_cid2() -> Cid {
556        "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354"
557            .parse()
558            .unwrap()
559    }
560
561    #[test]
562    fn test_track_embedding() {
563        let tracker = ProvenanceTracker::new();
564
565        let metadata = EmbeddingMetadata::new(
566            test_cid(),
567            768,
568            EmbeddingSource::Model {
569                name: "bert-base".to_string(),
570                version: "1.0".to_string(),
571                config: HashMap::new(),
572            },
573        );
574
575        assert!(tracker.track_embedding(metadata).is_ok());
576
577        let retrieved = tracker.get_metadata(&test_cid());
578        assert!(retrieved.is_some());
579        assert_eq!(retrieved.unwrap().dimension, 768);
580    }
581
582    #[test]
583    fn test_version_history() {
584        let tracker = ProvenanceTracker::new();
585
586        let metadata = EmbeddingMetadata::new(
587            test_cid(),
588            768,
589            EmbeddingSource::Manual {
590                creator: "test".to_string(),
591                description: "test embedding".to_string(),
592            },
593        );
594
595        tracker.track_embedding(metadata).unwrap();
596
597        // Update embedding
598        tracker
599            .update_embedding(test_cid(), test_cid2(), "Updated embedding")
600            .unwrap();
601
602        let history = tracker.get_version_history(&test_cid());
603        assert!(history.is_some());
604        assert_eq!(history.unwrap().versions.len(), 2);
605    }
606
607    #[test]
608    fn test_audit_log() {
609        let tracker = ProvenanceTracker::new();
610
611        let metadata = EmbeddingMetadata::new(
612            test_cid(),
613            768,
614            EmbeddingSource::Derived {
615                source_cid: test_cid2(),
616                transformation: "normalize".to_string(),
617            },
618        );
619
620        tracker.track_embedding(metadata).unwrap();
621
622        let audit_entries = tracker.get_audit_log(&test_cid());
623        assert!(!audit_entries.is_empty());
624        assert_eq!(audit_entries[0].operation, AuditOperation::Create);
625    }
626
627    #[test]
628    fn test_log_query() {
629        let tracker = ProvenanceTracker::new();
630
631        let result_cids = vec![test_cid(), test_cid2()];
632        let query_cid = test_cid();
633
634        assert!(tracker.log_query(query_cid, &result_cids).is_ok());
635
636        let audit_log = tracker.get_full_audit_log();
637        assert_eq!(audit_log.len(), 3); // 1 query + 2 access
638    }
639
640    #[test]
641    fn test_explain_result() {
642        let tracker = ProvenanceTracker::new();
643
644        let query_emb = vec![1.0, 0.5, 0.3];
645        let result_emb = vec![0.9, 0.6, 0.2];
646
647        let explanation = tracker.explain_result(&query_emb, &test_cid(), &result_emb, 0.95);
648
649        assert_eq!(explanation.result_cid, test_cid());
650        assert_eq!(explanation.score, 0.95);
651        assert!(!explanation.attributions.is_empty());
652        assert!(!explanation.explanation.is_empty());
653    }
654
655    #[test]
656    fn test_provenance_stats() {
657        let tracker = ProvenanceTracker::new();
658
659        let metadata1 = EmbeddingMetadata::new(
660            test_cid(),
661            768,
662            EmbeddingSource::Manual {
663                creator: "test".to_string(),
664                description: "test".to_string(),
665            },
666        );
667
668        let metadata2 = EmbeddingMetadata::new(
669            test_cid2(),
670            512,
671            EmbeddingSource::Manual {
672                creator: "test".to_string(),
673                description: "test2".to_string(),
674            },
675        );
676
677        tracker.track_embedding(metadata1).unwrap();
678        tracker.track_embedding(metadata2).unwrap();
679
680        let stats = tracker.stats();
681        assert_eq!(stats.total_embeddings, 2);
682        assert_eq!(stats.total_versions, 2);
683        assert_eq!(stats.total_audit_entries, 2);
684    }
685}