Skip to main content

kaizen/store/
remote_cache.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Local SQLite cache for provider pull results. Writers use a single transaction per refresh.
3
4use crate::store::Store;
5use crate::sync::outbound::OutboundEvent;
6use anyhow::Result;
7use rusqlite::{Transaction, params};
8use std::collections::{HashMap, HashSet};
9
10/// Row `remote_pull_state` (singleton `id = 1`).
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct RemotePullState {
13    pub query_provider: String,
14    pub cursor_json: String,
15    pub last_success_ms: Option<i64>,
16}
17
18impl Default for RemotePullState {
19    fn default() -> Self {
20        Self {
21            query_provider: "none".to_string(),
22            cursor_json: String::new(),
23            last_success_ms: None,
24        }
25    }
26}
27
28/// Read/write remote cache tables and pull cursor. Implemented for [`Store`].
29pub trait RemoteCacheStore {
30    fn get_pull_state(&self) -> Result<RemotePullState>;
31    /// Update cursor and last success time (call only after a successful import in the same txn or immediately after commit).
32    fn set_pull_state(&self, state: &RemotePullState) -> Result<()>;
33    /// Run `f` inside a transaction (use for clear + bulk insert + cursor update).
34    fn with_remote_refresh<T>(
35        &self,
36        f: impl for<'a> FnOnce(&'a Transaction<'_>) -> Result<T>,
37    ) -> Result<T>;
38}
39
40/// Delete all rows from `remote_*` data tables (use inside `with_remote_refresh` before re-insert).
41pub fn clear_remote_cache_tables(tx: &Transaction<'_>) -> Result<()> {
42    for table in [
43        "remote_sessions",
44        "remote_events",
45        "remote_tool_spans",
46        "remote_repo_snapshots",
47        "remote_workspace_facts",
48    ] {
49        tx.execute(&format!("DELETE FROM {table}"), [])?;
50    }
51    Ok(())
52}
53
54impl RemoteCacheStore for Store {
55    fn get_pull_state(&self) -> Result<RemotePullState> {
56        let conn = self.conn();
57        let row = conn.query_row(
58            "SELECT query_provider, cursor_json, last_success_ms FROM remote_pull_state WHERE id = 1",
59            [],
60            |r| {
61                Ok(RemotePullState {
62                    query_provider: r.get(0)?,
63                    cursor_json: r.get(1)?,
64                    last_success_ms: r.get(2)?,
65                })
66            },
67        );
68        row.map_err(Into::into)
69    }
70
71    fn set_pull_state(&self, state: &RemotePullState) -> Result<()> {
72        self.conn().execute(
73            "UPDATE remote_pull_state SET query_provider = ?1, cursor_json = ?2, last_success_ms = ?3 WHERE id = 1",
74            params![
75                &state.query_provider,
76                &state.cursor_json,
77                state.last_success_ms
78            ],
79        )?;
80        Ok(())
81    }
82
83    fn with_remote_refresh<T>(
84        &self,
85        f: impl for<'a> FnOnce(&'a Transaction<'_>) -> Result<T>,
86    ) -> Result<T> {
87        let tx = self.conn().unchecked_transaction()?;
88        let out = f(&tx)?;
89        tx.commit()?;
90        Ok(out)
91    }
92}
93
94impl Store {
95    /// Upsert one remote event row (caller runs inside transaction as needed).
96    pub fn remote_insert_event(
97        &self,
98        team_id: &str,
99        workspace_hash: &str,
100        session_id_hash: &str,
101        event_seq: i64,
102        json: &str,
103    ) -> Result<()> {
104        self.conn().execute(
105            "INSERT OR REPLACE INTO remote_events (team_id, workspace_hash, session_id_hash, event_seq, json)
106             VALUES (?1, ?2, ?3, ?4, ?5)",
107            params![team_id, workspace_hash, session_id_hash, event_seq, json],
108        )?;
109        Ok(())
110    }
111
112    /// JSON payloads in remote_events for this team/workspace (for provider-side retro/merge).
113    pub fn list_remote_event_jsons(
114        &self,
115        team_id: &str,
116        workspace_hash: &str,
117    ) -> Result<Vec<String>> {
118        let mut stmt = self.conn().prepare(
119            "SELECT json FROM remote_events WHERE team_id = ?1 AND workspace_hash = ?2 ORDER BY session_id_hash, event_seq",
120        )?;
121        let rows = stmt.query_map(params![team_id, workspace_hash], |r| r.get::<_, String>(0))?;
122        let mut out = Vec::new();
123        for row in rows {
124            out.push(row?);
125        }
126        Ok(out)
127    }
128
129    /// Event-derived aggregates for `summary` / `insights` / `metrics` when `DataSource` is not local.
130    pub fn remote_event_aggregate(
131        &self,
132        team_id: &str,
133        workspace_hash: &str,
134    ) -> Result<RemoteEventAgg> {
135        let mut out = RemoteEventAgg::default();
136        let now_ms = now_ms();
137        let week_ago = now_ms.saturating_sub(7 * 86_400_000);
138        let now_day = now_ms / 86_400_000;
139
140        let mut sessions: HashSet<String> = HashSet::new();
141        let mut by_agent: HashMap<String, HashSet<String>> = HashMap::new();
142        let mut by_model: HashMap<String, HashSet<String>> = HashMap::new();
143        let mut top_tools: HashMap<String, u64> = HashMap::new();
144        let mut tool_tokens: HashMap<String, u64> = HashMap::new();
145        let mut sessions_by_day: [HashSet<String>; 7] = std::array::from_fn(|_| HashSet::new());
146        let mut with_cost: HashSet<String> = HashSet::new();
147
148        for raw in self.list_remote_event_jsons(team_id, workspace_hash)? {
149            let o: OutboundEvent = match serde_json::from_str(&raw) {
150                Ok(x) => x,
151                Err(_) => continue,
152            };
153            out.event_count = out.event_count.saturating_add(1);
154            sessions.insert(o.session_id_hash.clone());
155            if o.ts_ms >= week_ago {
156                for i in 0..7 {
157                    let target = now_day.saturating_sub(6 - i);
158                    let d = o.ts_ms / 86_400_000;
159                    if d == target {
160                        sessions_by_day[i as usize].insert(o.session_id_hash.clone());
161                    }
162                }
163            }
164            if let Some(c) = o.cost_usd_e6 {
165                out.total_cost_usd_e6 = out.total_cost_usd_e6.saturating_add(c);
166                with_cost.insert(o.session_id_hash.clone());
167            }
168            by_agent
169                .entry(o.agent.clone())
170                .or_default()
171                .insert(o.session_id_hash.clone());
172            by_model
173                .entry(o.model.clone())
174                .or_default()
175                .insert(o.session_id_hash.clone());
176            if let Some(t) = o.tool.as_ref() {
177                *top_tools.entry(t.clone()).or_insert(0) += 1;
178                let tok = (o.tokens_in.unwrap_or(0) as u64)
179                    .saturating_add(o.tokens_out.unwrap_or(0) as u64)
180                    .saturating_add(o.reasoning_tokens.unwrap_or(0) as u64);
181                *tool_tokens.entry(t.clone()).or_insert(0) += tok;
182            }
183        }
184
185        out.session_count = sessions.len() as u64;
186        out.sessions_with_cost = with_cost.len() as u64;
187        out.by_agent = key_sets_to_top(by_agent);
188        out.by_model = key_sets_to_top(by_model);
189        out.top_tools = top_hash_to_vec(&top_tools, 10);
190        out.tool_token_totals = top_hash_to_vec(&tool_tokens, 20);
191        out.sessions_by_day = (0u64..7)
192            .map(|i| {
193                (
194                    day_label(now_day.saturating_sub(6 - i)).to_string(),
195                    sessions_by_day[i as usize].len() as u64,
196                )
197            })
198            .collect();
199        Ok(out)
200    }
201}
202
203/// Aggregated remote events for `kaizen summary` / `insights` (and tool rows for `metrics`).
204#[derive(Debug, Clone, Default, PartialEq, Eq)]
205pub struct RemoteEventAgg {
206    pub session_count: u64,
207    pub event_count: u64,
208    pub total_cost_usd_e6: i64,
209    pub sessions_with_cost: u64,
210    pub by_agent: Vec<(String, u64)>,
211    pub by_model: Vec<(String, u64)>,
212    pub top_tools: Vec<(String, u64)>,
213    /// Aligned to local `InsightsStats::sessions_by_day` (last 7d, Mon..Sun order in label — same formula as local sessions).
214    pub sessions_by_day: Vec<(String, u64)>,
215    /// Per-tool total tokens (in+out+reasoning) for merging into `highest_token_tools`.
216    pub tool_token_totals: Vec<(String, u64)>,
217}
218
219fn now_ms() -> u64 {
220    std::time::SystemTime::now()
221        .duration_since(std::time::UNIX_EPOCH)
222        .unwrap_or_default()
223        .as_millis() as u64
224}
225
226fn day_label(day_idx: u64) -> &'static str {
227    ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"][((day_idx + 4) % 7) as usize]
228}
229
230fn key_sets_to_top(m: HashMap<String, HashSet<String>>) -> Vec<(String, u64)> {
231    let mut v: Vec<(String, u64)> = m.into_iter().map(|(k, s)| (k, s.len() as u64)).collect();
232    v.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
233    v
234}
235
236fn top_hash_to_vec(m: &HashMap<String, u64>, limit: usize) -> Vec<(String, u64)> {
237    let mut v: Vec<(String, u64)> = m.iter().map(|(a, c)| (a.clone(), *c)).collect();
238    v.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
239    v.truncate(limit);
240    v
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::store::Store;
247    use tempfile::tempdir;
248
249    #[test]
250    fn pull_state_roundtrip() {
251        let dir = tempdir().unwrap();
252        let db = dir.path().join("t.db");
253        let s = Store::open(&db).unwrap();
254        let st = s.get_pull_state().unwrap();
255        assert_eq!(st.query_provider, "none");
256        s.set_pull_state(&RemotePullState {
257            query_provider: "posthog".into(),
258            cursor_json: r#"{"x":1}"#.into(),
259            last_success_ms: Some(42),
260        })
261        .unwrap();
262        let st2 = s.get_pull_state().unwrap();
263        assert_eq!(st2.query_provider, "posthog");
264        assert_eq!(st2.last_success_ms, Some(42));
265    }
266
267    #[test]
268    fn clear_remote_tx() {
269        let dir = tempdir().unwrap();
270        let db = dir.path().join("t.db");
271        let s = Store::open(&db).unwrap();
272        s.remote_insert_event("t", "w", "s", 0, "{}").unwrap();
273        s.with_remote_refresh(|tx| {
274            clear_remote_cache_tables(tx)?;
275            Ok(())
276        })
277        .unwrap();
278        let n: i64 = s
279            .conn()
280            .query_row("SELECT COUNT(*) FROM remote_events", [], |r| r.get(0))
281            .unwrap();
282        assert_eq!(n, 0);
283    }
284}