perspt_store/
store.rs

1//! Session Store Implementation
2//!
3//! Provides CRUD operations for SRBN sessions, node states, and energy history.
4
5use anyhow::{Context, Result};
6use duckdb::Connection;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::path::PathBuf;
10
11use crate::schema::init_schema;
12
13/// Record for a session
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SessionRecord {
16    pub session_id: String,
17    pub task: String,
18    pub working_dir: String,
19    pub merkle_root: Option<Vec<u8>>,
20    pub detected_toolchain: Option<String>,
21    pub status: String,
22}
23
24/// Record for node state
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct NodeStateRecord {
27    pub node_id: String,
28    pub session_id: String,
29    pub state: String,
30    pub v_total: f32,
31    pub merkle_hash: Option<Vec<u8>>,
32    pub attempt_count: i32,
33}
34
35/// Record for energy history
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EnergyRecord {
38    pub node_id: String,
39    pub session_id: String,
40    pub v_syn: f32,
41    pub v_str: f32,
42    pub v_log: f32,
43    pub v_boot: f32,
44    pub v_sheaf: f32,
45    pub v_total: f32,
46}
47
48/// Record for LLM request/response logging
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LlmRequestRecord {
51    pub session_id: String,
52    pub node_id: Option<String>,
53    pub model: String,
54    pub prompt: String,
55    pub response: String,
56    pub tokens_in: i32,
57    pub tokens_out: i32,
58    pub latency_ms: i32,
59}
60
61use std::sync::Mutex;
62
63/// Session store for SRBN persistence
64pub struct SessionStore {
65    conn: Mutex<Connection>,
66}
67
68impl SessionStore {
69    /// Create a new session store with default path
70    pub fn new() -> Result<Self> {
71        let db_path = Self::default_db_path()?;
72        Self::open(&db_path)
73    }
74
75    /// Open a session store at the given path
76    pub fn open(path: &PathBuf) -> Result<Self> {
77        // Ensure parent directory exists
78        if let Some(parent) = path.parent() {
79            std::fs::create_dir_all(parent)?;
80        }
81
82        let conn = Connection::open(path).context("Failed to open DuckDB")?;
83        init_schema(&conn)?;
84
85        Ok(Self {
86            conn: Mutex::new(conn),
87        })
88    }
89
90    /// Get the default database path (~/.local/share/perspt/perspt.db or similar)
91    pub fn default_db_path() -> Result<PathBuf> {
92        let data_dir = dirs::data_local_dir()
93            .context("Could not find local data directory")?
94            .join("perspt");
95        Ok(data_dir.join("perspt.db"))
96    }
97
98    /// Create a new session
99    pub fn create_session(&self, session: &SessionRecord) -> Result<()> {
100        self.conn.lock().unwrap().execute(
101            r#"
102            INSERT INTO sessions (session_id, task, working_dir, merkle_root, detected_toolchain, status)
103            VALUES (?, ?, ?, ?, ?, ?)
104            "#,
105            [
106                &session.session_id,
107                &session.task,
108                &session.working_dir,
109                &session.merkle_root.as_ref().map(hex::encode).unwrap_or_default(),
110                &session.detected_toolchain.clone().unwrap_or_default(),
111                &session.status,
112            ],
113        )?;
114        Ok(())
115    }
116
117    /// Update session merkle root
118    pub fn update_merkle_root(&self, session_id: &str, merkle_root: &[u8]) -> Result<()> {
119        self.conn.lock().unwrap().execute(
120            "UPDATE sessions SET merkle_root = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
121            [hex::encode(merkle_root), session_id.to_string()],
122        )?;
123        Ok(())
124    }
125
126    /// Record node state
127    pub fn record_node_state(&self, record: &NodeStateRecord) -> Result<()> {
128        self.conn.lock().unwrap().execute(
129            r#"
130            INSERT INTO node_states (node_id, session_id, state, v_total, merkle_hash, attempt_count)
131            VALUES (?, ?, ?, ?, ?, ?)
132            "#,
133            [
134                &record.node_id,
135                &record.session_id,
136                &record.state,
137                &record.v_total.to_string(),
138                &record.merkle_hash.as_ref().map(hex::encode).unwrap_or_default(),
139                &record.attempt_count.to_string(),
140            ],
141        )?;
142        Ok(())
143    }
144
145    /// Record energy measurement
146    pub fn record_energy(&self, record: &EnergyRecord) -> Result<()> {
147        self.conn.lock().unwrap().execute(
148            r#"
149            INSERT INTO energy_history (node_id, session_id, v_syn, v_str, v_log, v_boot, v_sheaf, v_total)
150            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
151            "#,
152            [
153                &record.node_id,
154                &record.session_id,
155                &record.v_syn.to_string(),
156                &record.v_str.to_string(),
157                &record.v_log.to_string(),
158                &record.v_boot.to_string(),
159                &record.v_sheaf.to_string(),
160                &record.v_total.to_string(),
161            ],
162        )?;
163        Ok(())
164    }
165
166    /// Calculate Merkle hash for content
167    pub fn calculate_hash(content: &[u8]) -> Vec<u8> {
168        let mut hasher = Sha256::new();
169        hasher.update(content);
170        hasher.finalize().to_vec()
171    }
172
173    /// Get session by ID
174    pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>> {
175        let conn = self.conn.lock().unwrap();
176        let mut stmt = conn.prepare(
177            "SELECT session_id, task, working_dir, merkle_root, detected_toolchain, status FROM sessions WHERE session_id = ?"
178        )?;
179
180        let mut rows = stmt.query([session_id])?;
181        if let Some(row) = rows.next()? {
182            Ok(Some(SessionRecord {
183                session_id: row.get(0)?,
184                task: row.get(1)?,
185                working_dir: row.get(2)?,
186                merkle_root: row
187                    .get::<_, Option<String>>(3)?
188                    .and_then(|s| hex::decode(s).ok()),
189                detected_toolchain: row.get(4)?,
190                status: row.get(5)?,
191            }))
192        } else {
193            Ok(None)
194        }
195    }
196
197    /// Get the directory for session artifacts (~/.local/share/perspt/sessions/<id>)
198    pub fn get_session_dir(&self, session_id: &str) -> Result<PathBuf> {
199        let data_dir = dirs::data_local_dir()
200            .context("Could not find local data directory")?
201            .join("perspt")
202            .join("sessions")
203            .join(session_id);
204        Ok(data_dir)
205    }
206
207    /// Ensure a session directory exists and return the path
208    pub fn create_session_dir(&self, session_id: &str) -> Result<PathBuf> {
209        let dir = self.get_session_dir(session_id)?;
210        if !dir.exists() {
211            std::fs::create_dir_all(&dir).context("Failed to create session directory")?;
212        }
213        Ok(dir)
214    }
215
216    /// Get energy history for a node (query)
217    pub fn get_energy_history(&self, session_id: &str, node_id: &str) -> Result<Vec<EnergyRecord>> {
218        let conn = self.conn.lock().unwrap();
219        let mut stmt = conn.prepare(
220            "SELECT node_id, session_id, v_syn, v_str, v_log, v_boot, v_sheaf, v_total FROM energy_history WHERE session_id = ? AND node_id = ? ORDER BY timestamp"
221        )?;
222
223        let mut rows = stmt.query([session_id, node_id])?;
224        let mut records = Vec::new();
225
226        while let Some(row) = rows.next()? {
227            records.push(EnergyRecord {
228                node_id: row.get(0)?,
229                session_id: row.get(1)?,
230                v_syn: row.get::<_, f64>(2)? as f32,
231                v_str: row.get::<_, f64>(3)? as f32,
232                v_log: row.get::<_, f64>(4)? as f32,
233                v_boot: row.get::<_, f64>(5)? as f32,
234                v_sheaf: row.get::<_, f64>(6)? as f32,
235                v_total: row.get::<_, f64>(7)? as f32,
236            });
237        }
238
239        Ok(records)
240    }
241
242    /// List recent sessions (newest first)
243    pub fn list_recent_sessions(&self, limit: usize) -> Result<Vec<SessionRecord>> {
244        let conn = self.conn.lock().unwrap();
245        let mut stmt = conn.prepare(
246            "SELECT session_id, task, working_dir, merkle_root, detected_toolchain, status
247             FROM sessions ORDER BY created_at DESC LIMIT ?",
248        )?;
249
250        let mut rows = stmt.query([limit.to_string()])?;
251        let mut records = Vec::new();
252
253        while let Some(row) = rows.next()? {
254            // merkle_root is stored as BLOB, read it directly as Option<Vec<u8>>
255            let merkle_root: Option<Vec<u8>> = row.get(3).ok();
256
257            records.push(SessionRecord {
258                session_id: row.get(0)?,
259                task: row.get(1)?,
260                working_dir: row.get(2)?,
261                merkle_root,
262                detected_toolchain: row.get(4)?,
263                status: row.get(5)?,
264            });
265        }
266
267        Ok(records)
268    }
269
270    /// Get all node states for a session
271    pub fn get_node_states(&self, session_id: &str) -> Result<Vec<NodeStateRecord>> {
272        let conn = self.conn.lock().unwrap();
273        let mut stmt = conn.prepare(
274            "SELECT node_id, session_id, state, v_total, merkle_hash, attempt_count
275             FROM node_states WHERE session_id = ? ORDER BY created_at",
276        )?;
277
278        let mut rows = stmt.query([session_id])?;
279        let mut records = Vec::new();
280
281        while let Some(row) = rows.next()? {
282            records.push(NodeStateRecord {
283                node_id: row.get(0)?,
284                session_id: row.get(1)?,
285                state: row.get(2)?,
286                v_total: row.get::<_, f64>(3)? as f32,
287                merkle_hash: row
288                    .get::<_, Option<String>>(4)?
289                    .and_then(|s| hex::decode(s).ok()),
290                attempt_count: row.get(5)?,
291            });
292        }
293
294        Ok(records)
295    }
296
297    /// Update session status
298    pub fn update_session_status(&self, session_id: &str, status: &str) -> Result<()> {
299        self.conn.lock().unwrap().execute(
300            "UPDATE sessions SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
301            [status, session_id],
302        )?;
303        Ok(())
304    }
305
306    /// Record an LLM request/response
307    pub fn record_llm_request(&self, record: &LlmRequestRecord) -> Result<()> {
308        let conn = self.conn.lock().unwrap();
309        conn.execute(
310            r#"
311            INSERT INTO llm_requests (session_id, node_id, model, prompt, response, tokens_in, tokens_out, latency_ms)
312            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
313            "#,
314            [
315                &record.session_id,
316                &record.node_id.clone().unwrap_or_default(),
317                &record.model,
318                &record.prompt,
319                &record.response,
320                &record.tokens_in.to_string(),
321                &record.tokens_out.to_string(),
322                &record.latency_ms.to_string(),
323            ],
324        )?;
325        Ok(())
326    }
327
328    /// Get LLM requests for a session
329    pub fn get_llm_requests(&self, session_id: &str) -> Result<Vec<LlmRequestRecord>> {
330        let conn = self.conn.lock().unwrap();
331        let mut stmt = conn.prepare(
332            "SELECT session_id, node_id, model, prompt, response, tokens_in, tokens_out, latency_ms
333             FROM llm_requests WHERE session_id = ? ORDER BY timestamp",
334        )?;
335
336        let mut rows = stmt.query([session_id])?;
337        let mut records = Vec::new();
338
339        while let Some(row) = rows.next()? {
340            let node_id: Option<String> = row.get(1)?;
341            records.push(LlmRequestRecord {
342                session_id: row.get(0)?,
343                node_id: if node_id.as_ref().map(|s| s.is_empty()).unwrap_or(true) {
344                    None
345                } else {
346                    node_id
347                },
348                model: row.get(2)?,
349                prompt: row.get(3)?,
350                response: row.get(4)?,
351                tokens_in: row.get(5)?,
352                tokens_out: row.get(6)?,
353                latency_ms: row.get(7)?,
354            });
355        }
356
357        Ok(records)
358    }
359
360    /// Count all LLM requests in the database (for debugging)
361    pub fn count_all_llm_requests(&self) -> Result<i64> {
362        let conn = self.conn.lock().unwrap();
363        let mut stmt = conn.prepare("SELECT COUNT(*) FROM llm_requests")?;
364        let count: i64 = stmt.query_row([], |row| row.get(0))?;
365        Ok(count)
366    }
367
368    /// Get all LLM requests (for debugging)
369    pub fn get_all_llm_requests(&self, limit: usize) -> Result<Vec<LlmRequestRecord>> {
370        let conn = self.conn.lock().unwrap();
371        let mut stmt = conn.prepare(
372            "SELECT session_id, node_id, model, prompt, response, tokens_in, tokens_out, latency_ms
373             FROM llm_requests ORDER BY timestamp DESC LIMIT ?",
374        )?;
375
376        let mut rows = stmt.query([limit as i64])?;
377        let mut records = Vec::new();
378
379        while let Some(row) = rows.next()? {
380            let node_id: Option<String> = row.get(1)?;
381            records.push(LlmRequestRecord {
382                session_id: row.get(0)?,
383                node_id: if node_id.as_ref().map(|s| s.is_empty()).unwrap_or(true) {
384                    None
385                } else {
386                    node_id
387                },
388                model: row.get(2)?,
389                prompt: row.get(3)?,
390                response: row.get(4)?,
391                tokens_in: row.get(5)?,
392                tokens_out: row.get(6)?,
393                latency_ms: row.get(7)?,
394            });
395        }
396
397        Ok(records)
398    }
399}