Skip to main content

neleus_db/
provenance.rs

1use serde::{Deserialize, Serialize};
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use crate::hash::Hash;
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub enum SourceType {
8    ToolOutput,
9    WebPage,
10    UserInput,
11    Document,
12    ApiResponse,
13    Custom(String),
14}
15
16impl SourceType {
17    pub fn as_str(&self) -> &str {
18        match self {
19            SourceType::ToolOutput => "tool_output",
20            SourceType::WebPage => "web_page",
21            SourceType::UserInput => "user_input",
22            SourceType::Document => "document",
23            SourceType::ApiResponse => "api_response",
24            SourceType::Custom(s) => s,
25        }
26    }
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub struct Evidence {
31    pub source_blob: Hash,
32
33    #[serde(with = "source_type_serde")]
34    pub source_type: SourceType,
35    pub extracted_text: String,
36    pub position: Option<(usize, usize)>,
37    pub timestamp: u64,
38    #[serde(default)]
39    pub metadata: std::collections::BTreeMap<String, String>,
40}
41
42mod source_type_serde {
43    use super::SourceType;
44    use serde::{Deserialize, Deserializer, Serializer};
45
46    pub fn serialize<S>(source_type: &SourceType, serializer: S) -> Result<S::Ok, S::Error>
47    where
48        S: Serializer,
49    {
50        serializer.serialize_str(source_type.as_str())
51    }
52
53    pub fn deserialize<'de, D>(deserializer: D) -> Result<SourceType, D::Error>
54    where
55        D: Deserializer<'de>,
56    {
57        let s = String::deserialize(deserializer)?;
58        Ok(match s.as_str() {
59            "tool_output" => SourceType::ToolOutput,
60            "web_page" => SourceType::WebPage,
61            "user_input" => SourceType::UserInput,
62            "document" => SourceType::Document,
63            "api_response" => SourceType::ApiResponse,
64            custom => SourceType::Custom(custom.to_string()),
65        })
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub struct ProvenanceRecord {
71    pub claim_id: String,
72    pub claim_text: String,
73    pub evidence: Vec<Evidence>,
74    pub agent_id: String,
75    pub timestamp: u64,
76    pub confidence: f32,
77    pub reasoning: Option<String>,
78    #[serde(default)]
79    pub tags: Vec<String>,
80    #[serde(default = "default_provenance_version")]
81    pub schema_version: u32,
82}
83
84fn default_provenance_version() -> u32 {
85    1
86}
87
88impl ProvenanceRecord {
89    pub fn new(claim_id: String, claim_text: String, agent_id: String, confidence: f32) -> Self {
90        let timestamp = SystemTime::now()
91            .duration_since(UNIX_EPOCH)
92            .unwrap_or_default()
93            .as_secs();
94
95        Self {
96            claim_id,
97            claim_text,
98            evidence: Vec::new(),
99            agent_id,
100            timestamp,
101            confidence,
102            reasoning: None,
103            tags: Vec::new(),
104            schema_version: 1,
105        }
106    }
107
108    pub fn add_evidence(&mut self, evidence: Evidence) -> &mut Self {
109        self.evidence.push(evidence);
110        self
111    }
112
113    pub fn add_evidence_batch(&mut self, evidence: Vec<Evidence>) -> &mut Self {
114        self.evidence.extend(evidence);
115        self
116    }
117
118    /// Set reasoning for how evidence supports the claim
119    pub fn with_reasoning(mut self, reasoning: String) -> Self {
120        self.reasoning = Some(reasoning);
121        self
122    }
123
124    /// Add tags for organization
125    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
126        self.tags = tags;
127        self
128    }
129
130    /// Validate provenance record (returns error if invalid)
131    pub fn validate(&self) -> anyhow::Result<()> {
132        if self.claim_id.is_empty() {
133            return Err(anyhow::anyhow!("claim_id cannot be empty"));
134        }
135
136        if !(0.0..=1.0).contains(&self.confidence) {
137            return Err(anyhow::anyhow!(
138                "confidence must be between 0.0 and 1.0, got {}",
139                self.confidence
140            ));
141        }
142
143        if self.evidence.is_empty() {
144            return Err(anyhow::anyhow!(
145                "provenance record must have at least one piece of evidence"
146            ));
147        }
148
149        Ok(())
150    }
151}
152
153/// Manifest for storing provenance records
154#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
155pub struct ProvenanceManifest {
156    /// Schema version
157    #[serde(default = "default_provenance_version")]
158    pub schema_version: u32,
159
160    /// List of provenance records
161    pub records: Vec<ProvenanceRecord>,
162
163    /// When manifest was created
164    pub created_at: u64,
165
166    /// Agent that created this manifest
167    pub agent_id: String,
168}
169
170impl ProvenanceManifest {
171    /// Create new empty provenance manifest
172    pub fn new(agent_id: String) -> Self {
173        let created_at = SystemTime::now()
174            .duration_since(UNIX_EPOCH)
175            .unwrap_or_default()
176            .as_secs();
177
178        Self {
179            schema_version: 1,
180            records: Vec::new(),
181            created_at,
182            agent_id,
183        }
184    }
185
186    /// Add a provenance record
187    pub fn add_record(&mut self, record: ProvenanceRecord) -> anyhow::Result<&mut Self> {
188        record.validate()?;
189        self.records.push(record);
190        Ok(self)
191    }
192
193    /// Find records by claim ID
194    pub fn find_by_claim_id(&self, claim_id: &str) -> Option<&ProvenanceRecord> {
195        self.records.iter().find(|r| r.claim_id == claim_id)
196    }
197
198    /// Find records by tag
199    pub fn find_by_tag(&self, tag: &str) -> Vec<&ProvenanceRecord> {
200        self.records
201            .iter()
202            .filter(|r| r.tags.contains(&tag.to_string()))
203            .collect()
204    }
205
206    /// Find records by agent
207    pub fn find_by_agent(&self, agent_id: &str) -> Vec<&ProvenanceRecord> {
208        self.records
209            .iter()
210            .filter(|r| r.agent_id == agent_id)
211            .collect()
212    }
213
214    /// Get confidence score for a claim
215    pub fn get_claim_confidence(&self, claim_id: &str) -> Option<f32> {
216        self.find_by_claim_id(claim_id).map(|r| r.confidence)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn create_evidence() {
226        let evidence = Evidence {
227            source_blob: Hash::zero(),
228            source_type: SourceType::ToolOutput,
229            extracted_text: "Tool result".to_string(),
230            position: Some((0, 11)),
231            timestamp: 1234567890,
232            metadata: Default::default(),
233        };
234
235        assert_eq!(evidence.source_type, SourceType::ToolOutput);
236        assert_eq!(evidence.extracted_text, "Tool result");
237    }
238
239    #[test]
240    fn create_provenance_record() {
241        let record = ProvenanceRecord::new(
242            "claim_1".to_string(),
243            "The agent decided to run task X".to_string(),
244            "agent_1".to_string(),
245            0.95,
246        );
247
248        assert_eq!(record.claim_id, "claim_1");
249        assert_eq!(record.confidence, 0.95);
250        assert!(record.evidence.is_empty());
251    }
252
253    #[test]
254    fn provenance_with_evidence() {
255        let evidence = Evidence {
256            source_blob: Hash::zero(),
257            source_type: SourceType::UserInput,
258            extracted_text: "user wants X".to_string(),
259            position: None,
260            timestamp: 1234567890,
261            metadata: Default::default(),
262        };
263
264        let mut record = ProvenanceRecord::new(
265            "claim_1".to_string(),
266            "Decided to do X".to_string(),
267            "agent_1".to_string(),
268            0.9,
269        );
270
271        record.add_evidence(evidence);
272        assert_eq!(record.evidence.len(), 1);
273        assert!(record.validate().is_ok());
274    }
275
276    #[test]
277    fn provenance_manifest() {
278        let mut manifest = ProvenanceManifest::new("agent_1".to_string());
279
280        let evidence = Evidence {
281            source_blob: Hash::zero(),
282            source_type: SourceType::ToolOutput,
283            extracted_text: "result".to_string(),
284            position: None,
285            timestamp: 1234567890,
286            metadata: Default::default(),
287        };
288
289        let mut record = ProvenanceRecord::new(
290            "claim_1".to_string(),
291            "test claim".to_string(),
292            "agent_1".to_string(),
293            0.8,
294        );
295
296        record.add_evidence(evidence);
297        manifest.add_record(record).unwrap();
298
299        assert_eq!(manifest.records.len(), 1);
300        assert!(manifest.find_by_claim_id("claim_1").is_some());
301        assert_eq!(manifest.get_claim_confidence("claim_1"), Some(0.8));
302    }
303
304    #[test]
305    fn invalid_confidence() {
306        let mut record = ProvenanceRecord::new(
307            "claim_1".to_string(),
308            "test".to_string(),
309            "agent_1".to_string(),
310            1.5, // Invalid: > 1.0
311        );
312
313        record.add_evidence(Evidence {
314            source_blob: Hash::zero(),
315            source_type: SourceType::UserInput,
316            extracted_text: "evidence".to_string(),
317            position: None,
318            timestamp: 1234567890,
319            metadata: Default::default(),
320        });
321
322        assert!(record.validate().is_err());
323    }
324
325    #[test]
326    fn find_by_tag() {
327        let mut manifest = ProvenanceManifest::new("agent_1".to_string());
328
329        let evidence = Evidence {
330            source_blob: Hash::zero(),
331            source_type: SourceType::ToolOutput,
332            extracted_text: "result".to_string(),
333            position: None,
334            timestamp: 1234567890,
335            metadata: Default::default(),
336        };
337
338        let mut record = ProvenanceRecord::new(
339            "claim_1".to_string(),
340            "test claim".to_string(),
341            "agent_1".to_string(),
342            0.8,
343        );
344
345        record = record.with_tags(vec!["important".to_string(), "decision".to_string()]);
346        record.add_evidence(evidence);
347
348        manifest.add_record(record).unwrap();
349
350        let results = manifest.find_by_tag("important");
351        assert_eq!(results.len(), 1);
352    }
353}