Skip to main content

opensession_core/
stats.rs

1use crate::{EventType, Session, Stats};
2use chrono::Utc;
3
4/// Aggregate statistics computed from a collection of sessions.
5///
6/// All fields are `u64` for in-memory computation; convert to `i64` when
7/// mapping to SQL-based API types via the `From` impls in `api-types`.
8#[derive(Debug, Clone, Default, PartialEq, Eq)]
9pub struct SessionAggregate {
10    pub session_count: u64,
11    pub message_count: u64,
12    pub event_count: u64,
13    pub tool_call_count: u64,
14    pub task_count: u64,
15    pub duration_seconds: u64,
16    pub total_input_tokens: u64,
17    pub total_output_tokens: u64,
18    pub user_message_count: u64,
19    pub files_changed: u64,
20    pub lines_added: u64,
21    pub lines_removed: u64,
22}
23
24impl SessionAggregate {
25    fn add_session_stats(&mut self, stats: &Stats) {
26        self.session_count += 1;
27        self.message_count += stats.message_count;
28        self.event_count += stats.event_count;
29        self.tool_call_count += stats.tool_call_count;
30        self.task_count += stats.task_count;
31        self.duration_seconds += stats.duration_seconds;
32        self.total_input_tokens += stats.total_input_tokens;
33        self.total_output_tokens += stats.total_output_tokens;
34        self.user_message_count += stats.user_message_count;
35        self.files_changed += stats.files_changed;
36        self.lines_added += stats.lines_added;
37        self.lines_removed += stats.lines_removed;
38    }
39}
40
41/// Aggregate pre-computed `stats` from every session in the slice.
42pub fn aggregate(sessions: &[Session]) -> SessionAggregate {
43    let mut agg = SessionAggregate::default();
44    for s in sessions {
45        agg.add_session_stats(&s.stats);
46    }
47    agg
48}
49
50/// Group sessions by `agent.tool` and aggregate each group.
51pub fn aggregate_by_tool(sessions: &[Session]) -> Vec<(String, SessionAggregate)> {
52    aggregate_by(sessions, |s| s.agent.tool.clone())
53}
54
55/// Group sessions by `agent.model` and aggregate each group.
56pub fn aggregate_by_model(sessions: &[Session]) -> Vec<(String, SessionAggregate)> {
57    aggregate_by(sessions, |s| s.agent.model.clone())
58}
59
60/// Generic group-by aggregation. Results sorted by session_count descending.
61fn aggregate_by(
62    sessions: &[Session],
63    key_fn: impl Fn(&Session) -> String,
64) -> Vec<(String, SessionAggregate)> {
65    let mut map = std::collections::HashMap::<String, SessionAggregate>::new();
66    for s in sessions {
67        map.entry(key_fn(s))
68            .or_default()
69            .add_session_stats(&s.stats);
70    }
71    let mut result: Vec<_> = map.into_iter().collect();
72    result.sort_by(|a, b| b.1.session_count.cmp(&a.1.session_count));
73    result
74}
75
76/// Filter sessions by a time-range string relative to now.
77///
78/// Supported values: `"24h"`, `"7d"`, `"30d"`, `"all"` (or anything else → no filter).
79pub fn filter_by_time_range<'a>(sessions: &'a [Session], range: &str) -> Vec<&'a Session> {
80    let cutoff = match range {
81        "24h" => Some(Utc::now() - chrono::Duration::days(1)),
82        "7d" => Some(Utc::now() - chrono::Duration::days(7)),
83        "30d" => Some(Utc::now() - chrono::Duration::days(30)),
84        _ => None,
85    };
86    match cutoff {
87        Some(c) => sessions
88            .iter()
89            .filter(|s| s.context.created_at >= c)
90            .collect(),
91        None => sessions.iter().collect(),
92    }
93}
94
95/// Count tool calls per tool name across all sessions, returning `(tool_name, count)` sorted descending.
96pub fn count_tool_calls(sessions: &[Session]) -> Vec<(String, u64)> {
97    let mut map = std::collections::HashMap::<String, u64>::new();
98    for s in sessions {
99        for event in &s.events {
100            let name = match &event.event_type {
101                EventType::ToolCall { name } => Some(name.clone()),
102                EventType::FileRead { .. } => Some("FileRead".to_string()),
103                EventType::CodeSearch { .. } => Some("CodeSearch".to_string()),
104                EventType::FileSearch { .. } => Some("FileSearch".to_string()),
105                EventType::FileEdit { .. } => Some("FileEdit".to_string()),
106                EventType::FileCreate { .. } => Some("FileCreate".to_string()),
107                EventType::FileDelete { .. } => Some("FileDelete".to_string()),
108                EventType::ShellCommand { .. } => Some("ShellCommand".to_string()),
109                EventType::WebSearch { .. } => Some("WebSearch".to_string()),
110                EventType::WebFetch { .. } => Some("WebFetch".to_string()),
111                _ => None,
112            };
113            if let Some(n) = name {
114                *map.entry(n).or_default() += 1;
115            }
116        }
117    }
118    let mut result: Vec<_> = map.into_iter().collect();
119    result.sort_by(|a, b| b.1.cmp(&a.1));
120    result
121}
122
123// ---------------------------------------------------------------------------
124// SQL helpers — shared query strings for SQLite-backed servers
125// ---------------------------------------------------------------------------
126
127pub mod sql {
128    /// Convert a time-range string to a SQL WHERE clause fragment.
129    ///
130    /// Returns an empty string for `"all"` or unknown values.
131    pub fn time_range_filter(range: &str) -> &'static str {
132        match range {
133            "24h" => " AND s.created_at >= datetime('now', '-1 day')",
134            "7d" => " AND s.created_at >= datetime('now', '-7 days')",
135            "30d" => " AND s.created_at >= datetime('now', '-30 days')",
136            _ => "",
137        }
138    }
139
140    /// Build a totals query for sessions matching `team_id = ?1`.
141    pub fn totals_query(time_filter: &str) -> String {
142        format!(
143            "SELECT \
144                COUNT(*) as session_count, \
145                COALESCE(SUM(s.message_count), 0) as message_count, \
146                COALESCE(SUM(s.event_count), 0) as event_count, \
147                COALESCE(SUM(s.tool_call_count), 0) as tool_call_count, \
148                COALESCE(SUM(s.duration_seconds), 0) as duration_seconds, \
149                COALESCE(SUM(s.total_input_tokens), 0) as total_input_tokens, \
150                COALESCE(SUM(s.total_output_tokens), 0) as total_output_tokens \
151             FROM sessions s \
152             WHERE s.team_id = ?1{time_filter}"
153        )
154    }
155
156    /// Build a by-user grouped query (requires JOIN with `users`).
157    pub fn by_user_query(time_filter: &str) -> String {
158        format!(
159            "SELECT \
160                s.user_id as user_id, \
161                COALESCE(u.nickname, 'unknown') as nickname, \
162                COUNT(*) as session_count, \
163                COALESCE(SUM(s.message_count), 0) as message_count, \
164                COALESCE(SUM(s.event_count), 0) as event_count, \
165                COALESCE(SUM(s.duration_seconds), 0) as duration_seconds, \
166                COALESCE(SUM(s.total_input_tokens), 0) as total_input_tokens, \
167                COALESCE(SUM(s.total_output_tokens), 0) as total_output_tokens \
168             FROM sessions s \
169             LEFT JOIN users u ON u.id = s.user_id \
170             WHERE s.team_id = ?1{time_filter} \
171             GROUP BY s.user_id \
172             ORDER BY session_count DESC"
173        )
174    }
175
176    /// Build a by-tool grouped query.
177    pub fn by_tool_query(time_filter: &str) -> String {
178        format!(
179            "SELECT \
180                s.tool as tool, \
181                COUNT(*) as session_count, \
182                COALESCE(SUM(s.message_count), 0) as message_count, \
183                COALESCE(SUM(s.event_count), 0) as event_count, \
184                COALESCE(SUM(s.duration_seconds), 0) as duration_seconds, \
185                COALESCE(SUM(s.total_input_tokens), 0) as total_input_tokens, \
186                COALESCE(SUM(s.total_output_tokens), 0) as total_output_tokens \
187             FROM sessions s \
188             WHERE s.team_id = ?1{time_filter} \
189             GROUP BY s.tool \
190             ORDER BY session_count DESC"
191        )
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::{Agent, Content, Event, Session, Stats};
199    use chrono::{Duration, Utc};
200    use std::collections::HashMap;
201
202    fn make_agent(tool: &str, model: &str) -> Agent {
203        Agent {
204            provider: "test".to_string(),
205            model: model.to_string(),
206            tool: tool.to_string(),
207            tool_version: None,
208        }
209    }
210
211    fn make_session_with_stats(tool: &str, model: &str, stats: Stats) -> Session {
212        let mut s = Session::new("s1".to_string(), make_agent(tool, model));
213        s.stats = stats;
214        s
215    }
216
217    fn sample_stats(msg: u64, events: u64, tools: u64, dur: u64) -> Stats {
218        Stats {
219            event_count: events,
220            message_count: msg,
221            tool_call_count: tools,
222            task_count: 1,
223            duration_seconds: dur,
224            total_input_tokens: 100,
225            total_output_tokens: 200,
226            ..Default::default()
227        }
228    }
229
230    #[test]
231    fn test_aggregate_empty() {
232        let agg = aggregate(&[]);
233        assert_eq!(agg, SessionAggregate::default());
234    }
235
236    #[test]
237    fn test_aggregate_single() {
238        let sessions = vec![make_session_with_stats(
239            "claude-code",
240            "opus",
241            sample_stats(5, 10, 3, 60),
242        )];
243        let agg = aggregate(&sessions);
244        assert_eq!(agg.session_count, 1);
245        assert_eq!(agg.message_count, 5);
246        assert_eq!(agg.event_count, 10);
247        assert_eq!(agg.tool_call_count, 3);
248        assert_eq!(agg.duration_seconds, 60);
249        assert_eq!(agg.total_input_tokens, 100);
250        assert_eq!(agg.total_output_tokens, 200);
251    }
252
253    #[test]
254    fn test_aggregate_multiple() {
255        let sessions = vec![
256            make_session_with_stats("claude-code", "opus", sample_stats(5, 10, 3, 60)),
257            make_session_with_stats("cursor", "gpt-4o", sample_stats(3, 6, 2, 30)),
258        ];
259        let agg = aggregate(&sessions);
260        assert_eq!(agg.session_count, 2);
261        assert_eq!(agg.message_count, 8);
262        assert_eq!(agg.event_count, 16);
263        assert_eq!(agg.tool_call_count, 5);
264        assert_eq!(agg.duration_seconds, 90);
265        assert_eq!(agg.total_input_tokens, 200);
266        assert_eq!(agg.total_output_tokens, 400);
267    }
268
269    #[test]
270    fn test_aggregate_by_tool() {
271        let sessions = vec![
272            make_session_with_stats("claude-code", "opus", sample_stats(5, 10, 3, 60)),
273            make_session_with_stats("claude-code", "sonnet", sample_stats(3, 6, 2, 30)),
274            make_session_with_stats("cursor", "gpt-4o", sample_stats(1, 2, 1, 10)),
275        ];
276        let by_tool = aggregate_by_tool(&sessions);
277        assert_eq!(by_tool.len(), 2);
278        // claude-code has 2 sessions → should be first
279        assert_eq!(by_tool[0].0, "claude-code");
280        assert_eq!(by_tool[0].1.session_count, 2);
281        assert_eq!(by_tool[1].0, "cursor");
282        assert_eq!(by_tool[1].1.session_count, 1);
283    }
284
285    #[test]
286    fn test_aggregate_by_model() {
287        let sessions = vec![
288            make_session_with_stats("claude-code", "opus", sample_stats(5, 10, 3, 60)),
289            make_session_with_stats("cursor", "opus", sample_stats(3, 6, 2, 30)),
290            make_session_with_stats("cursor", "gpt-4o", sample_stats(1, 2, 1, 10)),
291        ];
292        let by_model = aggregate_by_model(&sessions);
293        assert_eq!(by_model.len(), 2);
294        assert_eq!(by_model[0].0, "opus");
295        assert_eq!(by_model[0].1.session_count, 2);
296    }
297
298    #[test]
299    fn test_filter_by_time_range_all() {
300        let sessions = vec![make_session_with_stats(
301            "cc",
302            "opus",
303            sample_stats(1, 1, 0, 10),
304        )];
305        let filtered = filter_by_time_range(&sessions, "all");
306        assert_eq!(filtered.len(), 1);
307    }
308
309    #[test]
310    fn test_filter_by_time_range_24h() {
311        let mut recent = make_session_with_stats("cc", "opus", sample_stats(1, 1, 0, 10));
312        recent.context.created_at = Utc::now();
313
314        let mut old = make_session_with_stats("cc", "opus", sample_stats(1, 1, 0, 10));
315        old.context.created_at = Utc::now() - Duration::days(2);
316
317        let sessions = vec![recent, old];
318        let filtered = filter_by_time_range(&sessions, "24h");
319        assert_eq!(filtered.len(), 1);
320    }
321
322    #[test]
323    fn test_count_tool_calls() {
324        let mut session = Session::new("s1".to_string(), make_agent("cc", "opus"));
325        session.events.push(Event {
326            event_id: "e1".to_string(),
327            timestamp: Utc::now(),
328            event_type: EventType::ToolCall {
329                name: "Read".to_string(),
330            },
331            task_id: None,
332            content: Content::empty(),
333            duration_ms: None,
334            attributes: HashMap::new(),
335        });
336        session.events.push(Event {
337            event_id: "e2".to_string(),
338            timestamp: Utc::now(),
339            event_type: EventType::FileRead {
340                path: "/tmp/a.rs".to_string(),
341            },
342            task_id: None,
343            content: Content::empty(),
344            duration_ms: None,
345            attributes: HashMap::new(),
346        });
347        session.events.push(Event {
348            event_id: "e3".to_string(),
349            timestamp: Utc::now(),
350            event_type: EventType::UserMessage,
351            task_id: None,
352            content: Content::text("hello"),
353            duration_ms: None,
354            attributes: HashMap::new(),
355        });
356
357        let counts = count_tool_calls(&[session]);
358        assert_eq!(counts.len(), 2);
359        // Both Read and FileRead should appear
360        let names: Vec<&str> = counts.iter().map(|(n, _)| n.as_str()).collect();
361        assert!(names.contains(&"Read"));
362        assert!(names.contains(&"FileRead"));
363    }
364
365    // --- SQL helper tests ---
366
367    #[test]
368    fn test_sql_time_range_filter() {
369        assert_eq!(
370            sql::time_range_filter("24h"),
371            " AND s.created_at >= datetime('now', '-1 day')"
372        );
373        assert_eq!(
374            sql::time_range_filter("7d"),
375            " AND s.created_at >= datetime('now', '-7 days')"
376        );
377        assert_eq!(
378            sql::time_range_filter("30d"),
379            " AND s.created_at >= datetime('now', '-30 days')"
380        );
381        assert_eq!(sql::time_range_filter("all"), "");
382        assert_eq!(sql::time_range_filter("unknown"), "");
383    }
384
385    #[test]
386    fn test_sql_totals_query_contains_expected_fragments() {
387        let q = sql::totals_query("");
388        assert!(q.contains("COUNT(*) as session_count"));
389        assert!(q.contains("SUM(s.message_count)"));
390        assert!(q.contains("SUM(s.total_input_tokens)"));
391        assert!(q.contains("WHERE s.team_id = ?1"));
392    }
393
394    #[test]
395    fn test_sql_totals_query_with_time_filter() {
396        let tf = sql::time_range_filter("24h");
397        let q = sql::totals_query(tf);
398        assert!(q.contains("datetime('now', '-1 day')"));
399    }
400
401    #[test]
402    fn test_sql_by_user_query() {
403        let q = sql::by_user_query("");
404        assert!(q.contains("LEFT JOIN users u"));
405        assert!(q.contains("GROUP BY s.user_id"));
406        assert!(q.contains("ORDER BY session_count DESC"));
407    }
408
409    #[test]
410    fn test_sql_by_tool_query() {
411        let q = sql::by_tool_query("");
412        assert!(q.contains("GROUP BY s.tool"));
413        assert!(q.contains("ORDER BY session_count DESC"));
414    }
415}