1use ipfrs_core::{Cid, Error, Result};
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub enum EmbeddingSource {
18 Model {
20 name: String,
22 version: String,
24 config: HashMap<String, String>,
26 },
27 Manual {
29 creator: String,
31 description: String,
33 },
34 Derived {
36 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
38 source_cid: Cid,
39 transformation: String,
41 },
42 Aggregated {
44 #[serde(
46 serialize_with = "serialize_cid_vec",
47 deserialize_with = "deserialize_cid_vec"
48 )]
49 source_cids: Vec<Cid>,
50 method: String,
52 },
53}
54
55fn serialize_cid<S>(cid: &Cid, serializer: S) -> std::result::Result<S::Ok, S::Error>
57where
58 S: serde::Serializer,
59{
60 serializer.serialize_str(&cid.to_string())
61}
62
63fn deserialize_cid<'de, D>(deserializer: D) -> std::result::Result<Cid, D::Error>
65where
66 D: serde::Deserializer<'de>,
67{
68 let s = String::deserialize(deserializer)?;
69 s.parse().map_err(serde::de::Error::custom)
70}
71
72fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> std::result::Result<S::Ok, S::Error>
74where
75 S: serde::Serializer,
76{
77 let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
78 strings.serialize(serializer)
79}
80
81fn deserialize_cid_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<Cid>, D::Error>
83where
84 D: serde::Deserializer<'de>,
85{
86 let strings: Vec<String> = Vec::deserialize(deserializer)?;
87 strings
88 .into_iter()
89 .map(|s| s.parse().map_err(serde::de::Error::custom))
90 .collect()
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct EmbeddingMetadata {
96 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
98 pub cid: Cid,
99 pub version: u32,
101 pub source: EmbeddingSource,
103 pub created_at: u64,
105 pub input_reference: Option<String>,
107 pub dimension: usize,
109 pub extra: HashMap<String, String>,
111}
112
113impl EmbeddingMetadata {
114 pub fn new(cid: Cid, dimension: usize, source: EmbeddingSource) -> Self {
116 Self {
117 cid,
118 version: 1,
119 source,
120 created_at: current_timestamp_ms(),
121 input_reference: None,
122 dimension,
123 extra: HashMap::new(),
124 }
125 }
126
127 pub fn with_input_reference(mut self, reference: impl Into<String>) -> Self {
129 self.input_reference = Some(reference.into());
130 self
131 }
132
133 pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
135 self.extra.insert(key.into(), value.into());
136 self
137 }
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct AuditLogEntry {
143 pub id: u64,
145 pub timestamp: u64,
147 pub operation: AuditOperation,
149 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
151 pub cid: Cid,
152 pub actor: String,
154 pub context: HashMap<String, String>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
160pub enum AuditOperation {
161 Create,
163 Update,
165 Delete,
167 Query,
169 Access,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct VersionHistory {
176 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
178 pub cid: Cid,
179 pub versions: Vec<EmbeddingVersion>,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct EmbeddingVersion {
186 pub version: u32,
188 pub timestamp: u64,
190 pub change_log: String,
192 #[serde(
194 skip_serializing_if = "Option::is_none",
195 serialize_with = "serialize_cid_option",
196 deserialize_with = "deserialize_cid_option"
197 )]
198 pub previous_cid: Option<Cid>,
199 pub metadata: EmbeddingMetadata,
201}
202
203fn serialize_cid_option<S>(cid: &Option<Cid>, serializer: S) -> std::result::Result<S::Ok, S::Error>
205where
206 S: serde::Serializer,
207{
208 match cid {
209 Some(c) => serializer.serialize_some(&c.to_string()),
210 None => serializer.serialize_none(),
211 }
212}
213
214fn deserialize_cid_option<'de, D>(deserializer: D) -> std::result::Result<Option<Cid>, D::Error>
216where
217 D: serde::Deserializer<'de>,
218{
219 let opt: Option<String> = Option::deserialize(deserializer)?;
220 opt.map(|s| s.parse().map_err(serde::de::Error::custom))
221 .transpose()
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct SearchExplanation {
227 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
229 pub query_cid: Cid,
230 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
232 pub result_cid: Cid,
233 pub score: f32,
235 pub attributions: Vec<FeatureAttribution>,
237 pub explanation: String,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct FeatureAttribution {
244 pub feature_idx: usize,
246 pub contribution: f32,
248 pub description: Option<String>,
250}
251
252pub struct ProvenanceTracker {
254 metadata: Arc<RwLock<HashMap<Cid, EmbeddingMetadata>>>,
256 versions: Arc<RwLock<HashMap<Cid, VersionHistory>>>,
258 audit_log: Arc<RwLock<Vec<AuditLogEntry>>>,
260 next_audit_id: Arc<RwLock<u64>>,
262}
263
264impl ProvenanceTracker {
265 pub fn new() -> Self {
267 Self {
268 metadata: Arc::new(RwLock::new(HashMap::new())),
269 versions: Arc::new(RwLock::new(HashMap::new())),
270 audit_log: Arc::new(RwLock::new(Vec::new())),
271 next_audit_id: Arc::new(RwLock::new(0)),
272 }
273 }
274
275 pub fn track_embedding(&self, metadata: EmbeddingMetadata) -> Result<()> {
277 let cid = metadata.cid;
278
279 self.metadata.write().insert(cid, metadata.clone());
281
282 let version = EmbeddingVersion {
284 version: 1,
285 timestamp: current_timestamp_ms(),
286 change_log: "Initial version".to_string(),
287 previous_cid: None,
288 metadata: metadata.clone(),
289 };
290
291 let history = VersionHistory {
292 cid,
293 versions: vec![version],
294 };
295
296 self.versions.write().insert(cid, history);
297
298 self.add_audit_entry(
300 AuditOperation::Create,
301 cid,
302 "system".to_string(),
303 HashMap::new(),
304 )?;
305
306 Ok(())
307 }
308
309 pub fn update_embedding(
311 &self,
312 cid: Cid,
313 new_cid: Cid,
314 change_log: impl Into<String>,
315 ) -> Result<()> {
316 let metadata = self
317 .metadata
318 .read()
319 .get(&cid)
320 .cloned()
321 .ok_or_else(|| Error::InvalidInput(format!("Embedding not found: {}", cid)))?;
322
323 let mut new_metadata = metadata.clone();
325 new_metadata.cid = new_cid;
326 new_metadata.version += 1;
327 new_metadata.created_at = current_timestamp_ms();
328
329 self.metadata.write().insert(new_cid, new_metadata.clone());
331
332 let mut versions = self.versions.write();
334 let history = versions.entry(cid).or_insert_with(|| VersionHistory {
335 cid,
336 versions: Vec::new(),
337 });
338
339 history.versions.push(EmbeddingVersion {
340 version: new_metadata.version,
341 timestamp: new_metadata.created_at,
342 change_log: change_log.into(),
343 previous_cid: Some(cid),
344 metadata: new_metadata,
345 });
346
347 drop(versions);
349 self.add_audit_entry(
350 AuditOperation::Update,
351 new_cid,
352 "system".to_string(),
353 HashMap::from([("previous_cid".to_string(), cid.to_string())]),
354 )?;
355
356 Ok(())
357 }
358
359 pub fn get_metadata(&self, cid: &Cid) -> Option<EmbeddingMetadata> {
361 self.metadata.read().get(cid).cloned()
362 }
363
364 pub fn get_version_history(&self, cid: &Cid) -> Option<VersionHistory> {
366 self.versions.read().get(cid).cloned()
367 }
368
369 pub fn get_audit_log(&self, cid: &Cid) -> Vec<AuditLogEntry> {
371 self.audit_log
372 .read()
373 .iter()
374 .filter(|e| &e.cid == cid)
375 .cloned()
376 .collect()
377 }
378
379 pub fn get_full_audit_log(&self) -> Vec<AuditLogEntry> {
381 self.audit_log.read().clone()
382 }
383
384 fn add_audit_entry(
386 &self,
387 operation: AuditOperation,
388 cid: Cid,
389 actor: String,
390 context: HashMap<String, String>,
391 ) -> Result<()> {
392 let id = {
393 let mut next_id = self.next_audit_id.write();
394 let id = *next_id;
395 *next_id += 1;
396 id
397 };
398
399 let entry = AuditLogEntry {
400 id,
401 timestamp: current_timestamp_ms(),
402 operation,
403 cid,
404 actor,
405 context,
406 };
407
408 self.audit_log.write().push(entry);
409 Ok(())
410 }
411
412 pub fn log_query(&self, query_cid: Cid, result_cids: &[Cid]) -> Result<()> {
414 let context = HashMap::from([("result_count".to_string(), result_cids.len().to_string())]);
415
416 self.add_audit_entry(
417 AuditOperation::Query,
418 query_cid,
419 "system".to_string(),
420 context,
421 )?;
422
423 for result_cid in result_cids {
425 self.add_audit_entry(
426 AuditOperation::Access,
427 *result_cid,
428 "query".to_string(),
429 HashMap::from([("query_cid".to_string(), query_cid.to_string())]),
430 )?;
431 }
432
433 Ok(())
434 }
435
436 pub fn explain_result(
438 &self,
439 query_embedding: &[f32],
440 result_cid: &Cid,
441 result_embedding: &[f32],
442 score: f32,
443 ) -> SearchExplanation {
444 let mut attributions = Vec::new();
446
447 for (idx, (q, r)) in query_embedding
448 .iter()
449 .zip(result_embedding.iter())
450 .enumerate()
451 {
452 let contribution = q * r; if contribution.abs() > 0.1 {
454 attributions.push(FeatureAttribution {
456 feature_idx: idx,
457 contribution,
458 description: Some(format!("Dimension {}", idx)),
459 });
460 }
461 }
462
463 attributions.sort_by(|a, b| {
465 b.contribution
466 .abs()
467 .partial_cmp(&a.contribution.abs())
468 .unwrap_or(std::cmp::Ordering::Equal)
469 });
470
471 attributions.truncate(10);
473
474 let explanation = if attributions.is_empty() {
476 format!("Result matched with similarity score {:.3}", score)
477 } else {
478 let top_features: Vec<String> = attributions
479 .iter()
480 .take(3)
481 .map(|a| format!("dim {}: {:.3}", a.feature_idx, a.contribution))
482 .collect();
483
484 format!(
485 "Result matched with similarity score {:.3}. Top contributing features: {}",
486 score,
487 top_features.join(", ")
488 )
489 };
490
491 SearchExplanation {
492 query_cid: Cid::default(), result_cid: *result_cid,
494 score,
495 attributions,
496 explanation,
497 }
498 }
499
500 pub fn stats(&self) -> ProvenanceStats {
502 let metadata = self.metadata.read();
503 let versions = self.versions.read();
504 let audit_log = self.audit_log.read();
505
506 ProvenanceStats {
507 total_embeddings: metadata.len(),
508 total_versions: versions.values().map(|h| h.versions.len()).sum(),
509 total_audit_entries: audit_log.len(),
510 oldest_timestamp: audit_log.first().map(|e| e.timestamp),
511 newest_timestamp: audit_log.last().map(|e| e.timestamp),
512 }
513 }
514}
515
516impl Default for ProvenanceTracker {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct ProvenanceStats {
525 pub total_embeddings: usize,
527 pub total_versions: usize,
529 pub total_audit_entries: usize,
531 pub oldest_timestamp: Option<u64>,
533 pub newest_timestamp: Option<u64>,
535}
536
537fn current_timestamp_ms() -> u64 {
539 std::time::SystemTime::now()
540 .duration_since(std::time::UNIX_EPOCH)
541 .unwrap()
542 .as_millis() as u64
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 fn test_cid() -> Cid {
550 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
551 .parse()
552 .unwrap()
553 }
554
555 fn test_cid2() -> Cid {
556 "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354"
557 .parse()
558 .unwrap()
559 }
560
561 #[test]
562 fn test_track_embedding() {
563 let tracker = ProvenanceTracker::new();
564
565 let metadata = EmbeddingMetadata::new(
566 test_cid(),
567 768,
568 EmbeddingSource::Model {
569 name: "bert-base".to_string(),
570 version: "1.0".to_string(),
571 config: HashMap::new(),
572 },
573 );
574
575 assert!(tracker.track_embedding(metadata).is_ok());
576
577 let retrieved = tracker.get_metadata(&test_cid());
578 assert!(retrieved.is_some());
579 assert_eq!(retrieved.unwrap().dimension, 768);
580 }
581
582 #[test]
583 fn test_version_history() {
584 let tracker = ProvenanceTracker::new();
585
586 let metadata = EmbeddingMetadata::new(
587 test_cid(),
588 768,
589 EmbeddingSource::Manual {
590 creator: "test".to_string(),
591 description: "test embedding".to_string(),
592 },
593 );
594
595 tracker.track_embedding(metadata).unwrap();
596
597 tracker
599 .update_embedding(test_cid(), test_cid2(), "Updated embedding")
600 .unwrap();
601
602 let history = tracker.get_version_history(&test_cid());
603 assert!(history.is_some());
604 assert_eq!(history.unwrap().versions.len(), 2);
605 }
606
607 #[test]
608 fn test_audit_log() {
609 let tracker = ProvenanceTracker::new();
610
611 let metadata = EmbeddingMetadata::new(
612 test_cid(),
613 768,
614 EmbeddingSource::Derived {
615 source_cid: test_cid2(),
616 transformation: "normalize".to_string(),
617 },
618 );
619
620 tracker.track_embedding(metadata).unwrap();
621
622 let audit_entries = tracker.get_audit_log(&test_cid());
623 assert!(!audit_entries.is_empty());
624 assert_eq!(audit_entries[0].operation, AuditOperation::Create);
625 }
626
627 #[test]
628 fn test_log_query() {
629 let tracker = ProvenanceTracker::new();
630
631 let result_cids = vec![test_cid(), test_cid2()];
632 let query_cid = test_cid();
633
634 assert!(tracker.log_query(query_cid, &result_cids).is_ok());
635
636 let audit_log = tracker.get_full_audit_log();
637 assert_eq!(audit_log.len(), 3); }
639
640 #[test]
641 fn test_explain_result() {
642 let tracker = ProvenanceTracker::new();
643
644 let query_emb = vec![1.0, 0.5, 0.3];
645 let result_emb = vec![0.9, 0.6, 0.2];
646
647 let explanation = tracker.explain_result(&query_emb, &test_cid(), &result_emb, 0.95);
648
649 assert_eq!(explanation.result_cid, test_cid());
650 assert_eq!(explanation.score, 0.95);
651 assert!(!explanation.attributions.is_empty());
652 assert!(!explanation.explanation.is_empty());
653 }
654
655 #[test]
656 fn test_provenance_stats() {
657 let tracker = ProvenanceTracker::new();
658
659 let metadata1 = EmbeddingMetadata::new(
660 test_cid(),
661 768,
662 EmbeddingSource::Manual {
663 creator: "test".to_string(),
664 description: "test".to_string(),
665 },
666 );
667
668 let metadata2 = EmbeddingMetadata::new(
669 test_cid2(),
670 512,
671 EmbeddingSource::Manual {
672 creator: "test".to_string(),
673 description: "test2".to_string(),
674 },
675 );
676
677 tracker.track_embedding(metadata1).unwrap();
678 tracker.track_embedding(metadata2).unwrap();
679
680 let stats = tracker.stats();
681 assert_eq!(stats.total_embeddings, 2);
682 assert_eq!(stats.total_versions, 2);
683 assert_eq!(stats.total_audit_entries, 2);
684 }
685}