Skip to main content

kernex_memory/store/
usage.rs

1//! Token usage recording and cost tracking.
2
3use super::Store;
4use kernex_core::error::KernexError;
5use kernex_core::pricing::pricing_for;
6
7/// Per-dimension token breakdown reported by providers that distinguish
8/// regular input, output, prompt-cache reads, and prompt-cache creations
9/// (e.g. Anthropic). All fields are optional — providers that do not report
10/// a breakdown should pass `UsageBreakdown::default()`.
11#[derive(Debug, Clone, Copy, Default)]
12pub struct UsageBreakdown {
13    pub input_tokens: Option<u64>,
14    pub output_tokens: Option<u64>,
15    pub cache_read_tokens: Option<u64>,
16    pub cache_creation_tokens: Option<u64>,
17}
18
19/// Aggregated token usage for a session or sender.
20#[derive(Debug, Clone, Default)]
21pub struct UsageSummary {
22    /// Total tokens consumed across all recorded requests.
23    pub total_tokens: i64,
24    /// Estimated total cost in USD.
25    pub total_cost_usd: f64,
26    /// Number of API requests recorded.
27    pub request_count: i64,
28    /// Sum of input tokens across requests that reported a breakdown.
29    pub total_input_tokens: i64,
30    /// Sum of output tokens across requests that reported a breakdown.
31    pub total_output_tokens: i64,
32    /// Sum of prompt-cache reads across requests that reported a breakdown.
33    pub total_cache_read_tokens: i64,
34    /// Sum of prompt-cache creations across requests that reported a breakdown.
35    pub total_cache_creation_tokens: i64,
36}
37
38impl Store {
39    /// Record token usage for a completed API request, total tokens only.
40    ///
41    /// Thin wrapper over [`Store::record_usage_full`] for callers that do not
42    /// have a per-dimension breakdown. Cost is estimated using known per-model
43    /// pricing; unrecognized models record cost as 0.0.
44    pub async fn record_usage(
45        &self,
46        sender_id: &str,
47        session_id: &str,
48        tokens: u64,
49        model: &str,
50    ) -> Result<(), KernexError> {
51        self.record_usage_full(
52            sender_id,
53            session_id,
54            tokens,
55            model,
56            UsageBreakdown::default(),
57        )
58        .await
59    }
60
61    /// Record token usage with a per-dimension breakdown.
62    ///
63    /// `tokens` is the authoritative total used for cost estimation and
64    /// summary aggregation. The breakdown columns are stored verbatim and
65    /// surface in [`UsageSummary`] for cost telemetry (e.g. cache hit ratio).
66    pub async fn record_usage_full(
67        &self,
68        sender_id: &str,
69        session_id: &str,
70        tokens: u64,
71        model: &str,
72        breakdown: UsageBreakdown,
73    ) -> Result<(), KernexError> {
74        let cost = pricing_for(model)
75            .map(|p| p.estimate_cost(tokens))
76            .unwrap_or(0.0);
77
78        sqlx::query(
79            "INSERT INTO token_usage (
80                 sender_id, session_id, model, tokens, cost_usd,
81                 input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens
82             )
83             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
84        )
85        .bind(sender_id)
86        .bind(session_id)
87        .bind(model)
88        .bind(tokens as i64)
89        .bind(cost)
90        .bind(breakdown.input_tokens.map(|v| v as i64))
91        .bind(breakdown.output_tokens.map(|v| v as i64))
92        .bind(breakdown.cache_read_tokens.map(|v| v as i64))
93        .bind(breakdown.cache_creation_tokens.map(|v| v as i64))
94        .execute(&self.pool)
95        .await
96        .map_err(|e| KernexError::Store(format!("failed to record token usage: {e}")))?;
97
98        Ok(())
99    }
100
101    /// Get aggregated token usage for a session.
102    pub async fn get_session_usage(&self, session_id: &str) -> Result<UsageSummary, KernexError> {
103        let row: Option<(i64, f64, i64, i64, i64, i64, i64)> = sqlx::query_as(
104            "SELECT
105                 COALESCE(SUM(tokens), 0),
106                 COALESCE(SUM(cost_usd), 0.0),
107                 COUNT(*),
108                 COALESCE(SUM(input_tokens), 0),
109                 COALESCE(SUM(output_tokens), 0),
110                 COALESCE(SUM(cache_read_tokens), 0),
111                 COALESCE(SUM(cache_creation_tokens), 0)
112             FROM token_usage WHERE session_id = ?",
113        )
114        .bind(session_id)
115        .fetch_optional(&self.pool)
116        .await
117        .map_err(|e| KernexError::Store(format!("failed to query session usage: {e}")))?;
118
119        let (
120            total_tokens,
121            total_cost_usd,
122            request_count,
123            total_input_tokens,
124            total_output_tokens,
125            total_cache_read_tokens,
126            total_cache_creation_tokens,
127        ) = row.unwrap_or((0, 0.0, 0, 0, 0, 0, 0));
128
129        Ok(UsageSummary {
130            total_tokens,
131            total_cost_usd,
132            request_count,
133            total_input_tokens,
134            total_output_tokens,
135            total_cache_read_tokens,
136            total_cache_creation_tokens,
137        })
138    }
139
140    /// Get aggregated token usage across all sessions in the store.
141    ///
142    /// Useful for project-wide cost reporting (e.g. the kx `/cost`
143    /// command) when callers do not maintain a stable session id.
144    pub async fn get_total_usage(&self) -> Result<UsageSummary, KernexError> {
145        let row: Option<(i64, f64, i64, i64, i64, i64, i64)> = sqlx::query_as(
146            "SELECT
147                 COALESCE(SUM(tokens), 0),
148                 COALESCE(SUM(cost_usd), 0.0),
149                 COUNT(*),
150                 COALESCE(SUM(input_tokens), 0),
151                 COALESCE(SUM(output_tokens), 0),
152                 COALESCE(SUM(cache_read_tokens), 0),
153                 COALESCE(SUM(cache_creation_tokens), 0)
154             FROM token_usage",
155        )
156        .fetch_optional(&self.pool)
157        .await
158        .map_err(|e| KernexError::Store(format!("failed to query total usage: {e}")))?;
159
160        let (
161            total_tokens,
162            total_cost_usd,
163            request_count,
164            total_input_tokens,
165            total_output_tokens,
166            total_cache_read_tokens,
167            total_cache_creation_tokens,
168        ) = row.unwrap_or((0, 0.0, 0, 0, 0, 0, 0));
169
170        Ok(UsageSummary {
171            total_tokens,
172            total_cost_usd,
173            request_count,
174            total_input_tokens,
175            total_output_tokens,
176            total_cache_read_tokens,
177            total_cache_creation_tokens,
178        })
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use kernex_core::config::MemoryConfig;
186
187    async fn make_store() -> Store {
188        let config = MemoryConfig {
189            db_path: ":memory:".to_string(),
190            ..Default::default()
191        };
192        Store::new(&config).await.unwrap()
193    }
194
195    #[tokio::test]
196    async fn test_record_and_get_usage() {
197        let store = make_store().await;
198        store
199            .record_usage("user-1", "sess-abc", 1000, "claude-sonnet-4-6")
200            .await
201            .unwrap();
202        store
203            .record_usage("user-1", "sess-abc", 500, "claude-sonnet-4-6")
204            .await
205            .unwrap();
206
207        let summary = store.get_session_usage("sess-abc").await.unwrap();
208        assert_eq!(summary.total_tokens, 1500);
209        assert_eq!(summary.request_count, 2);
210        assert!(summary.total_cost_usd > 0.0);
211    }
212
213    #[tokio::test]
214    async fn test_get_usage_empty_session() {
215        let store = make_store().await;
216        let summary = store.get_session_usage("sess-nonexistent").await.unwrap();
217        assert_eq!(summary.total_tokens, 0);
218        assert_eq!(summary.request_count, 0);
219        assert_eq!(summary.total_cost_usd, 0.0);
220    }
221
222    #[tokio::test]
223    async fn test_record_usage_unknown_model_zero_cost() {
224        let store = make_store().await;
225        store
226            .record_usage("user-1", "sess-local", 2000, "llama3.2")
227            .await
228            .unwrap();
229
230        let summary = store.get_session_usage("sess-local").await.unwrap();
231        assert_eq!(summary.total_tokens, 2000);
232        assert_eq!(summary.total_cost_usd, 0.0);
233    }
234
235    #[tokio::test]
236    async fn test_usage_isolated_by_session() {
237        let store = make_store().await;
238        store
239            .record_usage("user-1", "sess-1", 100, "gpt-4o")
240            .await
241            .unwrap();
242        store
243            .record_usage("user-1", "sess-2", 200, "gpt-4o")
244            .await
245            .unwrap();
246
247        let s1 = store.get_session_usage("sess-1").await.unwrap();
248        let s2 = store.get_session_usage("sess-2").await.unwrap();
249        assert_eq!(s1.total_tokens, 100);
250        assert_eq!(s2.total_tokens, 200);
251    }
252
253    #[tokio::test]
254    async fn test_record_usage_full_persists_breakdown() {
255        let store = make_store().await;
256        store
257            .record_usage_full(
258                "user-1",
259                "sess-cache",
260                1500,
261                "claude-sonnet-4-6",
262                UsageBreakdown {
263                    input_tokens: Some(200),
264                    output_tokens: Some(100),
265                    cache_read_tokens: Some(1000),
266                    cache_creation_tokens: Some(200),
267                },
268            )
269            .await
270            .unwrap();
271        store
272            .record_usage_full(
273                "user-1",
274                "sess-cache",
275                500,
276                "claude-sonnet-4-6",
277                UsageBreakdown {
278                    input_tokens: Some(150),
279                    output_tokens: Some(50),
280                    cache_read_tokens: Some(300),
281                    cache_creation_tokens: None,
282                },
283            )
284            .await
285            .unwrap();
286
287        let summary = store.get_session_usage("sess-cache").await.unwrap();
288        assert_eq!(summary.total_tokens, 2000);
289        assert_eq!(summary.request_count, 2);
290        assert_eq!(summary.total_input_tokens, 350);
291        assert_eq!(summary.total_output_tokens, 150);
292        assert_eq!(summary.total_cache_read_tokens, 1300);
293        assert_eq!(summary.total_cache_creation_tokens, 200);
294    }
295
296    #[tokio::test]
297    async fn test_get_total_usage_aggregates_across_sessions() {
298        let store = make_store().await;
299        store
300            .record_usage_full(
301                "user-1",
302                "sess-a",
303                400,
304                "claude-sonnet-4-6",
305                UsageBreakdown {
306                    cache_read_tokens: Some(300),
307                    ..UsageBreakdown::default()
308                },
309            )
310            .await
311            .unwrap();
312        store
313            .record_usage_full(
314                "user-2",
315                "sess-b",
316                600,
317                "gpt-4o",
318                UsageBreakdown {
319                    cache_read_tokens: Some(100),
320                    ..UsageBreakdown::default()
321                },
322            )
323            .await
324            .unwrap();
325
326        let summary = store.get_total_usage().await.unwrap();
327        assert_eq!(summary.total_tokens, 1000);
328        assert_eq!(summary.request_count, 2);
329        assert_eq!(summary.total_cache_read_tokens, 400);
330    }
331
332    #[tokio::test]
333    async fn test_record_usage_wrapper_leaves_breakdown_null() {
334        let store = make_store().await;
335        store
336            .record_usage("user-1", "sess-plain", 700, "gpt-4o")
337            .await
338            .unwrap();
339
340        let summary = store.get_session_usage("sess-plain").await.unwrap();
341        assert_eq!(summary.total_tokens, 700);
342        // No breakdown was provided — aggregates remain at zero (NULLs sum to 0).
343        assert_eq!(summary.total_input_tokens, 0);
344        assert_eq!(summary.total_output_tokens, 0);
345        assert_eq!(summary.total_cache_read_tokens, 0);
346        assert_eq!(summary.total_cache_creation_tokens, 0);
347    }
348}