1use serde::{Deserialize, Serialize};
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use crate::hash::Hash;
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7pub enum SourceType {
8 ToolOutput,
9 WebPage,
10 UserInput,
11 Document,
12 ApiResponse,
13 Custom(String),
14}
15
16impl SourceType {
17 pub fn as_str(&self) -> &str {
18 match self {
19 SourceType::ToolOutput => "tool_output",
20 SourceType::WebPage => "web_page",
21 SourceType::UserInput => "user_input",
22 SourceType::Document => "document",
23 SourceType::ApiResponse => "api_response",
24 SourceType::Custom(s) => s,
25 }
26 }
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub struct Evidence {
31 pub source_blob: Hash,
32
33 #[serde(with = "source_type_serde")]
34 pub source_type: SourceType,
35 pub extracted_text: String,
36 pub position: Option<(usize, usize)>,
37 pub timestamp: u64,
38 #[serde(default)]
39 pub metadata: std::collections::BTreeMap<String, String>,
40}
41
42mod source_type_serde {
43 use super::SourceType;
44 use serde::{Deserialize, Deserializer, Serializer};
45
46 pub fn serialize<S>(source_type: &SourceType, serializer: S) -> Result<S::Ok, S::Error>
47 where
48 S: Serializer,
49 {
50 serializer.serialize_str(source_type.as_str())
51 }
52
53 pub fn deserialize<'de, D>(deserializer: D) -> Result<SourceType, D::Error>
54 where
55 D: Deserializer<'de>,
56 {
57 let s = String::deserialize(deserializer)?;
58 Ok(match s.as_str() {
59 "tool_output" => SourceType::ToolOutput,
60 "web_page" => SourceType::WebPage,
61 "user_input" => SourceType::UserInput,
62 "document" => SourceType::Document,
63 "api_response" => SourceType::ApiResponse,
64 custom => SourceType::Custom(custom.to_string()),
65 })
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub struct ProvenanceRecord {
71 pub claim_id: String,
72 pub claim_text: String,
73 pub evidence: Vec<Evidence>,
74 pub agent_id: String,
75 pub timestamp: u64,
76 pub confidence: f32,
77 pub reasoning: Option<String>,
78 #[serde(default)]
79 pub tags: Vec<String>,
80 #[serde(default = "default_provenance_version")]
81 pub schema_version: u32,
82}
83
84fn default_provenance_version() -> u32 {
85 1
86}
87
88impl ProvenanceRecord {
89 pub fn new(claim_id: String, claim_text: String, agent_id: String, confidence: f32) -> Self {
90 let timestamp = SystemTime::now()
91 .duration_since(UNIX_EPOCH)
92 .unwrap_or_default()
93 .as_secs();
94
95 Self {
96 claim_id,
97 claim_text,
98 evidence: Vec::new(),
99 agent_id,
100 timestamp,
101 confidence,
102 reasoning: None,
103 tags: Vec::new(),
104 schema_version: 1,
105 }
106 }
107
108 pub fn add_evidence(&mut self, evidence: Evidence) -> &mut Self {
109 self.evidence.push(evidence);
110 self
111 }
112
113 pub fn add_evidence_batch(&mut self, evidence: Vec<Evidence>) -> &mut Self {
114 self.evidence.extend(evidence);
115 self
116 }
117
118 pub fn with_reasoning(mut self, reasoning: String) -> Self {
120 self.reasoning = Some(reasoning);
121 self
122 }
123
124 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
126 self.tags = tags;
127 self
128 }
129
130 pub fn validate(&self) -> anyhow::Result<()> {
132 if self.claim_id.is_empty() {
133 return Err(anyhow::anyhow!("claim_id cannot be empty"));
134 }
135
136 if !(0.0..=1.0).contains(&self.confidence) {
137 return Err(anyhow::anyhow!(
138 "confidence must be between 0.0 and 1.0, got {}",
139 self.confidence
140 ));
141 }
142
143 if self.evidence.is_empty() {
144 return Err(anyhow::anyhow!(
145 "provenance record must have at least one piece of evidence"
146 ));
147 }
148
149 Ok(())
150 }
151}
152
153#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
155pub struct ProvenanceManifest {
156 #[serde(default = "default_provenance_version")]
158 pub schema_version: u32,
159
160 pub records: Vec<ProvenanceRecord>,
162
163 pub created_at: u64,
165
166 pub agent_id: String,
168}
169
170impl ProvenanceManifest {
171 pub fn new(agent_id: String) -> Self {
173 let created_at = SystemTime::now()
174 .duration_since(UNIX_EPOCH)
175 .unwrap_or_default()
176 .as_secs();
177
178 Self {
179 schema_version: 1,
180 records: Vec::new(),
181 created_at,
182 agent_id,
183 }
184 }
185
186 pub fn add_record(&mut self, record: ProvenanceRecord) -> anyhow::Result<&mut Self> {
188 record.validate()?;
189 self.records.push(record);
190 Ok(self)
191 }
192
193 pub fn find_by_claim_id(&self, claim_id: &str) -> Option<&ProvenanceRecord> {
195 self.records.iter().find(|r| r.claim_id == claim_id)
196 }
197
198 pub fn find_by_tag(&self, tag: &str) -> Vec<&ProvenanceRecord> {
200 self.records
201 .iter()
202 .filter(|r| r.tags.contains(&tag.to_string()))
203 .collect()
204 }
205
206 pub fn find_by_agent(&self, agent_id: &str) -> Vec<&ProvenanceRecord> {
208 self.records
209 .iter()
210 .filter(|r| r.agent_id == agent_id)
211 .collect()
212 }
213
214 pub fn get_claim_confidence(&self, claim_id: &str) -> Option<f32> {
216 self.find_by_claim_id(claim_id).map(|r| r.confidence)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn create_evidence() {
226 let evidence = Evidence {
227 source_blob: Hash::zero(),
228 source_type: SourceType::ToolOutput,
229 extracted_text: "Tool result".to_string(),
230 position: Some((0, 11)),
231 timestamp: 1234567890,
232 metadata: Default::default(),
233 };
234
235 assert_eq!(evidence.source_type, SourceType::ToolOutput);
236 assert_eq!(evidence.extracted_text, "Tool result");
237 }
238
239 #[test]
240 fn create_provenance_record() {
241 let record = ProvenanceRecord::new(
242 "claim_1".to_string(),
243 "The agent decided to run task X".to_string(),
244 "agent_1".to_string(),
245 0.95,
246 );
247
248 assert_eq!(record.claim_id, "claim_1");
249 assert_eq!(record.confidence, 0.95);
250 assert!(record.evidence.is_empty());
251 }
252
253 #[test]
254 fn provenance_with_evidence() {
255 let evidence = Evidence {
256 source_blob: Hash::zero(),
257 source_type: SourceType::UserInput,
258 extracted_text: "user wants X".to_string(),
259 position: None,
260 timestamp: 1234567890,
261 metadata: Default::default(),
262 };
263
264 let mut record = ProvenanceRecord::new(
265 "claim_1".to_string(),
266 "Decided to do X".to_string(),
267 "agent_1".to_string(),
268 0.9,
269 );
270
271 record.add_evidence(evidence);
272 assert_eq!(record.evidence.len(), 1);
273 assert!(record.validate().is_ok());
274 }
275
276 #[test]
277 fn provenance_manifest() {
278 let mut manifest = ProvenanceManifest::new("agent_1".to_string());
279
280 let evidence = Evidence {
281 source_blob: Hash::zero(),
282 source_type: SourceType::ToolOutput,
283 extracted_text: "result".to_string(),
284 position: None,
285 timestamp: 1234567890,
286 metadata: Default::default(),
287 };
288
289 let mut record = ProvenanceRecord::new(
290 "claim_1".to_string(),
291 "test claim".to_string(),
292 "agent_1".to_string(),
293 0.8,
294 );
295
296 record.add_evidence(evidence);
297 manifest.add_record(record).unwrap();
298
299 assert_eq!(manifest.records.len(), 1);
300 assert!(manifest.find_by_claim_id("claim_1").is_some());
301 assert_eq!(manifest.get_claim_confidence("claim_1"), Some(0.8));
302 }
303
304 #[test]
305 fn invalid_confidence() {
306 let mut record = ProvenanceRecord::new(
307 "claim_1".to_string(),
308 "test".to_string(),
309 "agent_1".to_string(),
310 1.5, );
312
313 record.add_evidence(Evidence {
314 source_blob: Hash::zero(),
315 source_type: SourceType::UserInput,
316 extracted_text: "evidence".to_string(),
317 position: None,
318 timestamp: 1234567890,
319 metadata: Default::default(),
320 });
321
322 assert!(record.validate().is_err());
323 }
324
325 #[test]
326 fn find_by_tag() {
327 let mut manifest = ProvenanceManifest::new("agent_1".to_string());
328
329 let evidence = Evidence {
330 source_blob: Hash::zero(),
331 source_type: SourceType::ToolOutput,
332 extracted_text: "result".to_string(),
333 position: None,
334 timestamp: 1234567890,
335 metadata: Default::default(),
336 };
337
338 let mut record = ProvenanceRecord::new(
339 "claim_1".to_string(),
340 "test claim".to_string(),
341 "agent_1".to_string(),
342 0.8,
343 );
344
345 record = record.with_tags(vec!["important".to_string(), "decision".to_string()]);
346 record.add_evidence(evidence);
347
348 manifest.add_record(record).unwrap();
349
350 let results = manifest.find_by_tag("important");
351 assert_eq!(results.len(), 1);
352 }
353}