Skip to main content

agent_citation/
store.rs

1use std::fs::{self, OpenOptions};
2use std::io::{BufRead, BufReader, Write};
3use std::path::{Path, PathBuf};
4use std::sync::{Arc, Mutex};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use serde::{Deserialize, Serialize};
8
9use crate::citation::Citation;
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct AttributionRecord {
13    pub turn_id: String,
14    pub text: String,
15    pub citations: Vec<Citation>,
16    pub captured_at: f64,
17}
18
19impl AttributionRecord {
20    pub fn new(turn_id: String, text: String, citations: Vec<Citation>) -> Self {
21        let captured_at = SystemTime::now()
22            .duration_since(UNIX_EPOCH)
23            .map(|d| d.as_secs_f64())
24            .unwrap_or(0.0);
25        AttributionRecord {
26            turn_id,
27            text,
28            citations,
29            captured_at,
30        }
31    }
32
33    pub fn to_json_value(&self) -> serde_json::Value {
34        let cites: Vec<serde_json::Value> =
35            self.citations.iter().map(|c| c.to_json_value()).collect();
36        let mut map = serde_json::Map::new();
37        map.insert(
38            "turn_id".into(),
39            serde_json::Value::String(self.turn_id.clone()),
40        );
41        map.insert("text".into(), serde_json::Value::String(self.text.clone()));
42        map.insert("citations".into(), serde_json::Value::Array(cites));
43        if let Some(n) = serde_json::Number::from_f64(self.captured_at) {
44            map.insert("captured_at".into(), serde_json::Value::Number(n));
45        }
46        serde_json::Value::Object(map)
47    }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub enum StoreError {
52    BlankTurnId,
53    Io(String),
54    Parse(String),
55}
56
57impl std::fmt::Display for StoreError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            StoreError::BlankTurnId => f.write_str("turn_id must be a non-empty string"),
61            StoreError::Io(s) => write!(f, "io error: {}", s),
62            StoreError::Parse(s) => write!(f, "parse error: {}", s),
63        }
64    }
65}
66
67impl std::error::Error for StoreError {}
68
69pub trait Sink: Send + Sync {
70    fn write(&self, record: AttributionRecord) -> Result<(), StoreError>;
71    fn read_all(&self) -> Result<Vec<AttributionRecord>, StoreError>;
72}
73
74#[derive(Default)]
75pub struct InMemorySink {
76    records: Arc<Mutex<Vec<AttributionRecord>>>,
77}
78
79impl InMemorySink {
80    pub fn new() -> Self {
81        InMemorySink {
82            records: Arc::new(Mutex::new(Vec::new())),
83        }
84    }
85
86    pub fn clear(&self) {
87        if let Ok(mut g) = self.records.lock() {
88            g.clear();
89        }
90    }
91
92    pub fn len(&self) -> usize {
93        self.records.lock().map(|g| g.len()).unwrap_or(0)
94    }
95
96    pub fn is_empty(&self) -> bool {
97        self.len() == 0
98    }
99}
100
101impl Sink for InMemorySink {
102    fn write(&self, record: AttributionRecord) -> Result<(), StoreError> {
103        let mut g = self
104            .records
105            .lock()
106            .map_err(|e| StoreError::Io(e.to_string()))?;
107        g.push(record);
108        Ok(())
109    }
110
111    fn read_all(&self) -> Result<Vec<AttributionRecord>, StoreError> {
112        let g = self
113            .records
114            .lock()
115            .map_err(|e| StoreError::Io(e.to_string()))?;
116        Ok(g.clone())
117    }
118}
119
120pub struct JsonlSink {
121    path: PathBuf,
122}
123
124impl JsonlSink {
125    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, StoreError> {
126        let path = path.as_ref().to_path_buf();
127        if let Some(parent) = path.parent() {
128            if !parent.as_os_str().is_empty() {
129                fs::create_dir_all(parent).map_err(|e| StoreError::Io(e.to_string()))?;
130            }
131        }
132        Ok(JsonlSink { path })
133    }
134
135    pub fn path(&self) -> &Path {
136        &self.path
137    }
138}
139
140impl Sink for JsonlSink {
141    fn write(&self, record: AttributionRecord) -> Result<(), StoreError> {
142        let payload = serde_json::to_string(&record.to_json_value())
143            .map_err(|e| StoreError::Parse(e.to_string()))?;
144        let mut fh = OpenOptions::new()
145            .append(true)
146            .create(true)
147            .open(&self.path)
148            .map_err(|e| StoreError::Io(e.to_string()))?;
149        fh.write_all(payload.as_bytes())
150            .map_err(|e| StoreError::Io(e.to_string()))?;
151        fh.write_all(b"\n").map_err(|e| StoreError::Io(e.to_string()))?;
152        Ok(())
153    }
154
155    fn read_all(&self) -> Result<Vec<AttributionRecord>, StoreError> {
156        if !self.path.exists() {
157            return Ok(Vec::new());
158        }
159        let fh = std::fs::File::open(&self.path).map_err(|e| StoreError::Io(e.to_string()))?;
160        let rdr = BufReader::new(fh);
161        let mut out: Vec<AttributionRecord> = Vec::new();
162        for line in rdr.lines() {
163            let raw = line.map_err(|e| StoreError::Io(e.to_string()))?;
164            let trimmed = raw.trim();
165            if trimmed.is_empty() {
166                continue;
167            }
168            let v: serde_json::Value =
169                serde_json::from_str(trimmed).map_err(|e| StoreError::Parse(e.to_string()))?;
170            let turn_id = v
171                .get("turn_id")
172                .and_then(|x| x.as_str())
173                .unwrap_or("")
174                .to_string();
175            let text = v
176                .get("text")
177                .and_then(|x| x.as_str())
178                .unwrap_or("")
179                .to_string();
180            let captured_at = v
181                .get("captured_at")
182                .and_then(|x| x.as_f64())
183                .unwrap_or(0.0);
184            let citations: Vec<Citation> = v
185                .get("citations")
186                .and_then(|x| x.as_array())
187                .map(|arr| {
188                    arr.iter()
189                        .filter_map(|c| Citation::from_json_value(c).ok())
190                        .collect()
191                })
192                .unwrap_or_default();
193            out.push(AttributionRecord {
194                turn_id,
195                text,
196                citations,
197                captured_at,
198            });
199        }
200        Ok(out)
201    }
202}
203
204pub struct CitationStore {
205    sink: Box<dyn Sink>,
206}
207
208impl Default for CitationStore {
209    fn default() -> Self {
210        CitationStore {
211            sink: Box::new(InMemorySink::new()),
212        }
213    }
214}
215
216impl CitationStore {
217    pub fn new() -> Self {
218        Self::default()
219    }
220
221    pub fn with_sink(sink: Box<dyn Sink>) -> Self {
222        CitationStore { sink }
223    }
224
225    pub fn sink(&self) -> &dyn Sink {
226        self.sink.as_ref()
227    }
228
229    pub fn attach(
230        &self,
231        turn_id: impl Into<String>,
232        text: impl Into<String>,
233        citations: impl IntoIterator<Item = Citation>,
234    ) -> Result<AttributionRecord, StoreError> {
235        let turn_id = turn_id.into();
236        if turn_id.trim().is_empty() {
237            return Err(StoreError::BlankTurnId);
238        }
239        let record = AttributionRecord::new(turn_id, text.into(), citations.into_iter().collect());
240        self.sink.write(record.clone())?;
241        Ok(record)
242    }
243
244    pub fn export(&self) -> Result<Vec<serde_json::Value>, StoreError> {
245        Ok(self
246            .sink
247            .read_all()?
248            .iter()
249            .map(|r| r.to_json_value())
250            .collect())
251    }
252
253    pub fn render_text_summary(&self) -> Result<String, StoreError> {
254        let mut out = String::new();
255        for record in self.sink.read_all()? {
256            let cite_ids: String = if record.citations.is_empty() {
257                "-".to_string()
258            } else {
259                record
260                    .citations
261                    .iter()
262                    .map(|c| c.id.clone())
263                    .collect::<Vec<_>>()
264                    .join(",")
265            };
266            out.push_str(&format!(
267                "[{}] cites={} len={}\n",
268                record.turn_id,
269                cite_ids,
270                record.text.len()
271            ));
272        }
273        Ok(out)
274    }
275}