1use super::Store;
4use crate::error::MemoryError;
5use kernex_core::pricing::pricing_for;
6
7#[derive(Debug, Clone, Copy, Default)]
12pub struct UsageBreakdown {
13 pub input_tokens: Option<u64>,
14 pub output_tokens: Option<u64>,
15 pub cache_read_tokens: Option<u64>,
16 pub cache_creation_tokens: Option<u64>,
17}
18
19#[derive(Debug, Clone, Default)]
21pub struct UsageSummary {
22 pub total_tokens: i64,
24 pub total_cost_usd: f64,
26 pub request_count: i64,
28 pub total_input_tokens: i64,
30 pub total_output_tokens: i64,
32 pub total_cache_read_tokens: i64,
34 pub total_cache_creation_tokens: i64,
36}
37
38impl Store {
39 pub async fn record_usage(
45 &self,
46 sender_id: &str,
47 session_id: &str,
48 tokens: u64,
49 model: &str,
50 ) -> Result<(), MemoryError> {
51 self.record_usage_full(
52 sender_id,
53 session_id,
54 tokens,
55 model,
56 UsageBreakdown::default(),
57 )
58 .await
59 }
60
61 pub async fn record_usage_full(
67 &self,
68 sender_id: &str,
69 session_id: &str,
70 tokens: u64,
71 model: &str,
72 breakdown: UsageBreakdown,
73 ) -> Result<(), MemoryError> {
74 let cost = pricing_for(model)
75 .map(|p| p.estimate_cost(tokens))
76 .unwrap_or(0.0);
77
78 sqlx::query(
79 "INSERT INTO token_usage (
80 sender_id, session_id, model, tokens, cost_usd,
81 input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens
82 )
83 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
84 )
85 .bind(sender_id)
86 .bind(session_id)
87 .bind(model)
88 .bind(tokens as i64)
89 .bind(cost)
90 .bind(breakdown.input_tokens.map(|v| v as i64))
91 .bind(breakdown.output_tokens.map(|v| v as i64))
92 .bind(breakdown.cache_read_tokens.map(|v| v as i64))
93 .bind(breakdown.cache_creation_tokens.map(|v| v as i64))
94 .execute(&self.pool)
95 .await
96 .map_err(|e| MemoryError::sqlite("failed to record token usage", e))?;
97
98 Ok(())
99 }
100
101 pub async fn get_session_usage(&self, session_id: &str) -> Result<UsageSummary, MemoryError> {
103 let row: Option<(i64, f64, i64, i64, i64, i64, i64)> = sqlx::query_as(
104 "SELECT
105 COALESCE(SUM(tokens), 0),
106 COALESCE(SUM(cost_usd), 0.0),
107 COUNT(*),
108 COALESCE(SUM(input_tokens), 0),
109 COALESCE(SUM(output_tokens), 0),
110 COALESCE(SUM(cache_read_tokens), 0),
111 COALESCE(SUM(cache_creation_tokens), 0)
112 FROM token_usage WHERE session_id = ?",
113 )
114 .bind(session_id)
115 .fetch_optional(&self.pool)
116 .await
117 .map_err(|e| MemoryError::sqlite("failed to query session usage", e))?;
118
119 let (
120 total_tokens,
121 total_cost_usd,
122 request_count,
123 total_input_tokens,
124 total_output_tokens,
125 total_cache_read_tokens,
126 total_cache_creation_tokens,
127 ) = row.unwrap_or((0, 0.0, 0, 0, 0, 0, 0));
128
129 Ok(UsageSummary {
130 total_tokens,
131 total_cost_usd,
132 request_count,
133 total_input_tokens,
134 total_output_tokens,
135 total_cache_read_tokens,
136 total_cache_creation_tokens,
137 })
138 }
139
140 pub async fn get_total_usage(&self) -> Result<UsageSummary, MemoryError> {
145 let row: Option<(i64, f64, i64, i64, i64, i64, i64)> = sqlx::query_as(
146 "SELECT
147 COALESCE(SUM(tokens), 0),
148 COALESCE(SUM(cost_usd), 0.0),
149 COUNT(*),
150 COALESCE(SUM(input_tokens), 0),
151 COALESCE(SUM(output_tokens), 0),
152 COALESCE(SUM(cache_read_tokens), 0),
153 COALESCE(SUM(cache_creation_tokens), 0)
154 FROM token_usage",
155 )
156 .fetch_optional(&self.pool)
157 .await
158 .map_err(|e| MemoryError::sqlite("failed to query total usage", e))?;
159
160 let (
161 total_tokens,
162 total_cost_usd,
163 request_count,
164 total_input_tokens,
165 total_output_tokens,
166 total_cache_read_tokens,
167 total_cache_creation_tokens,
168 ) = row.unwrap_or((0, 0.0, 0, 0, 0, 0, 0));
169
170 Ok(UsageSummary {
171 total_tokens,
172 total_cost_usd,
173 request_count,
174 total_input_tokens,
175 total_output_tokens,
176 total_cache_read_tokens,
177 total_cache_creation_tokens,
178 })
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use kernex_core::config::MemoryConfig;
186
187 async fn make_store() -> Store {
188 let config = MemoryConfig {
189 db_path: ":memory:".to_string(),
190 ..Default::default()
191 };
192 Store::new(&config).await.unwrap()
193 }
194
195 #[tokio::test]
196 async fn test_record_and_get_usage() {
197 let store = make_store().await;
198 store
199 .record_usage("user-1", "sess-abc", 1000, "claude-sonnet-4-6")
200 .await
201 .unwrap();
202 store
203 .record_usage("user-1", "sess-abc", 500, "claude-sonnet-4-6")
204 .await
205 .unwrap();
206
207 let summary = store.get_session_usage("sess-abc").await.unwrap();
208 assert_eq!(summary.total_tokens, 1500);
209 assert_eq!(summary.request_count, 2);
210 assert!(summary.total_cost_usd > 0.0);
211 }
212
213 #[tokio::test]
214 async fn test_get_usage_empty_session() {
215 let store = make_store().await;
216 let summary = store.get_session_usage("sess-nonexistent").await.unwrap();
217 assert_eq!(summary.total_tokens, 0);
218 assert_eq!(summary.request_count, 0);
219 assert_eq!(summary.total_cost_usd, 0.0);
220 }
221
222 #[tokio::test]
223 async fn test_record_usage_unknown_model_zero_cost() {
224 let store = make_store().await;
225 store
226 .record_usage("user-1", "sess-local", 2000, "llama3.2")
227 .await
228 .unwrap();
229
230 let summary = store.get_session_usage("sess-local").await.unwrap();
231 assert_eq!(summary.total_tokens, 2000);
232 assert_eq!(summary.total_cost_usd, 0.0);
233 }
234
235 #[tokio::test]
236 async fn test_usage_isolated_by_session() {
237 let store = make_store().await;
238 store
239 .record_usage("user-1", "sess-1", 100, "gpt-4o")
240 .await
241 .unwrap();
242 store
243 .record_usage("user-1", "sess-2", 200, "gpt-4o")
244 .await
245 .unwrap();
246
247 let s1 = store.get_session_usage("sess-1").await.unwrap();
248 let s2 = store.get_session_usage("sess-2").await.unwrap();
249 assert_eq!(s1.total_tokens, 100);
250 assert_eq!(s2.total_tokens, 200);
251 }
252
253 #[tokio::test]
254 async fn test_record_usage_full_persists_breakdown() {
255 let store = make_store().await;
256 store
257 .record_usage_full(
258 "user-1",
259 "sess-cache",
260 1500,
261 "claude-sonnet-4-6",
262 UsageBreakdown {
263 input_tokens: Some(200),
264 output_tokens: Some(100),
265 cache_read_tokens: Some(1000),
266 cache_creation_tokens: Some(200),
267 },
268 )
269 .await
270 .unwrap();
271 store
272 .record_usage_full(
273 "user-1",
274 "sess-cache",
275 500,
276 "claude-sonnet-4-6",
277 UsageBreakdown {
278 input_tokens: Some(150),
279 output_tokens: Some(50),
280 cache_read_tokens: Some(300),
281 cache_creation_tokens: None,
282 },
283 )
284 .await
285 .unwrap();
286
287 let summary = store.get_session_usage("sess-cache").await.unwrap();
288 assert_eq!(summary.total_tokens, 2000);
289 assert_eq!(summary.request_count, 2);
290 assert_eq!(summary.total_input_tokens, 350);
291 assert_eq!(summary.total_output_tokens, 150);
292 assert_eq!(summary.total_cache_read_tokens, 1300);
293 assert_eq!(summary.total_cache_creation_tokens, 200);
294 }
295
296 #[tokio::test]
297 async fn test_get_total_usage_aggregates_across_sessions() {
298 let store = make_store().await;
299 store
300 .record_usage_full(
301 "user-1",
302 "sess-a",
303 400,
304 "claude-sonnet-4-6",
305 UsageBreakdown {
306 cache_read_tokens: Some(300),
307 ..UsageBreakdown::default()
308 },
309 )
310 .await
311 .unwrap();
312 store
313 .record_usage_full(
314 "user-2",
315 "sess-b",
316 600,
317 "gpt-4o",
318 UsageBreakdown {
319 cache_read_tokens: Some(100),
320 ..UsageBreakdown::default()
321 },
322 )
323 .await
324 .unwrap();
325
326 let summary = store.get_total_usage().await.unwrap();
327 assert_eq!(summary.total_tokens, 1000);
328 assert_eq!(summary.request_count, 2);
329 assert_eq!(summary.total_cache_read_tokens, 400);
330 }
331
332 #[tokio::test]
333 async fn test_record_usage_wrapper_leaves_breakdown_null() {
334 let store = make_store().await;
335 store
336 .record_usage("user-1", "sess-plain", 700, "gpt-4o")
337 .await
338 .unwrap();
339
340 let summary = store.get_session_usage("sess-plain").await.unwrap();
341 assert_eq!(summary.total_tokens, 700);
342 assert_eq!(summary.total_input_tokens, 0);
344 assert_eq!(summary.total_output_tokens, 0);
345 assert_eq!(summary.total_cache_read_tokens, 0);
346 assert_eq!(summary.total_cache_creation_tokens, 0);
347 }
348}