Skip to main content

starpod_session/
lib.rs

1use chrono::{DateTime, Duration, Utc};
2use serde::{Deserialize, Serialize};
3use sqlx::{Row, SqlitePool};
4use tracing::debug;
5use uuid::Uuid;
6
7use starpod_core::{StarpodError, Result};
8
9/// A channel that sessions are scoped to.
10#[derive(Debug, Clone, PartialEq)]
11pub enum Channel {
12    /// Explicit sessions — client controls lifecycle (web, REPL, CLI).
13    Main,
14    /// Time-gap sessions — new session after inactivity threshold (6h).
15    Telegram,
16    /// Time-gap sessions via email — new session after inactivity threshold (24h).
17    Email,
18}
19
20impl Channel {
21    pub fn as_str(&self) -> &'static str {
22        match self {
23            Channel::Main => "main",
24            Channel::Telegram => "telegram",
25            Channel::Email => "email",
26        }
27    }
28
29    pub fn from_channel_str(s: &str) -> Self {
30        match s {
31            "telegram" => Channel::Telegram,
32            "email" => Channel::Email,
33            _ => Channel::Main,
34        }
35    }
36
37}
38
39/// Decision from session resolution on whether to continue or start a new session.
40#[derive(Debug, Clone)]
41pub enum SessionDecision {
42    /// Continue an existing session (contains session ID).
43    Continue(String),
44    /// Start a new session. If a previous session was auto-closed (e.g. time-gap),
45    /// `closed_session_id` carries its ID so callers can export it.
46    New { closed_session_id: Option<String> },
47}
48
49/// Metadata for a session.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct SessionMeta {
52    pub id: String,
53    pub created_at: String,
54    pub last_message_at: String,
55    pub is_closed: bool,
56    pub summary: Option<String>,
57    pub title: Option<String>,
58    pub message_count: i64,
59    pub channel: String,
60    pub channel_session_key: Option<String>,
61    pub user_id: String,
62    pub is_read: bool,
63    /// Cron job name or `"__heartbeat__"` if this session was triggered by a scheduled job.
64    /// `None` for regular user sessions.
65    pub triggered_by: Option<String>,
66}
67
68/// A stored message in a session.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct SessionMessage {
71    pub id: i64,
72    pub session_id: String,
73    pub role: String,
74    pub content: String,
75    pub timestamp: String,
76}
77
78/// Usage record for a single turn.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct UsageRecord {
81    pub input_tokens: u64,
82    pub output_tokens: u64,
83    pub cache_read: u64,
84    pub cache_write: u64,
85    pub cost_usd: f64,
86    pub model: String,
87    pub user_id: String,
88}
89
90/// Manages session lifecycle — creation, channel-aware resolution, closure, and usage tracking.
91pub struct SessionManager {
92    pool: SqlitePool,
93}
94
95impl SessionManager {
96    /// Create a SessionManager from a shared pool.
97    ///
98    /// The pool should already have migrations applied (via `CoreDb`).
99    pub fn from_pool(pool: SqlitePool) -> Self {
100        Self { pool }
101    }
102
103    /// Resolve session for a given channel and key.
104    ///
105    /// - **Main** (explicit): always continues the matching open session if one exists.
106    /// - **Telegram** (time-gap): continues if last message was within the gap threshold,
107    ///   otherwise auto-closes the old session and returns `New`.
108    ///
109    /// `gap_minutes` is the inactivity gap from config. Pass `None` for explicit
110    /// channels that don't use time-gap sessions.
111    pub async fn resolve_session(
112        &self,
113        channel: &Channel,
114        key: &str,
115        gap_minutes: Option<i64>,
116    ) -> Result<SessionDecision> {
117        self.resolve_session_for_user(channel, key, gap_minutes, "admin").await
118    }
119
120    /// Resolve session for a given channel, key, and user.
121    pub async fn resolve_session_for_user(
122        &self,
123        channel: &Channel,
124        key: &str,
125        gap_minutes: Option<i64>,
126        user_id: &str,
127    ) -> Result<SessionDecision> {
128        let row = sqlx::query(
129            "SELECT id, last_message_at
130             FROM session_metadata
131             WHERE channel = ?1 AND channel_session_key = ?2 AND is_closed = 0 AND user_id = ?3
132             ORDER BY last_message_at DESC
133             LIMIT 1",
134        )
135        .bind(channel.as_str())
136        .bind(key)
137        .bind(user_id)
138        .fetch_optional(&self.pool)
139        .await
140        .map_err(|e| StarpodError::Database(format!("Resolve session query failed: {}", e)))?;
141
142        let row = match row {
143            Some(r) => r,
144            None => return Ok(SessionDecision::New { closed_session_id: None }),
145        };
146
147        let session_id: String = row.get("id");
148
149        // For explicit channels (no gap), always continue.
150        let gap_threshold = match gap_minutes {
151            None => {
152                debug!(session_id = %session_id, channel = %channel.as_str(), "Continuing session (explicit channel)");
153                return Ok(SessionDecision::Continue(session_id));
154            }
155            Some(gap) => gap,
156        };
157
158        // For time-gap channels, check inactivity
159        let last_msg_str: String = row.get("last_message_at");
160        let last_msg = DateTime::parse_from_rfc3339(&last_msg_str)
161            .map_err(|e| StarpodError::Session(format!("Bad timestamp: {}", e)))?
162            .with_timezone(&Utc);
163
164        let gap = Utc::now() - last_msg;
165
166        if gap < Duration::minutes(gap_threshold) {
167            debug!(session_id = %session_id, gap_mins = gap.num_minutes(), "Continuing session (within gap)");
168            Ok(SessionDecision::Continue(session_id))
169        } else {
170            debug!(session_id = %session_id, gap_mins = gap.num_minutes(), "Auto-closing session (gap exceeded)");
171            self.close_session(&session_id, "Auto-closed: inactivity").await?;
172            Ok(SessionDecision::New { closed_session_id: Some(session_id) })
173        }
174    }
175
176    /// Create a new session for a channel and key, returning its ID.
177    pub async fn create_session(
178        &self,
179        channel: &Channel,
180        key: &str,
181    ) -> Result<String> {
182        self.create_session_full(channel, key, "admin", None).await
183    }
184
185    /// Create a new session for a channel, key, and user, returning its ID.
186    pub async fn create_session_for_user(
187        &self,
188        channel: &Channel,
189        key: &str,
190        user_id: &str,
191    ) -> Result<String> {
192        self.create_session_full(channel, key, user_id, None).await
193    }
194
195    /// Create a new session with full metadata, including an optional trigger source.
196    ///
197    /// `triggered_by` records the cron job name (e.g. `"daily-digest"`) or
198    /// `"__heartbeat__"` when the session is created by the scheduler.
199    pub async fn create_session_full(
200        &self,
201        channel: &Channel,
202        key: &str,
203        user_id: &str,
204        triggered_by: Option<&str>,
205    ) -> Result<String> {
206        let id = Uuid::new_v4().to_string();
207        let now = Utc::now().to_rfc3339();
208
209        sqlx::query(
210            "INSERT INTO session_metadata (id, created_at, last_message_at, is_closed, message_count, channel, channel_session_key, user_id, triggered_by)
211             VALUES (?1, ?2, ?2, 0, 0, ?3, ?4, ?5, ?6)",
212        )
213        .bind(&id)
214        .bind(&now)
215        .bind(channel.as_str())
216        .bind(key)
217        .bind(user_id)
218        .bind(triggered_by)
219        .execute(&self.pool)
220        .await
221        .map_err(|e| StarpodError::Database(format!("Create session failed: {}", e)))?;
222
223        debug!(session_id = %id, channel = %channel.as_str(), key = %key, "Created new session");
224        Ok(id)
225    }
226
227    /// Mark a session as closed with an optional summary.
228    pub async fn close_session(&self, id: &str, summary: &str) -> Result<()> {
229        sqlx::query("UPDATE session_metadata SET is_closed = 1, summary = ?2 WHERE id = ?1")
230            .bind(id)
231            .bind(summary)
232            .execute(&self.pool)
233            .await
234            .map_err(|e| StarpodError::Database(format!("Close session failed: {}", e)))?;
235
236        debug!(session_id = %id, "Closed session");
237        Ok(())
238    }
239
240    /// Mark a session as read or unread.
241    pub async fn mark_read(&self, id: &str, is_read: bool) -> Result<()> {
242        sqlx::query("UPDATE session_metadata SET is_read = ?2 WHERE id = ?1")
243            .bind(id)
244            .bind(is_read as i64)
245            .execute(&self.pool)
246            .await
247            .map_err(|e| StarpodError::Database(format!("Mark read failed: {}", e)))?;
248        Ok(())
249    }
250
251    /// Update the last_message_at timestamp and increment message_count.
252    pub async fn touch_session(&self, id: &str) -> Result<()> {
253        let now = Utc::now().to_rfc3339();
254        sqlx::query(
255            "UPDATE session_metadata SET last_message_at = ?2, message_count = message_count + 1 WHERE id = ?1",
256        )
257        .bind(id)
258        .bind(&now)
259        .execute(&self.pool)
260        .await
261        .map_err(|e| StarpodError::Database(format!("Touch session failed: {}", e)))?;
262        Ok(())
263    }
264
265    /// Set the session title if it hasn't been set yet.
266    pub async fn set_title_if_empty(&self, id: &str, title: &str) -> Result<()> {
267        let truncated = if title.len() > 100 {
268            let mut end = 100;
269            while end > 0 && !title.is_char_boundary(end) {
270                end -= 1;
271            }
272            format!("{}...", &title[..end])
273        } else {
274            title.to_string()
275        };
276        sqlx::query(
277            "UPDATE session_metadata SET title = ?2 WHERE id = ?1 AND title IS NULL",
278        )
279        .bind(id)
280        .bind(&truncated)
281        .execute(&self.pool)
282        .await
283        .map_err(|e| StarpodError::Database(format!("Set title failed: {}", e)))?;
284        Ok(())
285    }
286
287    /// Record token usage for a turn.
288    pub async fn record_usage(&self, session_id: &str, usage: &UsageRecord, turn: u32) -> Result<()> {
289        let now = Utc::now().to_rfc3339();
290        sqlx::query(
291            "INSERT INTO usage_stats (session_id, turn, input_tokens, output_tokens, cache_read, cache_write, cost_usd, model, user_id, timestamp)
292             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
293        )
294        .bind(session_id)
295        .bind(turn as i64)
296        .bind(usage.input_tokens as i64)
297        .bind(usage.output_tokens as i64)
298        .bind(usage.cache_read as i64)
299        .bind(usage.cache_write as i64)
300        .bind(usage.cost_usd)
301        .bind(&usage.model)
302        .bind(&usage.user_id)
303        .bind(&now)
304        .execute(&self.pool)
305        .await
306        .map_err(|e| StarpodError::Database(format!("Record usage failed: {}", e)))?;
307
308        Ok(())
309    }
310
311    /// List sessions, most recent first.
312    pub async fn list_sessions(&self, limit: usize) -> Result<Vec<SessionMeta>> {
313        let rows = sqlx::query(
314            "SELECT id, created_at, last_message_at, is_closed, summary, title, message_count, channel, channel_session_key, user_id, is_read, triggered_by
315             FROM session_metadata
316             ORDER BY last_message_at DESC
317             LIMIT ?1",
318        )
319        .bind(limit as i64)
320        .fetch_all(&self.pool)
321        .await
322        .map_err(|e| StarpodError::Database(format!("Query failed: {}", e)))?;
323
324        let sessions: Vec<SessionMeta> = rows
325            .iter()
326            .map(|row| session_meta_from_row(row))
327            .collect();
328
329        Ok(sessions)
330    }
331
332    /// Get a specific session by ID.
333    pub async fn get_session(&self, id: &str) -> Result<Option<SessionMeta>> {
334        let row = sqlx::query(
335            "SELECT id, created_at, last_message_at, is_closed, summary, title, message_count, channel, channel_session_key, user_id, is_read, triggered_by
336             FROM session_metadata WHERE id = ?1",
337        )
338        .bind(id)
339        .fetch_optional(&self.pool)
340        .await
341        .map_err(|e| StarpodError::Database(format!("Get session failed: {}", e)))?;
342
343        Ok(row.map(|r| session_meta_from_row(&r)))
344    }
345
346    /// Get total usage stats for a session.
347    ///
348    /// `total_input_tokens` includes uncached, cache-read, and cache-write
349    /// tokens so the caller gets the true context size. Cache breakdown is
350    /// available via `total_cache_read` / `total_cache_write`.
351    pub async fn session_usage(&self, session_id: &str) -> Result<UsageSummary> {
352        let row = sqlx::query(
353            "SELECT COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti, COALESCE(SUM(output_tokens), 0) as to_,
354                    COALESCE(SUM(cache_read), 0) as cr, COALESCE(SUM(cache_write), 0) as cw,
355                    COALESCE(SUM(cost_usd), 0.0) as cost, COUNT(*) as turns
356             FROM usage_stats WHERE session_id = ?1",
357        )
358        .bind(session_id)
359        .fetch_one(&self.pool)
360        .await
361        .map_err(|e| StarpodError::Database(format!("Usage query failed: {}", e)))?;
362
363        Ok(UsageSummary {
364            total_input_tokens: row.get::<i64, _>("ti") as u64,
365            total_output_tokens: row.get::<i64, _>("to_") as u64,
366            total_cache_read: row.get::<i64, _>("cr") as u64,
367            total_cache_write: row.get::<i64, _>("cw") as u64,
368            total_cost_usd: row.get::<f64, _>("cost"),
369            total_turns: row.get::<i64, _>("turns") as u32,
370        })
371    }
372
373    /// Get a full cost overview with breakdowns by user and model.
374    ///
375    /// If `since` is provided (RFC 3339 timestamp), only usage after that time is included.
376    pub async fn cost_overview(&self, since: Option<&str>) -> Result<CostOverview> {
377        let (where_clause, bind_val) = match since {
378            Some(ts) => ("WHERE timestamp >= ?1", Some(ts)),
379            None => ("", None),
380        };
381
382        // Total
383        let total_sql = format!(
384            "SELECT COALESCE(SUM(cost_usd), 0.0) as cost,
385                    COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti,
386                    COALESCE(SUM(output_tokens), 0) as to_,
387                    COALESCE(SUM(cache_read), 0) as cr,
388                    COALESCE(SUM(cache_write), 0) as cw,
389                    COUNT(*) as turns
390             FROM usage_stats {}",
391            where_clause
392        );
393        let mut q = sqlx::query(&total_sql);
394        if let Some(ts) = bind_val {
395            q = q.bind(ts);
396        }
397        let total_row = q.fetch_one(&self.pool).await
398            .map_err(|e| StarpodError::Database(format!("Cost total query failed: {}", e)))?;
399
400        // By user
401        let user_sql = format!(
402            "SELECT user_id,
403                    COALESCE(SUM(cost_usd), 0.0) as cost,
404                    COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti,
405                    COALESCE(SUM(output_tokens), 0) as to_,
406                    COALESCE(SUM(cache_read), 0) as cr,
407                    COALESCE(SUM(cache_write), 0) as cw,
408                    COUNT(*) as turns
409             FROM usage_stats {} GROUP BY user_id ORDER BY cost DESC",
410            where_clause
411        );
412        let mut q = sqlx::query(&user_sql);
413        if let Some(ts) = bind_val {
414            q = q.bind(ts);
415        }
416        let user_rows = q.fetch_all(&self.pool).await
417            .map_err(|e| StarpodError::Database(format!("Cost by-user query failed: {}", e)))?;
418
419        // By model
420        let model_sql = format!(
421            "SELECT model,
422                    COALESCE(SUM(cost_usd), 0.0) as cost,
423                    COALESCE(SUM(input_tokens + cache_read + cache_write), 0) as ti,
424                    COALESCE(SUM(output_tokens), 0) as to_,
425                    COALESCE(SUM(cache_read), 0) as cr,
426                    COALESCE(SUM(cache_write), 0) as cw,
427                    COUNT(*) as turns
428             FROM usage_stats {} GROUP BY model ORDER BY cost DESC",
429            where_clause
430        );
431        let mut q = sqlx::query(&model_sql);
432        if let Some(ts) = bind_val {
433            q = q.bind(ts);
434        }
435        let model_rows = q.fetch_all(&self.pool).await
436            .map_err(|e| StarpodError::Database(format!("Cost by-model query failed: {}", e)))?;
437
438        // By day + model
439        let day_sql = format!(
440            "SELECT DATE(timestamp) as day, COALESCE(model, 'unknown') as model,
441                    COALESCE(SUM(cost_usd), 0.0) as cost
442             FROM usage_stats {} GROUP BY day, model ORDER BY day ASC",
443            where_clause
444        );
445        let mut q = sqlx::query(&day_sql);
446        if let Some(ts) = bind_val {
447            q = q.bind(ts);
448        }
449        let day_rows = q.fetch_all(&self.pool).await
450            .map_err(|e| StarpodError::Database(format!("Cost by-day query failed: {}", e)))?;
451
452        // Group day rows into DayCostSummary
453        let mut by_day: Vec<DayCostSummary> = Vec::new();
454        for row in &day_rows {
455            let date: String = row.get("day");
456            let model: String = row.get("model");
457            let cost: f64 = row.get::<f64, _>("cost");
458            if let Some(last) = by_day.last_mut().filter(|d| d.date == date) {
459                last.total_cost_usd += cost;
460                last.by_model.push(DayModelCost { model, cost_usd: cost });
461            } else {
462                by_day.push(DayCostSummary {
463                    date,
464                    total_cost_usd: cost,
465                    by_model: vec![DayModelCost { model, cost_usd: cost }],
466                });
467            }
468        }
469
470        // By tool (from session_messages)
471        let tool_sql = format!(
472            "SELECT json_extract(sm.content, '$.name') AS tool_name,
473                    COUNT(*) AS invocations,
474                    COALESCE(SUM(
475                      CASE WHEN tr.content IS NOT NULL
476                           AND json_extract(tr.content, '$.is_error') = 1
477                      THEN 1 ELSE 0 END
478                    ), 0) AS errors
479             FROM session_messages sm
480             LEFT JOIN session_messages tr
481               ON tr.session_id = sm.session_id
482               AND tr.role = 'tool_result'
483               AND json_extract(tr.content, '$.tool_use_id') = json_extract(sm.content, '$.id')
484             WHERE sm.role = 'tool_use'
485               {}
486             GROUP BY tool_name
487             ORDER BY invocations DESC",
488            if bind_val.is_some() { "AND sm.timestamp >= ?1" } else { "" }
489        );
490        let mut q = sqlx::query(&tool_sql);
491        if let Some(ts) = bind_val {
492            q = q.bind(ts);
493        }
494        let tool_rows = q.fetch_all(&self.pool).await
495            .map_err(|e| StarpodError::Database(format!("Cost by-tool query failed: {}", e)))?;
496
497        let by_tool: Vec<ToolUsageSummary> = tool_rows.iter().map(|r| ToolUsageSummary {
498            tool_name: r.try_get("tool_name").unwrap_or_else(|_| "unknown".to_string()),
499            invocations: r.get::<i64, _>("invocations") as u32,
500            errors: r.get::<i64, _>("errors") as u32,
501        }).collect();
502
503        Ok(CostOverview {
504            total_cost_usd: total_row.get::<f64, _>("cost"),
505            total_input_tokens: total_row.get::<i64, _>("ti") as u64,
506            total_output_tokens: total_row.get::<i64, _>("to_") as u64,
507            total_cache_read: total_row.get::<i64, _>("cr") as u64,
508            total_cache_write: total_row.get::<i64, _>("cw") as u64,
509            total_turns: total_row.get::<i64, _>("turns") as u32,
510            by_user: user_rows.iter().map(|r| UserCostSummary {
511                user_id: r.get("user_id"),
512                total_cost_usd: r.get::<f64, _>("cost"),
513                total_input_tokens: r.get::<i64, _>("ti") as u64,
514                total_output_tokens: r.get::<i64, _>("to_") as u64,
515                total_cache_read: r.get::<i64, _>("cr") as u64,
516                total_cache_write: r.get::<i64, _>("cw") as u64,
517                total_turns: r.get::<i64, _>("turns") as u32,
518            }).collect(),
519            by_model: model_rows.iter().map(|r| ModelCostSummary {
520                model: r.try_get("model").unwrap_or_else(|_| "unknown".to_string()),
521                total_cost_usd: r.get::<f64, _>("cost"),
522                total_input_tokens: r.get::<i64, _>("ti") as u64,
523                total_output_tokens: r.get::<i64, _>("to_") as u64,
524                total_cache_read: r.get::<i64, _>("cr") as u64,
525                total_cache_write: r.get::<i64, _>("cw") as u64,
526                total_turns: r.get::<i64, _>("turns") as u32,
527            }).collect(),
528            by_day,
529            by_tool,
530        })
531    }
532
533    /// Record a compaction event for a session.
534    pub async fn record_compaction(
535        &self,
536        session_id: &str,
537        trigger: &str,
538        pre_tokens: u64,
539        summary: &str,
540        messages_compacted: usize,
541    ) -> Result<()> {
542        let now = Utc::now().to_rfc3339();
543        sqlx::query(
544            "INSERT INTO compaction_log (session_id, timestamp, trigger, pre_tokens, summary, messages_compacted)
545             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
546        )
547        .bind(session_id)
548        .bind(&now)
549        .bind(trigger)
550        .bind(pre_tokens as i64)
551        .bind(summary)
552        .bind(messages_compacted as i64)
553        .execute(&self.pool)
554        .await
555        .map_err(|e| StarpodError::Database(format!("Record compaction failed: {}", e)))?;
556
557        debug!(session_id = %session_id, pre_tokens, messages_compacted, "Recorded compaction event");
558        Ok(())
559    }
560
561    /// Save a message to a session.
562    ///
563    /// When the first "user" message is saved, the session title is automatically
564    /// set to the message text (truncated to 100 chars).
565    pub async fn save_message(&self, session_id: &str, role: &str, content: &str) -> Result<()> {
566        let now = Utc::now().to_rfc3339();
567        sqlx::query(
568            "INSERT INTO session_messages (session_id, role, content, timestamp)
569             VALUES (?1, ?2, ?3, ?4)",
570        )
571        .bind(session_id)
572        .bind(role)
573        .bind(content)
574        .bind(&now)
575        .execute(&self.pool)
576        .await
577        .map_err(|e| StarpodError::Database(format!("Save message failed: {}", e)))?;
578
579        // Auto-set title from first user message
580        if role == "user" {
581            let title = if content.len() > 100 {
582                let mut end = 100;
583                while end > 0 && !content.is_char_boundary(end) {
584                    end -= 1;
585                }
586                format!("{}...", &content[..end])
587            } else {
588                content.to_string()
589            };
590            // Only set if title is currently NULL (first message)
591            sqlx::query(
592                "UPDATE session_metadata SET title = ?2 WHERE id = ?1 AND title IS NULL",
593            )
594            .bind(session_id)
595            .bind(&title)
596            .execute(&self.pool)
597            .await
598            .map_err(|e| StarpodError::Database(format!("Set title failed: {}", e)))?;
599        }
600
601        Ok(())
602    }
603
604    /// Get all messages for a session, ordered by ID.
605    pub async fn get_messages(&self, session_id: &str) -> Result<Vec<SessionMessage>> {
606        let rows = sqlx::query(
607            "SELECT id, session_id, role, content, timestamp
608             FROM session_messages
609             WHERE session_id = ?1
610             ORDER BY id ASC",
611        )
612        .bind(session_id)
613        .fetch_all(&self.pool)
614        .await
615        .map_err(|e| StarpodError::Database(format!("Get messages failed: {}", e)))?;
616
617        Ok(rows
618            .iter()
619            .map(|r| SessionMessage {
620                id: r.get("id"),
621                session_id: r.get("session_id"),
622                role: r.get("role"),
623                content: r.get("content"),
624                timestamp: r.get("timestamp"),
625            })
626            .collect())
627    }
628}
629
630/// Extract a SessionMeta from a database row.
631fn session_meta_from_row(row: &sqlx::sqlite::SqliteRow) -> SessionMeta {
632    SessionMeta {
633        id: row.get("id"),
634        created_at: row.get("created_at"),
635        last_message_at: row.get("last_message_at"),
636        is_closed: row.get::<i64, _>("is_closed") != 0,
637        summary: row.get("summary"),
638        title: row.get("title"),
639        message_count: row.get("message_count"),
640        channel: row.get("channel"),
641        channel_session_key: row.get("channel_session_key"),
642        user_id: row.try_get("user_id").unwrap_or_else(|_| "admin".to_string()),
643        is_read: row.try_get::<i64, _>("is_read").unwrap_or(1) != 0,
644        triggered_by: row.try_get("triggered_by").unwrap_or(None),
645    }
646}
647
648/// Aggregated usage summary for a session.
649///
650/// ## Token accounting
651///
652/// `total_input_tokens` is the **total** input context size across all turns,
653/// i.e. `SUM(input_tokens + cache_read + cache_write)` from the per-turn
654/// records. This is what the UI displays as "X in".
655///
656/// `total_cache_read` and `total_cache_write` are the cached subsets of
657/// that total — useful for showing cache efficiency (e.g. "2.1k cached").
658#[derive(Debug, Clone, Default, Serialize, Deserialize)]
659pub struct UsageSummary {
660    /// Total input tokens (uncached + cache_read + cache_write).
661    pub total_input_tokens: u64,
662    pub total_output_tokens: u64,
663    /// Tokens served from prompt cache.
664    pub total_cache_read: u64,
665    /// Tokens written to prompt cache.
666    pub total_cache_write: u64,
667    pub total_cost_usd: f64,
668    pub total_turns: u32,
669}
670
671/// Cost summary per user.
672#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct UserCostSummary {
674    pub user_id: String,
675    pub total_cost_usd: f64,
676    /// Total input tokens (uncached + cache_read + cache_write).
677    pub total_input_tokens: u64,
678    pub total_output_tokens: u64,
679    pub total_cache_read: u64,
680    pub total_cache_write: u64,
681    pub total_turns: u32,
682}
683
684/// Cost summary per model.
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub struct ModelCostSummary {
687    pub model: String,
688    pub total_cost_usd: f64,
689    /// Total input tokens (uncached + cache_read + cache_write).
690    pub total_input_tokens: u64,
691    pub total_output_tokens: u64,
692    pub total_cache_read: u64,
693    pub total_cache_write: u64,
694    pub total_turns: u32,
695}
696
697/// Cost summary for a single day, broken down by model.
698#[derive(Debug, Clone, Serialize, Deserialize)]
699pub struct DayCostSummary {
700    /// Date string (YYYY-MM-DD).
701    pub date: String,
702    /// Cost per model on this day.
703    pub by_model: Vec<DayModelCost>,
704    /// Total cost for this day.
705    pub total_cost_usd: f64,
706}
707
708/// Cost for a single model on a single day.
709#[derive(Debug, Clone, Serialize, Deserialize)]
710pub struct DayModelCost {
711    pub model: String,
712    pub cost_usd: f64,
713}
714
715/// Aggregated tool invocation statistics, grouped by tool name.
716///
717/// Extracted from `session_messages` rows with `role = "tool_use"` and
718/// `role = "tool_result"`.  The error count is derived by joining each
719/// `tool_use` to its matching `tool_result` and checking the `is_error` flag.
720#[derive(Debug, Clone, Serialize, Deserialize)]
721pub struct ToolUsageSummary {
722    /// The tool name (e.g. `"MemorySearch"`, `"VaultGet"`).
723    pub tool_name: String,
724    /// Total number of times this tool was invoked.
725    pub invocations: u32,
726    /// How many of those invocations resulted in an error.
727    pub errors: u32,
728}
729
730/// Full cost overview with breakdowns by user and model.
731///
732/// All `total_input_tokens` fields include cached tokens — see [`UsageSummary`]
733/// for the full accounting explanation.
734#[derive(Debug, Clone, Serialize, Deserialize)]
735pub struct CostOverview {
736    pub total_cost_usd: f64,
737    /// Total input tokens (uncached + cache_read + cache_write).
738    pub total_input_tokens: u64,
739    pub total_output_tokens: u64,
740    pub total_cache_read: u64,
741    pub total_cache_write: u64,
742    pub total_turns: u32,
743    pub by_user: Vec<UserCostSummary>,
744    pub by_model: Vec<ModelCostSummary>,
745    pub by_day: Vec<DayCostSummary>,
746    /// Tool invocation counts, sorted by invocations descending.
747    pub by_tool: Vec<ToolUsageSummary>,
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753    use tempfile::TempDir;
754
755    async fn setup() -> (TempDir, SessionManager) {
756        let tmp = TempDir::new().unwrap();
757        let db = starpod_db::CoreDb::in_memory().await.unwrap();
758        let mgr = SessionManager::from_pool(db.pool().clone());
759        (tmp, mgr)
760    }
761
762    #[tokio::test]
763    async fn test_create_and_get_session() {
764        let (_tmp, mgr) = setup().await;
765        let id = mgr.create_session(&Channel::Main, "test-key").await.unwrap();
766
767        let session = mgr.get_session(&id).await.unwrap().unwrap();
768        assert_eq!(session.id, id);
769        assert!(!session.is_closed);
770        assert_eq!(session.message_count, 0);
771        assert_eq!(session.channel, "main");
772        assert_eq!(session.channel_session_key.as_deref(), Some("test-key"));
773    }
774
775    #[tokio::test]
776    async fn test_close_session() {
777        let (_tmp, mgr) = setup().await;
778        let id = mgr.create_session(&Channel::Main, "test-key").await.unwrap();
779
780        mgr.close_session(&id, "Discussed Rust memory management").await.unwrap();
781
782        let session = mgr.get_session(&id).await.unwrap().unwrap();
783        assert!(session.is_closed);
784        assert_eq!(session.summary.as_deref(), Some("Discussed Rust memory management"));
785    }
786
787    #[tokio::test]
788    async fn test_touch_session() {
789        let (_tmp, mgr) = setup().await;
790        let id = mgr.create_session(&Channel::Main, "test-key").await.unwrap();
791
792        mgr.touch_session(&id).await.unwrap();
793        mgr.touch_session(&id).await.unwrap();
794
795        let session = mgr.get_session(&id).await.unwrap().unwrap();
796        assert_eq!(session.message_count, 2);
797    }
798
799    #[tokio::test]
800    async fn test_resolve_session_new_when_empty() {
801        let (_tmp, mgr) = setup().await;
802
803        match mgr.resolve_session(&Channel::Main, "some-key", None).await.unwrap() {
804            SessionDecision::New { .. } => {} // expected
805            SessionDecision::Continue(_) => panic!("Should be New when no sessions exist"),
806        }
807    }
808
809    #[tokio::test]
810    async fn test_resolve_session_continue_recent() {
811        let (_tmp, mgr) = setup().await;
812        let id = mgr.create_session(&Channel::Main, "key-1").await.unwrap();
813        mgr.touch_session(&id).await.unwrap();
814
815        match mgr.resolve_session(&Channel::Main, "key-1", None).await.unwrap() {
816            SessionDecision::Continue(sid) => assert_eq!(sid, id),
817            SessionDecision::New { .. } => panic!("Should continue recent session"),
818        }
819    }
820
821    #[tokio::test]
822    async fn test_resolve_session_new_when_closed() {
823        let (_tmp, mgr) = setup().await;
824        let id = mgr.create_session(&Channel::Main, "key-1").await.unwrap();
825        mgr.touch_session(&id).await.unwrap();
826        mgr.close_session(&id, "done").await.unwrap();
827
828        match mgr.resolve_session(&Channel::Main, "key-1", None).await.unwrap() {
829            SessionDecision::New { .. } => {} // expected
830            SessionDecision::Continue(_) => panic!("Should not continue closed session"),
831        }
832    }
833
834    #[tokio::test]
835    async fn test_list_sessions() {
836        let (_tmp, mgr) = setup().await;
837        mgr.create_session(&Channel::Main, "k1").await.unwrap();
838        mgr.create_session(&Channel::Main, "k2").await.unwrap();
839        mgr.create_session(&Channel::Telegram, "chat-1").await.unwrap();
840
841        let sessions = mgr.list_sessions(10).await.unwrap();
842        assert_eq!(sessions.len(), 3);
843    }
844
845    #[tokio::test]
846    async fn test_record_and_query_usage() {
847        let (_tmp, mgr) = setup().await;
848        let id = mgr.create_session(&Channel::Main, "test-key").await.unwrap();
849
850        mgr.record_usage(
851            &id,
852            &UsageRecord {
853                input_tokens: 1000,
854                output_tokens: 500,
855                cache_read: 200,
856                cache_write: 100,
857                cost_usd: 0.01,
858                model: "claude-sonnet".into(),
859                user_id: "admin".into(),
860            },
861            1,
862        )
863        .await
864        .unwrap();
865
866        mgr.record_usage(
867            &id,
868            &UsageRecord {
869                input_tokens: 800,
870                output_tokens: 400,
871                cache_read: 150,
872                cache_write: 50,
873                cost_usd: 0.008,
874                model: "claude-sonnet".into(),
875                user_id: "admin".into(),
876            },
877            2,
878        )
879        .await
880        .unwrap();
881
882        let summary = mgr.session_usage(&id).await.unwrap();
883        // total_input_tokens includes input_tokens + cache_read + cache_write
884        // Turn 1: 1000 + 200 + 100 = 1300, Turn 2: 800 + 150 + 50 = 1000
885        assert_eq!(summary.total_input_tokens, 2300);
886        assert_eq!(summary.total_output_tokens, 900);
887        assert_eq!(summary.total_turns, 2);
888        assert!((summary.total_cost_usd - 0.018).abs() < 0.001);
889    }
890
891    #[tokio::test]
892    async fn test_usage_cache_breakdown() {
893        let (_tmp, mgr) = setup().await;
894        let id = mgr.create_session(&Channel::Main, "cache-test").await.unwrap();
895
896        // Turn 1: cache miss — all tokens go to cache_write
897        mgr.record_usage(&id, &UsageRecord {
898            input_tokens: 500, output_tokens: 200, cache_read: 0, cache_write: 4000,
899            cost_usd: 0.05, model: "claude-sonnet".into(), user_id: "admin".into(),
900        }, 1).await.unwrap();
901
902        // Turn 2: cache hit — most tokens served from cache
903        mgr.record_usage(&id, &UsageRecord {
904            input_tokens: 100, output_tokens: 300, cache_read: 4000, cache_write: 0,
905            cost_usd: 0.01, model: "claude-sonnet".into(), user_id: "admin".into(),
906        }, 2).await.unwrap();
907
908        let summary = mgr.session_usage(&id).await.unwrap();
909
910        // total_input_tokens = (500 + 0 + 4000) + (100 + 4000 + 0) = 8600
911        assert_eq!(summary.total_input_tokens, 8600);
912        assert_eq!(summary.total_output_tokens, 500);
913        // Cache breakdown preserved separately
914        assert_eq!(summary.total_cache_read, 4000);
915        assert_eq!(summary.total_cache_write, 4000);
916        assert_eq!(summary.total_turns, 2);
917        assert!((summary.total_cost_usd - 0.06).abs() < 0.001);
918    }
919
920    // --- New channel-specific tests ---
921
922    #[tokio::test]
923    async fn test_main_explicit_sessions() {
924        let (_tmp, mgr) = setup().await;
925
926        // Create session for key "abc"
927        let id = mgr.create_session(&Channel::Main, "abc").await.unwrap();
928        mgr.touch_session(&id).await.unwrap();
929
930        // Same key → continue
931        match mgr.resolve_session(&Channel::Main, "abc", None).await.unwrap() {
932            SessionDecision::Continue(sid) => assert_eq!(sid, id),
933            SessionDecision::New { .. } => panic!("Should continue with same key"),
934        }
935
936        // Different key → new
937        match mgr.resolve_session(&Channel::Main, "xyz", None).await.unwrap() {
938            SessionDecision::New { .. } => {} // expected
939            SessionDecision::Continue(_) => panic!("Different key should get new session"),
940        }
941    }
942
943    #[tokio::test]
944    async fn test_telegram_time_gap() {
945        let (_tmp, mgr) = setup().await;
946        let gap = Some(360); // 6h, as configured via [channels.telegram] gap_minutes
947
948        // Create a telegram session
949        let id = mgr.create_session(&Channel::Telegram, "chat-123").await.unwrap();
950        mgr.touch_session(&id).await.unwrap();
951
952        // Within 6h → continue
953        match mgr.resolve_session(&Channel::Telegram, "chat-123", gap).await.unwrap() {
954            SessionDecision::Continue(sid) => assert_eq!(sid, id),
955            SessionDecision::New { .. } => panic!("Should continue within gap"),
956        }
957
958        // Manually set last_message_at to 7h ago to simulate inactivity
959        let old_time = (Utc::now() - Duration::hours(7)).to_rfc3339();
960        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
961            .bind(&old_time)
962            .bind(&id)
963            .execute(&mgr.pool)
964            .await
965            .unwrap();
966
967        // Beyond 6h → new (old session auto-closed)
968        match mgr.resolve_session(&Channel::Telegram, "chat-123", gap).await.unwrap() {
969            SessionDecision::New { .. } => {} // expected
970            SessionDecision::Continue(_) => panic!("Should start new session after 7h gap"),
971        }
972
973        // Verify old session was auto-closed
974        let old = mgr.get_session(&id).await.unwrap().unwrap();
975        assert!(old.is_closed);
976        assert_eq!(old.summary.as_deref(), Some("Auto-closed: inactivity"));
977    }
978
979    #[tokio::test]
980    async fn test_record_compaction() {
981        let (_tmp, mgr) = setup().await;
982        let id = mgr.create_session(&Channel::Main, "test-key").await.unwrap();
983
984        mgr.record_compaction(&id, "auto", 150_000, "Summary of old messages", 12)
985            .await
986            .unwrap();
987
988        // Verify via raw query
989        let row = sqlx::query(
990            "SELECT trigger, pre_tokens, summary, messages_compacted FROM compaction_log WHERE session_id = ?1",
991        )
992        .bind(&id)
993        .fetch_one(&mgr.pool)
994        .await
995        .unwrap();
996
997        assert_eq!(row.get::<String, _>("trigger"), "auto");
998        assert_eq!(row.get::<i64, _>("pre_tokens"), 150_000);
999        assert_eq!(row.get::<String, _>("summary"), "Summary of old messages");
1000        assert_eq!(row.get::<i64, _>("messages_compacted"), 12);
1001    }
1002
1003    #[tokio::test]
1004    async fn test_telegram_custom_gap_override() {
1005        let (_tmp, mgr) = setup().await;
1006
1007        // Create a Telegram session
1008        let id = mgr.create_session(&Channel::Telegram, "chat-gap").await.unwrap();
1009        mgr.touch_session(&id).await.unwrap();
1010
1011        // Set last_message_at to 2 hours ago
1012        let two_hours_ago = (Utc::now() - Duration::hours(2)).to_rfc3339();
1013        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1014            .bind(&two_hours_ago)
1015            .bind(&id)
1016            .execute(&mgr.pool)
1017            .await
1018            .unwrap();
1019
1020        // gap_minutes=60 (1h) — 2h ago exceeds 1h → should be New
1021        match mgr.resolve_session(&Channel::Telegram, "chat-gap", Some(60)).await.unwrap() {
1022            SessionDecision::New { .. } => {} // expected
1023            SessionDecision::Continue(_) => panic!("Should start new session when 2h > 1h gap"),
1024        }
1025
1026        // The old session was auto-closed, create a fresh one and backdate it again
1027        let id2 = mgr.create_session(&Channel::Telegram, "chat-gap").await.unwrap();
1028        mgr.touch_session(&id2).await.unwrap();
1029        let two_hours_ago = (Utc::now() - Duration::hours(2)).to_rfc3339();
1030        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1031            .bind(&two_hours_ago)
1032            .bind(&id2)
1033            .execute(&mgr.pool)
1034            .await
1035            .unwrap();
1036
1037        // gap_minutes=180 (3h) — 2h ago is within 3h → should Continue
1038        match mgr.resolve_session(&Channel::Telegram, "chat-gap", Some(180)).await.unwrap() {
1039            SessionDecision::Continue(sid) => assert_eq!(sid, id2),
1040            SessionDecision::New { .. } => panic!("Should continue session when 2h < 3h gap"),
1041        }
1042    }
1043
1044    #[tokio::test]
1045    async fn test_main_channel_ignores_gap() {
1046        let (_tmp, mgr) = setup().await;
1047
1048        // Create a Main session
1049        let id = mgr.create_session(&Channel::Main, "main-gap").await.unwrap();
1050        mgr.touch_session(&id).await.unwrap();
1051
1052        // Without a gap_minutes override, Main channel always continues (explicit)
1053        match mgr.resolve_session(&Channel::Main, "main-gap", None).await.unwrap() {
1054            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1055            SessionDecision::New { .. } => panic!("Main channel should always continue without gap override"),
1056        }
1057
1058        // Even backdating last_message_at to 24h ago, Main without gap override still continues
1059        let old = (Utc::now() - Duration::hours(24)).to_rfc3339();
1060        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1061            .bind(&old)
1062            .bind(&id)
1063            .execute(&mgr.pool)
1064            .await
1065            .unwrap();
1066
1067        match mgr.resolve_session(&Channel::Main, "main-gap", None).await.unwrap() {
1068            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1069            SessionDecision::New { .. } => panic!("Main channel should continue even with old last_message_at when gap_minutes is None"),
1070        }
1071    }
1072
1073    #[tokio::test]
1074    async fn test_channel_isolation() {
1075        let (_tmp, mgr) = setup().await;
1076
1077        // Create sessions with same key on different channels
1078        let main_id = mgr.create_session(&Channel::Main, "shared-key").await.unwrap();
1079        let tg_id = mgr.create_session(&Channel::Telegram, "shared-key").await.unwrap();
1080        mgr.touch_session(&main_id).await.unwrap();
1081        mgr.touch_session(&tg_id).await.unwrap();
1082
1083        // Each channel resolves to its own session
1084        match mgr.resolve_session(&Channel::Main, "shared-key", None).await.unwrap() {
1085            SessionDecision::Continue(sid) => assert_eq!(sid, main_id),
1086            SessionDecision::New { .. } => panic!("Main should find its session"),
1087        }
1088        match mgr.resolve_session(&Channel::Telegram, "shared-key", None).await.unwrap() {
1089            SessionDecision::Continue(sid) => assert_eq!(sid, tg_id),
1090            SessionDecision::New { .. } => panic!("Telegram should find its session"),
1091        }
1092    }
1093
1094    #[tokio::test]
1095    async fn test_auto_close_returns_closed_session_id() {
1096        let (_tmp, mgr) = setup().await;
1097        let gap = Some(60); // 1h
1098
1099        // Create and backdate a Telegram session
1100        let id = mgr.create_session(&Channel::Telegram, "export-test").await.unwrap();
1101        mgr.touch_session(&id).await.unwrap();
1102        mgr.save_message(&id, "user", "Hello!").await.unwrap();
1103        mgr.save_message(&id, "assistant", "Hi there!").await.unwrap();
1104
1105        let two_hours_ago = (Utc::now() - Duration::hours(2)).to_rfc3339();
1106        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1107            .bind(&two_hours_ago)
1108            .bind(&id)
1109            .execute(&mgr.pool)
1110            .await
1111            .unwrap();
1112
1113        // Resolve should return New with the closed session's ID
1114        match mgr.resolve_session(&Channel::Telegram, "export-test", gap).await.unwrap() {
1115            SessionDecision::New { closed_session_id } => {
1116                assert_eq!(closed_session_id, Some(id.clone()), "Should return the closed session ID");
1117            }
1118            SessionDecision::Continue(_) => panic!("Should start new session after 2h > 1h gap"),
1119        }
1120
1121        // First resolve with no prior session → New without closed ID
1122        match mgr.resolve_session(&Channel::Main, "fresh-key", None).await.unwrap() {
1123            SessionDecision::New { closed_session_id } => {
1124                assert!(closed_session_id.is_none(), "No prior session means no closed ID");
1125            }
1126            SessionDecision::Continue(_) => panic!("Should be new"),
1127        }
1128    }
1129
1130    #[tokio::test]
1131    async fn test_auto_close_closed_id_is_correct_session() {
1132        let (_tmp, mgr) = setup().await;
1133        let gap = Some(60); // 1h
1134
1135        // Create two Telegram sessions for different keys
1136        let id_a = mgr.create_session(&Channel::Telegram, "chat-a").await.unwrap();
1137        mgr.touch_session(&id_a).await.unwrap();
1138        mgr.save_message(&id_a, "user", "Message in chat A").await.unwrap();
1139        mgr.save_message(&id_a, "assistant", "Reply in chat A").await.unwrap();
1140
1141        let id_b = mgr.create_session(&Channel::Telegram, "chat-b").await.unwrap();
1142        mgr.touch_session(&id_b).await.unwrap();
1143        mgr.save_message(&id_b, "user", "Message in chat B").await.unwrap();
1144
1145        // Backdate only chat-a beyond the gap
1146        let old_time = (Utc::now() - Duration::hours(2)).to_rfc3339();
1147        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1148            .bind(&old_time)
1149            .bind(&id_a)
1150            .execute(&mgr.pool)
1151            .await
1152            .unwrap();
1153
1154        // Resolve chat-a → should auto-close and return its ID
1155        match mgr.resolve_session(&Channel::Telegram, "chat-a", gap).await.unwrap() {
1156            SessionDecision::New { closed_session_id } => {
1157                assert_eq!(
1158                    closed_session_id,
1159                    Some(id_a.clone()),
1160                    "closed_session_id must match the session that was auto-closed"
1161                );
1162            }
1163            SessionDecision::Continue(_) => panic!("Should start new session after gap"),
1164        }
1165
1166        // Verify the closed session's messages are still accessible
1167        let messages = mgr.get_messages(&id_a).await.unwrap();
1168        assert_eq!(messages.len(), 2);
1169        assert_eq!(messages[0].content, "Message in chat A");
1170        assert_eq!(messages[1].content, "Reply in chat A");
1171
1172        // Verify chat-b is unaffected (still open, still continuable)
1173        match mgr.resolve_session(&Channel::Telegram, "chat-b", gap).await.unwrap() {
1174            SessionDecision::Continue(sid) => assert_eq!(sid, id_b),
1175            SessionDecision::New { .. } => panic!("chat-b should still be continuable"),
1176        }
1177    }
1178
1179    #[tokio::test]
1180    async fn test_no_closed_id_for_main_channel() {
1181        let (_tmp, mgr) = setup().await;
1182
1183        // Create a Main session and backdate it far in the past
1184        let id = mgr.create_session(&Channel::Main, "main-key").await.unwrap();
1185        mgr.touch_session(&id).await.unwrap();
1186
1187        let old_time = (Utc::now() - Duration::hours(48)).to_rfc3339();
1188        sqlx::query("UPDATE session_metadata SET last_message_at = ?1 WHERE id = ?2")
1189            .bind(&old_time)
1190            .bind(&id)
1191            .execute(&mgr.pool)
1192            .await
1193            .unwrap();
1194
1195        // Main channel uses gap_minutes=None → never auto-closes
1196        match mgr.resolve_session(&Channel::Main, "main-key", None).await.unwrap() {
1197            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1198            SessionDecision::New { .. } => panic!("Main channel should never auto-close"),
1199        }
1200
1201        // Even with a fresh key (no session), New should have closed_session_id=None
1202        match mgr.resolve_session(&Channel::Main, "new-main-key", None).await.unwrap() {
1203            SessionDecision::New { closed_session_id } => {
1204                assert!(
1205                    closed_session_id.is_none(),
1206                    "Main channel should never produce a closed_session_id"
1207                );
1208            }
1209            SessionDecision::Continue(_) => panic!("No session for this key, should be New"),
1210        }
1211    }
1212
1213    #[tokio::test]
1214    async fn test_no_closed_id_when_session_manually_closed() {
1215        let (_tmp, mgr) = setup().await;
1216        let gap = Some(60); // 1h
1217
1218        // Create a Telegram session and manually close it
1219        let id = mgr.create_session(&Channel::Telegram, "manual-close").await.unwrap();
1220        mgr.touch_session(&id).await.unwrap();
1221        mgr.save_message(&id, "user", "Hello").await.unwrap();
1222        mgr.close_session(&id, "Manually closed by user").await.unwrap();
1223
1224        // Resolve should return New with closed_session_id=None because
1225        // there's no open session to auto-close
1226        match mgr.resolve_session(&Channel::Telegram, "manual-close", gap).await.unwrap() {
1227            SessionDecision::New { closed_session_id } => {
1228                assert!(
1229                    closed_session_id.is_none(),
1230                    "Manually closed session should not produce closed_session_id on resolve"
1231                );
1232            }
1233            SessionDecision::Continue(_) => panic!("Closed session should not be continued"),
1234        }
1235    }
1236
1237    #[tokio::test]
1238    async fn test_cost_overview_empty() {
1239        let (_tmp, mgr) = setup().await;
1240
1241        let overview = mgr.cost_overview(None).await.unwrap();
1242        assert_eq!(overview.total_cost_usd, 0.0);
1243        assert_eq!(overview.total_input_tokens, 0);
1244        assert_eq!(overview.total_output_tokens, 0);
1245        assert_eq!(overview.total_turns, 0);
1246        assert!(overview.by_user.is_empty());
1247        assert!(overview.by_model.is_empty());
1248    }
1249
1250    #[tokio::test]
1251    async fn test_cost_overview_by_user() {
1252        let (_tmp, mgr) = setup().await;
1253        let sid = mgr.create_session(&Channel::Main, "cost-test").await.unwrap();
1254
1255        // Record usage for two different users
1256        mgr.record_usage(&sid, &UsageRecord {
1257            input_tokens: 1000, output_tokens: 500, cache_read: 0, cache_write: 0,
1258            cost_usd: 0.05, model: "claude-sonnet".into(), user_id: "alice".into(),
1259        }, 1).await.unwrap();
1260
1261        mgr.record_usage(&sid, &UsageRecord {
1262            input_tokens: 2000, output_tokens: 800, cache_read: 0, cache_write: 0,
1263            cost_usd: 0.10, model: "claude-sonnet".into(), user_id: "bob".into(),
1264        }, 2).await.unwrap();
1265
1266        mgr.record_usage(&sid, &UsageRecord {
1267            input_tokens: 500, output_tokens: 200, cache_read: 0, cache_write: 0,
1268            cost_usd: 0.02, model: "claude-haiku".into(), user_id: "alice".into(),
1269        }, 3).await.unwrap();
1270
1271        let overview = mgr.cost_overview(None).await.unwrap();
1272
1273        // Totals
1274        assert_eq!(overview.total_turns, 3);
1275        assert!((overview.total_cost_usd - 0.17).abs() < 0.001);
1276        assert_eq!(overview.total_input_tokens, 3500);
1277        assert_eq!(overview.total_output_tokens, 1500);
1278
1279        // By user (sorted by cost desc)
1280        assert_eq!(overview.by_user.len(), 2);
1281        assert_eq!(overview.by_user[0].user_id, "bob");
1282        assert!((overview.by_user[0].total_cost_usd - 0.10).abs() < 0.001);
1283        assert_eq!(overview.by_user[0].total_turns, 1);
1284        assert_eq!(overview.by_user[1].user_id, "alice");
1285        assert!((overview.by_user[1].total_cost_usd - 0.07).abs() < 0.001);
1286        assert_eq!(overview.by_user[1].total_turns, 2);
1287
1288        // By model (sorted by cost desc)
1289        assert_eq!(overview.by_model.len(), 2);
1290        assert_eq!(overview.by_model[0].model, "claude-sonnet");
1291        assert!((overview.by_model[0].total_cost_usd - 0.15).abs() < 0.001);
1292        assert_eq!(overview.by_model[1].model, "claude-haiku");
1293        assert!((overview.by_model[1].total_cost_usd - 0.02).abs() < 0.001);
1294    }
1295
1296    #[tokio::test]
1297    async fn test_cost_overview_since_filter() {
1298        let (_tmp, mgr) = setup().await;
1299        let sid = mgr.create_session(&Channel::Main, "cost-filter").await.unwrap();
1300
1301        // Record usage now
1302        mgr.record_usage(&sid, &UsageRecord {
1303            input_tokens: 1000, output_tokens: 500, cache_read: 0, cache_write: 0,
1304            cost_usd: 0.05, model: "claude-sonnet".into(), user_id: "admin".into(),
1305        }, 1).await.unwrap();
1306
1307        // "Since" far in the future should return nothing
1308        let future = (Utc::now() + Duration::hours(1)).to_rfc3339();
1309        let overview = mgr.cost_overview(Some(&future)).await.unwrap();
1310        assert_eq!(overview.total_turns, 0);
1311        assert_eq!(overview.total_cost_usd, 0.0);
1312
1313        // "Since" far in the past should return everything
1314        let past = (Utc::now() - Duration::days(365)).to_rfc3339();
1315        let overview = mgr.cost_overview(Some(&past)).await.unwrap();
1316        assert_eq!(overview.total_turns, 1);
1317        assert!((overview.total_cost_usd - 0.05).abs() < 0.001);
1318    }
1319
1320    #[tokio::test]
1321    async fn test_cost_overview_user_id_recorded() {
1322        let (_tmp, mgr) = setup().await;
1323        let sid = mgr.create_session(&Channel::Main, "uid-test").await.unwrap();
1324
1325        mgr.record_usage(&sid, &UsageRecord {
1326            input_tokens: 100, output_tokens: 50, cache_read: 0, cache_write: 0,
1327            cost_usd: 0.01, model: "m".into(), user_id: "user-42".into(),
1328        }, 1).await.unwrap();
1329
1330        let overview = mgr.cost_overview(None).await.unwrap();
1331        assert_eq!(overview.by_user.len(), 1);
1332        assert_eq!(overview.by_user[0].user_id, "user-42");
1333        assert_eq!(overview.by_user[0].total_input_tokens, 100);
1334        assert_eq!(overview.by_user[0].total_output_tokens, 50);
1335    }
1336
1337    #[tokio::test]
1338    async fn test_cost_overview_cache_breakdown() {
1339        let (_tmp, mgr) = setup().await;
1340        let sid = mgr.create_session(&Channel::Main, "cache-cost").await.unwrap();
1341
1342        // Alice: cache miss (writes to cache)
1343        mgr.record_usage(&sid, &UsageRecord {
1344            input_tokens: 200, output_tokens: 100, cache_read: 0, cache_write: 3000,
1345            cost_usd: 0.04, model: "claude-sonnet".into(), user_id: "alice".into(),
1346        }, 1).await.unwrap();
1347
1348        // Alice: cache hit (reads from cache)
1349        mgr.record_usage(&sid, &UsageRecord {
1350            input_tokens: 50, output_tokens: 150, cache_read: 3000, cache_write: 0,
1351            cost_usd: 0.01, model: "claude-sonnet".into(), user_id: "alice".into(),
1352        }, 2).await.unwrap();
1353
1354        // Bob: no caching
1355        mgr.record_usage(&sid, &UsageRecord {
1356            input_tokens: 800, output_tokens: 400, cache_read: 0, cache_write: 0,
1357            cost_usd: 0.03, model: "claude-haiku".into(), user_id: "bob".into(),
1358        }, 3).await.unwrap();
1359
1360        let overview = mgr.cost_overview(None).await.unwrap();
1361
1362        // Totals: input = (200+0+3000) + (50+3000+0) + (800+0+0) = 7050
1363        assert_eq!(overview.total_input_tokens, 7050);
1364        assert_eq!(overview.total_output_tokens, 650);
1365        assert_eq!(overview.total_cache_read, 3000);
1366        assert_eq!(overview.total_cache_write, 3000);
1367
1368        // By user: alice first (higher cost)
1369        assert_eq!(overview.by_user.len(), 2);
1370        let alice = overview.by_user.iter().find(|u| u.user_id == "alice").unwrap();
1371        assert_eq!(alice.total_input_tokens, 6250); // (200+3000) + (50+3000)
1372        assert_eq!(alice.total_cache_read, 3000);
1373        assert_eq!(alice.total_cache_write, 3000);
1374
1375        let bob = overview.by_user.iter().find(|u| u.user_id == "bob").unwrap();
1376        assert_eq!(bob.total_input_tokens, 800);
1377        assert_eq!(bob.total_cache_read, 0);
1378        assert_eq!(bob.total_cache_write, 0);
1379
1380        // By model
1381        let sonnet = overview.by_model.iter().find(|m| m.model == "claude-sonnet").unwrap();
1382        assert_eq!(sonnet.total_cache_read, 3000);
1383        assert_eq!(sonnet.total_cache_write, 3000);
1384
1385        let haiku = overview.by_model.iter().find(|m| m.model == "claude-haiku").unwrap();
1386        assert_eq!(haiku.total_cache_read, 0);
1387        assert_eq!(haiku.total_cache_write, 0);
1388    }
1389
1390    // ── Read/unread state tests ────────────────────────────────────────
1391
1392    #[tokio::test]
1393    async fn test_new_session_is_read_by_default() {
1394        let (_tmp, mgr) = setup().await;
1395        let id = mgr.create_session(&Channel::Main, "key").await.unwrap();
1396
1397        let session = mgr.get_session(&id).await.unwrap().unwrap();
1398        assert!(session.is_read, "New sessions should default to is_read=true");
1399    }
1400
1401    #[tokio::test]
1402    async fn test_mark_read_false() {
1403        let (_tmp, mgr) = setup().await;
1404        let id = mgr.create_session(&Channel::Main, "key").await.unwrap();
1405
1406        mgr.mark_read(&id, false).await.unwrap();
1407
1408        let session = mgr.get_session(&id).await.unwrap().unwrap();
1409        assert!(!session.is_read, "Session should be unread after mark_read(false)");
1410    }
1411
1412    #[tokio::test]
1413    async fn test_mark_read_true() {
1414        let (_tmp, mgr) = setup().await;
1415        let id = mgr.create_session(&Channel::Main, "key").await.unwrap();
1416
1417        // Mark unread, then mark read again
1418        mgr.mark_read(&id, false).await.unwrap();
1419        mgr.mark_read(&id, true).await.unwrap();
1420
1421        let session = mgr.get_session(&id).await.unwrap().unwrap();
1422        assert!(session.is_read, "Session should be read after mark_read(true)");
1423    }
1424
1425    #[tokio::test]
1426    async fn test_list_sessions_includes_is_read() {
1427        let (_tmp, mgr) = setup().await;
1428        let id1 = mgr.create_session(&Channel::Main, "key1").await.unwrap();
1429        let id2 = mgr.create_session(&Channel::Main, "key2").await.unwrap();
1430
1431        mgr.mark_read(&id1, false).await.unwrap();
1432
1433        let sessions = mgr.list_sessions(10).await.unwrap();
1434        let s1 = sessions.iter().find(|s| s.id == id1).unwrap();
1435        let s2 = sessions.iter().find(|s| s.id == id2).unwrap();
1436
1437        assert!(!s1.is_read, "Session 1 should be unread");
1438        assert!(s2.is_read, "Session 2 should still be read");
1439    }
1440
1441    #[tokio::test]
1442    async fn test_mark_read_nonexistent_session_succeeds() {
1443        let (_tmp, mgr) = setup().await;
1444        // Should not error — just a no-op UPDATE matching zero rows
1445        mgr.mark_read("nonexistent-id", true).await.unwrap();
1446    }
1447
1448    // --- Email channel tests ---
1449
1450    #[test]
1451    fn test_email_channel_as_str() {
1452        assert_eq!(Channel::Email.as_str(), "email");
1453    }
1454
1455    #[test]
1456    fn test_email_channel_from_str() {
1457        assert_eq!(Channel::from_channel_str("email"), Channel::Email);
1458    }
1459
1460    #[test]
1461    fn test_unknown_channel_defaults_to_main() {
1462        assert_eq!(Channel::from_channel_str("unknown"), Channel::Main);
1463        assert_eq!(Channel::from_channel_str(""), Channel::Main);
1464    }
1465
1466    #[tokio::test]
1467    async fn test_create_email_session() {
1468        let (_tmp, mgr) = setup().await;
1469        let id = mgr.create_session(&Channel::Email, "user@example.com").await.unwrap();
1470
1471        let session = mgr.get_session(&id).await.unwrap().unwrap();
1472        assert_eq!(session.channel, "email");
1473        assert_eq!(session.channel_session_key.as_deref(), Some("user@example.com"));
1474    }
1475
1476    #[tokio::test]
1477    async fn test_resolve_email_session_continues_for_same_sender() {
1478        let (_tmp, mgr) = setup().await;
1479        let id = mgr.create_session(&Channel::Email, "sender@test.com").await.unwrap();
1480        mgr.touch_session(&id).await.unwrap();
1481
1482        match mgr.resolve_session(&Channel::Email, "sender@test.com", None).await.unwrap() {
1483            SessionDecision::Continue(sid) => assert_eq!(sid, id),
1484            SessionDecision::New { .. } => panic!("Should continue recent email session"),
1485        }
1486    }
1487
1488    #[tokio::test]
1489    async fn test_resolve_email_session_new_for_different_sender() {
1490        let (_tmp, mgr) = setup().await;
1491        let id = mgr.create_session(&Channel::Email, "sender-a@test.com").await.unwrap();
1492        mgr.touch_session(&id).await.unwrap();
1493
1494        match mgr.resolve_session(&Channel::Email, "sender-b@test.com", None).await.unwrap() {
1495            SessionDecision::New { .. } => {} // expected — different sender
1496            SessionDecision::Continue(_) => panic!("Should not continue session for different sender"),
1497        }
1498    }
1499
1500    #[tokio::test]
1501    async fn test_email_and_telegram_sessions_are_separate() {
1502        let (_tmp, mgr) = setup().await;
1503        let email_id = mgr.create_session(&Channel::Email, "user@test.com").await.unwrap();
1504        let tg_id = mgr.create_session(&Channel::Telegram, "user@test.com").await.unwrap();
1505
1506        assert_ne!(email_id, tg_id);
1507
1508        // Each channel resolves independently
1509        mgr.touch_session(&email_id).await.unwrap();
1510        mgr.touch_session(&tg_id).await.unwrap();
1511
1512        match mgr.resolve_session(&Channel::Email, "user@test.com", None).await.unwrap() {
1513            SessionDecision::Continue(sid) => assert_eq!(sid, email_id),
1514            SessionDecision::New { .. } => panic!("Should continue email session"),
1515        }
1516        match mgr.resolve_session(&Channel::Telegram, "user@test.com", None).await.unwrap() {
1517            SessionDecision::Continue(sid) => assert_eq!(sid, tg_id),
1518            SessionDecision::New { .. } => panic!("Should continue telegram session"),
1519        }
1520    }
1521
1522    // ── Tool usage stats tests ────────────────────────────────────────
1523
1524    #[tokio::test]
1525    async fn test_cost_overview_by_tool_empty() {
1526        let (_tmp, mgr) = setup().await;
1527
1528        let overview = mgr.cost_overview(None).await.unwrap();
1529        assert!(overview.by_tool.is_empty(), "No tool messages → empty by_tool");
1530    }
1531
1532    #[tokio::test]
1533    async fn test_cost_overview_by_tool_counts() {
1534        let (_tmp, mgr) = setup().await;
1535        let sid = mgr.create_session(&Channel::Main, "tool-test").await.unwrap();
1536
1537        // Simulate 3 MemorySearch invocations (all successful)
1538        for i in 0..3 {
1539            let tool_use = serde_json::json!({
1540                "type": "tool_use",
1541                "id": format!("tu_mem_{i}"),
1542                "name": "MemorySearch",
1543                "input": {"query": "test"}
1544            });
1545            mgr.save_message(&sid, "tool_use", &tool_use.to_string()).await.unwrap();
1546
1547            let tool_result = serde_json::json!({
1548                "type": "tool_result",
1549                "tool_use_id": format!("tu_mem_{i}"),
1550                "content": "some result",
1551                "is_error": false
1552            });
1553            mgr.save_message(&sid, "tool_result", &tool_result.to_string()).await.unwrap();
1554        }
1555
1556        // Simulate 2 VaultGet invocations: 1 success, 1 error
1557        let tool_use = serde_json::json!({
1558            "type": "tool_use", "id": "tu_vault_0", "name": "VaultGet",
1559            "input": {"key": "api_key"}
1560        });
1561        mgr.save_message(&sid, "tool_use", &tool_use.to_string()).await.unwrap();
1562        let tool_result = serde_json::json!({
1563            "type": "tool_result", "tool_use_id": "tu_vault_0",
1564            "content": "secret-value", "is_error": false
1565        });
1566        mgr.save_message(&sid, "tool_result", &tool_result.to_string()).await.unwrap();
1567
1568        let tool_use = serde_json::json!({
1569            "type": "tool_use", "id": "tu_vault_1", "name": "VaultGet",
1570            "input": {"key": "missing"}
1571        });
1572        mgr.save_message(&sid, "tool_use", &tool_use.to_string()).await.unwrap();
1573        let tool_result = serde_json::json!({
1574            "type": "tool_result", "tool_use_id": "tu_vault_1",
1575            "content": "key not found", "is_error": true
1576        });
1577        mgr.save_message(&sid, "tool_result", &tool_result.to_string()).await.unwrap();
1578
1579        let overview = mgr.cost_overview(None).await.unwrap();
1580
1581        // Sorted by invocations DESC: MemorySearch(3), VaultGet(2)
1582        assert_eq!(overview.by_tool.len(), 2);
1583        assert_eq!(overview.by_tool[0].tool_name, "MemorySearch");
1584        assert_eq!(overview.by_tool[0].invocations, 3);
1585        assert_eq!(overview.by_tool[0].errors, 0);
1586        assert_eq!(overview.by_tool[1].tool_name, "VaultGet");
1587        assert_eq!(overview.by_tool[1].invocations, 2);
1588        assert_eq!(overview.by_tool[1].errors, 1);
1589    }
1590
1591    #[tokio::test]
1592    async fn test_cost_overview_by_tool_since_filter() {
1593        let (_tmp, mgr) = setup().await;
1594        let sid = mgr.create_session(&Channel::Main, "tool-filter").await.unwrap();
1595
1596        // Save a tool_use message now
1597        let tool_use = serde_json::json!({
1598            "type": "tool_use", "id": "tu_1", "name": "CronList", "input": {}
1599        });
1600        mgr.save_message(&sid, "tool_use", &tool_use.to_string()).await.unwrap();
1601
1602        // "Since" far in the future should exclude it
1603        let future = (Utc::now() + Duration::hours(1)).to_rfc3339();
1604        let overview = mgr.cost_overview(Some(&future)).await.unwrap();
1605        assert!(overview.by_tool.is_empty());
1606
1607        // "Since" far in the past should include it
1608        let past = (Utc::now() - Duration::days(365)).to_rfc3339();
1609        let overview = mgr.cost_overview(Some(&past)).await.unwrap();
1610        assert_eq!(overview.by_tool.len(), 1);
1611        assert_eq!(overview.by_tool[0].tool_name, "CronList");
1612        assert_eq!(overview.by_tool[0].invocations, 1);
1613    }
1614
1615    #[tokio::test]
1616    async fn test_cost_overview_by_tool_without_result() {
1617        let (_tmp, mgr) = setup().await;
1618        let sid = mgr.create_session(&Channel::Main, "tool-no-result").await.unwrap();
1619
1620        // tool_use without a matching tool_result (e.g. stream interrupted)
1621        let tool_use = serde_json::json!({
1622            "type": "tool_use", "id": "tu_orphan", "name": "SkillList", "input": {}
1623        });
1624        mgr.save_message(&sid, "tool_use", &tool_use.to_string()).await.unwrap();
1625
1626        let overview = mgr.cost_overview(None).await.unwrap();
1627        assert_eq!(overview.by_tool.len(), 1);
1628        assert_eq!(overview.by_tool[0].tool_name, "SkillList");
1629        assert_eq!(overview.by_tool[0].invocations, 1);
1630        assert_eq!(overview.by_tool[0].errors, 0, "No result means no error, not an error");
1631    }
1632}