Skip to main content

cognee_database/ops/
session_lifecycle.rs

1//! Repository implementation for `SessionLifecycleDb` (LIB-05).
2//!
3//! Line-for-line port of Python's
4//! `cognee/modules/session_lifecycle/metrics.py` plus the three
5//! aggregate queries that live inline in
6//! `cognee/api/v1/sessions/routers/get_sessions_router.py` (stats /
7//! cost-by-model). The trait, public domain types, and effective-status
8//! semantics live in `traits::session_lifecycle_db`.
9//!
10//! Implementation choices:
11//!   * UUIDs persist as 32-char hex (per LIB-03 / `uuid_hex.rs`); the
12//!     repository converts at the boundary so the trait surface is plain
13//!     `Uuid`.
14//!   * `ensure_and_touch_session` and the per-model upsert in
15//!     `accumulate_usage` use raw SQL `INSERT ... ON CONFLICT DO UPDATE`
16//!     via `Statement::from_sql_and_values` to express the COALESCE
17//!     dataset-backfill and the `WHERE status = 'running'` clause on the
18//!     update — neither of which `sea_orm::sea_query::OnConflict`
19//!     surfaces portably. The dialect is SQLite/Postgres-shared syntax;
20//!     branching on `get_database_backend` selects the right backend
21//!     marker.
22//!   * The effective-status helper computes `now - threshold` in Rust
23//!     and binds it as a parameter (mirrors Python at
24//!     `metrics.py:281-282`), so no SQL function for elapsed time is
25//!     needed and the expression is portable across SQLite / Postgres.
26//!   * Duration aggregation pulls `(started_at, ended_at,
27//!     last_activity_at)` rows and folds in Rust — Python does the same
28//!     fallback at `get_sessions_router.py:148-158` because SQLite has
29//!     no `EXTRACT(epoch ...)`.
30
31use std::env;
32
33use chrono::{DateTime, Duration, Utc};
34use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
35use sea_orm::{
36    ColumnTrait, ConnectionTrait, DatabaseBackend, DatabaseConnection, EntityTrait,
37    FromQueryResult, QueryFilter, Statement, Value,
38};
39use tracing::{Span, instrument};
40use uuid::Uuid;
41
42use crate::conversions::map_sea_err;
43use crate::database_system_label;
44use crate::entities::session_record;
45use crate::traits::{
46    CostByModelRow, SessionLifecycleDb, SessionListFilters, SessionListPage, SessionRowWithStatus,
47    SessionStats,
48};
49use crate::types::DatabaseError;
50use crate::uuid_hex;
51
52/// Read the abandonment threshold (seconds) from the environment.
53/// Default `1800` (30 min) — Decision 12. Mirrors Python's
54/// `_abandon_after_seconds` at `metrics.py:47-52`: a non-numeric or
55/// empty value falls through to the default.
56pub fn abandon_after_seconds() -> i64 {
57    env::var("SESSION_ABANDON_AFTER_SECONDS")
58        .ok()
59        .and_then(|s| {
60            let trimmed = s.trim().to_string();
61            if trimmed.is_empty() {
62                None
63            } else {
64                trimmed.parse::<i64>().ok()
65            }
66        })
67        .unwrap_or(1800)
68}
69
70/// Compute the `<` cutoff for `last_activity_at` that flips a running
71/// row to `abandoned`. Centralized so callers (effective-status
72/// expression in raw SQL, status-bucket query, list query) all use the
73/// same wall-clock snapshot.
74fn abandon_threshold_ts() -> DateTime<Utc> {
75    Utc::now() - Duration::seconds(abandon_after_seconds())
76}
77
78/// Render the `effective_status` SQL fragment used inside SELECT lists.
79/// Returns the literal SQL plus the bound parameter (the cutoff
80/// timestamp) so callers can splice it into larger statements.
81///
82/// Matches Python's `get_effective_status_sql` at
83/// `cognee/modules/session_lifecycle/metrics.py:271-292`.
84fn effective_status_sql_fragment(threshold: DateTime<Utc>) -> (String, Value) {
85    // Both SQLite (sqlx-sqlite) and Postgres (sqlx-postgres) support
86    // CASE WHEN ... THEN ... ELSE ... END; the bound timestamp is
87    // dialect-portable via the `Value` enum.
88    let sql =
89        "CASE WHEN status = 'running' AND last_activity_at < ? THEN 'abandoned' ELSE status END"
90            .to_string();
91    (sql, threshold.into())
92}
93
94// ---------------------------------------------------------------------------
95// ensure_and_touch_session
96// ---------------------------------------------------------------------------
97
98#[instrument(
99    name = "cognee.db.relational.session_lifecycle.ensure_and_touch_session",
100    level = "info",
101    skip_all,
102    fields(cognee.db.system = tracing::field::Empty),
103    err,
104)]
105pub async fn ensure_and_touch_session(
106    db: &DatabaseConnection,
107    session_id: &str,
108    user_id: Uuid,
109    dataset_id: Option<Uuid>,
110) -> Result<(), DatabaseError> {
111    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
112    let now = Utc::now();
113    let backend = db.get_database_backend();
114
115    let user_hex = uuid_hex::to_hex(user_id);
116    let dataset_hex = uuid_hex::to_hex_opt(dataset_id);
117
118    // SQLite & Postgres share `INSERT ... ON CONFLICT(...) DO UPDATE
119    // SET ... WHERE ...` syntax. SeaORM's `OnConflict` doesn't expose
120    // the WHERE clause on the update or COALESCE-style backfill, so we
121    // hand-roll the SQL.
122    //
123    // The COALESCE on dataset_id mirrors Python's `case(...)` at
124    // `metrics.py:100-103`: if the existing row's dataset_id is NULL,
125    // adopt the new value; otherwise keep the existing one.
126    let sql = match backend {
127        DatabaseBackend::Sqlite | DatabaseBackend::Postgres => {
128            "INSERT INTO session_records (\
129                session_id, user_id, dataset_id, status, started_at, \
130                last_activity_at, ended_at, tokens_in, tokens_out, \
131                cost_usd, error_count, last_model\
132             ) VALUES ($1, $2, $3, 'running', $4, $4, NULL, 0, 0, 0.0, 0, NULL)\
133             ON CONFLICT (session_id, user_id) DO UPDATE SET \
134                last_activity_at = $4, \
135                dataset_id = COALESCE(session_records.dataset_id, $3) \
136             WHERE session_records.status = 'running'"
137        }
138        DatabaseBackend::MySql => {
139            return Err(DatabaseError::QueryError(
140                "ensure_and_touch_session: MySQL backend not supported".to_string(),
141            ));
142        }
143    };
144
145    db.execute(Statement::from_sql_and_values(
146        backend,
147        sql,
148        [
149            session_id.into(),
150            user_hex.into(),
151            Value::from(dataset_hex),
152            now.into(),
153        ],
154    ))
155    .await
156    .map_err(map_sea_err)?;
157    Ok(())
158}
159
160// ---------------------------------------------------------------------------
161// accumulate_usage
162// ---------------------------------------------------------------------------
163
164// Argument list mirrors Python's `accumulate_usage` keyword arguments at
165// `cognee/modules/session_lifecycle/metrics.py:133-141` for line-for-line
166// parity; introducing a struct just to silence clippy would diverge from
167// the reference shape without adding value.
168#[allow(clippy::too_many_arguments)]
169#[instrument(
170    name = "cognee.db.relational.session_lifecycle.accumulate_usage",
171    level = "info",
172    skip_all,
173    fields(cognee.db.system = tracing::field::Empty),
174    err,
175)]
176pub async fn accumulate_usage(
177    db: &DatabaseConnection,
178    session_id: &str,
179    user_id: Uuid,
180    model: Option<&str>,
181    tokens_in: i64,
182    tokens_out: i64,
183    cost_usd: f64,
184    errored: bool,
185) -> Result<(), DatabaseError> {
186    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
187    // Skip the no-op shortcut from Python `metrics.py:150-151`: nothing
188    // to credit, no error to count, no model to remember.
189    if tokens_in == 0 && tokens_out == 0 && cost_usd == 0.0 && !errored && model.is_none() {
190        return Ok(());
191    }
192
193    let backend = db.get_database_backend();
194    let user_hex = uuid_hex::to_hex(user_id);
195
196    // Step 1: gated UPDATE on session_records. We build the SET clause
197    // dynamically because Python only includes columns that change.
198    //
199    // The `WHERE status = 'running'` gate keeps terminal sessions
200    // frozen so a late straggler can't resurrect or distort them
201    // (Python `metrics.py:172-181`).
202    let mut set_parts: Vec<String> = Vec::new();
203    let mut params: Vec<Value> = Vec::new();
204    let mut next_idx: usize = 1;
205
206    let push_inc = |col: &str,
207                    delta: Value,
208                    set_parts: &mut Vec<String>,
209                    params: &mut Vec<Value>,
210                    next_idx: &mut usize| {
211        set_parts.push(format!("{col} = {col} + ${next_idx}"));
212        params.push(delta);
213        *next_idx += 1;
214    };
215
216    if tokens_in != 0 {
217        // session_records.tokens_in is INTEGER (i32 in SeaORM model);
218        // i64 deltas overflow at ~2.1B which we treat as caller error.
219        let v = i32::try_from(tokens_in).map_err(|_| {
220            DatabaseError::QueryError("accumulate_usage: tokens_in delta overflows i32".to_string())
221        })?;
222        push_inc(
223            "tokens_in",
224            Value::from(v),
225            &mut set_parts,
226            &mut params,
227            &mut next_idx,
228        );
229    }
230    if tokens_out != 0 {
231        let v = i32::try_from(tokens_out).map_err(|_| {
232            DatabaseError::QueryError(
233                "accumulate_usage: tokens_out delta overflows i32".to_string(),
234            )
235        })?;
236        push_inc(
237            "tokens_out",
238            Value::from(v),
239            &mut set_parts,
240            &mut params,
241            &mut next_idx,
242        );
243    }
244    if cost_usd != 0.0 {
245        push_inc(
246            "cost_usd",
247            Value::from(cost_usd),
248            &mut set_parts,
249            &mut params,
250            &mut next_idx,
251        );
252    }
253    if errored {
254        set_parts.push(format!("error_count = error_count + ${next_idx}"));
255        params.push(Value::from(1_i32));
256        next_idx += 1;
257    }
258    if let Some(m) = model {
259        set_parts.push(format!("last_model = ${next_idx}"));
260        params.push(Value::from(m.to_string()));
261        next_idx += 1;
262    }
263
264    if !set_parts.is_empty() {
265        // Append WHERE bindings.
266        let where_session_idx = next_idx;
267        params.push(Value::from(session_id.to_string()));
268        next_idx += 1;
269        let where_user_idx = next_idx;
270        params.push(Value::from(user_hex.clone()));
271        next_idx += 1;
272
273        let sql = format!(
274            "UPDATE session_records SET {set_clause} \
275             WHERE session_id = ${sid} AND user_id = ${uid} AND status = 'running'",
276            set_clause = set_parts.join(", "),
277            sid = where_session_idx,
278            uid = where_user_idx,
279        );
280        let _ = next_idx;
281
282        db.execute(Statement::from_sql_and_values(backend, sql, params))
283            .await
284            .map_err(map_sea_err)?;
285    }
286
287    // Step 2: per-model upsert. Only when there's actual usage to
288    // credit (Python `metrics.py:184`). Errored-only or model-only
289    // calls don't touch session_model_usage.
290    if let Some(m) = model
291        && (tokens_in != 0 || tokens_out != 0 || cost_usd != 0.0)
292    {
293        let now = Utc::now();
294        let ti = i32::try_from(tokens_in).map_err(|_| {
295            DatabaseError::QueryError("accumulate_usage: tokens_in delta overflows i32".to_string())
296        })?;
297        let to = i32::try_from(tokens_out).map_err(|_| {
298            DatabaseError::QueryError(
299                "accumulate_usage: tokens_out delta overflows i32".to_string(),
300            )
301        })?;
302
303        let sql = match backend {
304            DatabaseBackend::Sqlite | DatabaseBackend::Postgres => {
305                "INSERT INTO session_model_usage (\
306                    session_id, user_id, model, tokens_in, tokens_out, cost_usd, updated_at\
307                 ) VALUES ($1, $2, $3, $4, $5, $6, $7)\
308                 ON CONFLICT (session_id, user_id, model) DO UPDATE SET \
309                    tokens_in = session_model_usage.tokens_in + $4, \
310                    tokens_out = session_model_usage.tokens_out + $5, \
311                    cost_usd = session_model_usage.cost_usd + $6, \
312                    updated_at = $7"
313            }
314            DatabaseBackend::MySql => {
315                return Err(DatabaseError::QueryError(
316                    "accumulate_usage: MySQL backend not supported".to_string(),
317                ));
318            }
319        };
320
321        db.execute(Statement::from_sql_and_values(
322            backend,
323            sql,
324            [
325                Value::from(session_id.to_string()),
326                Value::from(user_hex.clone()),
327                Value::from(m.to_string()),
328                Value::from(ti),
329                Value::from(to),
330                Value::from(cost_usd),
331                Value::from(now),
332            ],
333        ))
334        .await
335        .map_err(map_sea_err)?;
336    }
337
338    Ok(())
339}
340
341// ---------------------------------------------------------------------------
342// get_session_row
343// ---------------------------------------------------------------------------
344
345#[instrument(
346    name = "cognee.db.relational.session_lifecycle.get_session_row",
347    level = "info",
348    skip_all,
349    fields(
350        cognee.db.system = tracing::field::Empty,
351        cognee.db.row_count = tracing::field::Empty,
352    ),
353    err,
354)]
355pub async fn get_session_row(
356    db: &DatabaseConnection,
357    session_id: &str,
358    user_id: Uuid,
359    permitted_dataset_ids: &[Uuid],
360    prefer_other_owner: bool,
361) -> Result<Option<SessionRowWithStatus>, DatabaseError> {
362    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
363    let user_hex = uuid_hex::to_hex(user_id);
364
365    // Visibility: caller's own OR session's dataset in permitted set.
366    // Mirrors Python `metrics.py:315-324`.
367    let mut query =
368        session_record::Entity::find().filter(session_record::Column::SessionId.eq(session_id));
369
370    if permitted_dataset_ids.is_empty() {
371        query = query.filter(session_record::Column::UserId.eq(user_hex.clone()));
372    } else {
373        let permitted_hex: Vec<String> = permitted_dataset_ids
374            .iter()
375            .map(|u| uuid_hex::to_hex(*u))
376            .collect();
377        // user_id == :u OR dataset_id IN :permitted
378        let cond = sea_orm::Condition::any()
379            .add(session_record::Column::UserId.eq(user_hex.clone()))
380            .add(session_record::Column::DatasetId.is_in(permitted_hex));
381        query = query.filter(cond);
382    }
383
384    let rows = query.all(db).await.map_err(map_sea_err)?;
385    if rows.is_empty() {
386        Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
387        return Ok(None);
388    }
389
390    // prefer_other_owner: when multiple rows match the visibility OR,
391    // return one whose owner is NOT the caller. Python `metrics.py:329-332`.
392    let chosen = if prefer_other_owner {
393        rows.iter()
394            .find(|r| r.user_id != user_hex)
395            .cloned()
396            .unwrap_or_else(|| rows[0].clone())
397    } else {
398        rows[0].clone()
399    };
400
401    let threshold = abandon_threshold_ts();
402    let effective = compute_effective_status(&chosen, threshold);
403    Span::current().record(COGNEE_DB_ROW_COUNT, 1i64);
404    Ok(Some(SessionRowWithStatus {
405        record: chosen,
406        effective_status: effective,
407    }))
408}
409
410/// Python `get_effective_status_sql` evaluated in Rust for a single row.
411fn compute_effective_status(row: &session_record::Model, threshold: DateTime<Utc>) -> String {
412    if row.status == "running" && row.last_activity_at < threshold {
413        "abandoned".to_string()
414    } else {
415        row.status.clone()
416    }
417}
418
419// ---------------------------------------------------------------------------
420// list_session_rows
421// ---------------------------------------------------------------------------
422
423/// Map an `order_by` string to a real column. Anything unrecognized
424/// falls back to `last_activity_at` (Python `metrics.py:415-423`).
425fn sortable_column(order_by: &str) -> &'static str {
426    match order_by {
427        "started_at" => "started_at",
428        "ended_at" => "ended_at",
429        "cost_usd" => "cost_usd",
430        "tokens_in" => "tokens_in",
431        "tokens_out" => "tokens_out",
432        // "last_activity_at" or anything else
433        _ => "last_activity_at",
434    }
435}
436
437#[derive(Debug, FromQueryResult)]
438struct ListRow {
439    session_id: String,
440    user_id: String,
441    dataset_id: Option<String>,
442    status: String,
443    started_at: DateTime<Utc>,
444    last_activity_at: DateTime<Utc>,
445    ended_at: Option<DateTime<Utc>>,
446    tokens_in: i32,
447    tokens_out: i32,
448    cost_usd: f64,
449    error_count: i32,
450    last_model: Option<String>,
451    effective_status: String,
452}
453
454#[derive(Debug, FromQueryResult)]
455struct CountRow {
456    n: i64,
457}
458
459#[instrument(
460    name = "cognee.db.relational.session_lifecycle.list_session_rows",
461    level = "info",
462    skip_all,
463    fields(
464        cognee.db.system = tracing::field::Empty,
465        cognee.db.row_count = tracing::field::Empty,
466    ),
467    err,
468)]
469pub async fn list_session_rows(
470    db: &DatabaseConnection,
471    filters: SessionListFilters,
472) -> Result<SessionListPage, DatabaseError> {
473    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
474    let backend = db.get_database_backend();
475    let threshold = abandon_threshold_ts();
476    let (eff_sql, eff_param) = effective_status_sql_fragment(threshold);
477    let user_hex = uuid_hex::to_hex(filters.user_id);
478
479    // ---- Build WHERE clause ----------------------------------------------
480    // We track parameter index so we can splice the effective-status
481    // fragment plus user-supplied filter values in order.
482    //
483    // The `eff_param` is bound *only* when `status_filter` is set; the
484    // SELECT list always references it though, so when listing without
485    // a status filter we still need the bound value at SELECT time.
486    let mut where_parts: Vec<String> = Vec::new();
487    let mut where_params: Vec<Value> = Vec::new();
488
489    // visibility predicate
490    if filters.permitted_dataset_ids.is_empty() {
491        where_parts.push("user_id = ?".to_string());
492        where_params.push(Value::from(user_hex.clone()));
493    } else {
494        let mut placeholders = Vec::with_capacity(filters.permitted_dataset_ids.len());
495        let mut perm_params: Vec<Value> = Vec::with_capacity(filters.permitted_dataset_ids.len());
496        for ds in &filters.permitted_dataset_ids {
497            placeholders.push("?");
498            perm_params.push(Value::from(uuid_hex::to_hex(*ds)));
499        }
500        where_parts.push(format!(
501            "(user_id = ? OR dataset_id IN ({}))",
502            placeholders.join(", ")
503        ));
504        where_params.push(Value::from(user_hex.clone()));
505        where_params.extend(perm_params);
506    }
507
508    if let Some(since) = filters.since {
509        where_parts.push("last_activity_at >= ?".to_string());
510        where_params.push(Value::from(since));
511    }
512
513    if let Some(ref status_filter) = filters.status_filter {
514        // The effective-status SQL fragment binds the threshold timestamp.
515        where_parts.push(format!("({eff_sql}) = ?"));
516        where_params.push(eff_param.clone());
517        where_params.push(Value::from(status_filter.clone()));
518    }
519
520    let where_clause = if where_parts.is_empty() {
521        String::new()
522    } else {
523        format!("WHERE {}", where_parts.join(" AND "))
524    };
525
526    // ---- Count query -----------------------------------------------------
527    let count_sql = format!("SELECT COUNT(*) AS n FROM session_records {where_clause}");
528    let count_row = CountRow::find_by_statement(Statement::from_sql_and_values(
529        backend,
530        &count_sql,
531        where_params.clone(),
532    ))
533    .one(db)
534    .await
535    .map_err(map_sea_err)?;
536    let total = count_row.map(|r| r.n).unwrap_or(0);
537
538    // ---- Page query ------------------------------------------------------
539    let sort_col = sortable_column(&filters.order_by);
540    let direction = if filters.descending { "DESC" } else { "ASC" };
541
542    // SELECT must always bind the effective-status threshold. Build
543    // params in the order: SELECT params, WHERE params, LIMIT/OFFSET.
544    let mut page_params: Vec<Value> = Vec::with_capacity(where_params.len() + 3);
545    page_params.push(eff_param.clone()); // for the SELECT list expression
546    page_params.extend(where_params);
547
548    let page_sql = format!(
549        "SELECT session_id, user_id, dataset_id, status, started_at, \
550                last_activity_at, ended_at, tokens_in, tokens_out, cost_usd, \
551                error_count, last_model, ({eff_sql}) AS effective_status \
552         FROM session_records {where_clause} \
553         ORDER BY {sort_col} {direction} \
554         LIMIT ? OFFSET ?"
555    );
556    page_params.push(Value::from(i64::from(filters.limit)));
557    page_params.push(Value::from(i64::from(filters.offset)));
558
559    let raw_rows = ListRow::find_by_statement(Statement::from_sql_and_values(
560        backend,
561        &page_sql,
562        page_params,
563    ))
564    .all(db)
565    .await
566    .map_err(map_sea_err)?;
567
568    let sessions: Vec<SessionRowWithStatus> = raw_rows
569        .into_iter()
570        .map(|r| SessionRowWithStatus {
571            record: session_record::Model {
572                session_id: r.session_id,
573                user_id: r.user_id,
574                dataset_id: r.dataset_id,
575                status: r.status,
576                started_at: r.started_at,
577                last_activity_at: r.last_activity_at,
578                ended_at: r.ended_at,
579                tokens_in: r.tokens_in,
580                tokens_out: r.tokens_out,
581                cost_usd: r.cost_usd,
582                error_count: r.error_count,
583                last_model: r.last_model,
584            },
585            effective_status: r.effective_status,
586        })
587        .collect();
588
589    Span::current().record(COGNEE_DB_ROW_COUNT, sessions.len() as i64);
590    Ok(SessionListPage {
591        sessions,
592        total,
593        limit: filters.limit,
594        offset: filters.offset,
595    })
596}
597
598// ---------------------------------------------------------------------------
599// aggregate_stats
600// ---------------------------------------------------------------------------
601
602#[derive(Debug, FromQueryResult)]
603struct TotalsRow {
604    sessions: i64,
605    tokens_in: i64,
606    tokens_out: i64,
607    cost_usd: f64,
608}
609
610#[derive(Debug, FromQueryResult)]
611struct DurRow {
612    started_at: Option<DateTime<Utc>>,
613    last_activity_at: Option<DateTime<Utc>>,
614    ended_at: Option<DateTime<Utc>>,
615}
616
617#[derive(Debug, FromQueryResult)]
618struct StatusBucketRow {
619    s: String,
620    c: i64,
621}
622
623#[instrument(
624    name = "cognee.db.relational.session_lifecycle.aggregate_stats",
625    level = "info",
626    skip_all,
627    fields(cognee.db.system = tracing::field::Empty),
628    err,
629)]
630pub async fn aggregate_stats(
631    db: &DatabaseConnection,
632    user_id: Uuid,
633    permitted_dataset_ids: &[Uuid],
634    since: Option<DateTime<Utc>>,
635) -> Result<SessionStats, DatabaseError> {
636    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
637    let backend = db.get_database_backend();
638    let user_hex = uuid_hex::to_hex(user_id);
639
640    // Shared visibility / since predicate. The same WHERE clause is
641    // reused across totals / duration / status-bucket queries — Python
642    // builds it once at `get_sessions_router.py:124-131` and reuses.
643    let mut where_parts: Vec<String> = Vec::new();
644    let mut base_params: Vec<Value> = Vec::new();
645
646    if permitted_dataset_ids.is_empty() {
647        where_parts.push("user_id = ?".to_string());
648        base_params.push(Value::from(user_hex.clone()));
649    } else {
650        let mut placeholders = Vec::with_capacity(permitted_dataset_ids.len());
651        let mut perm_params: Vec<Value> = Vec::with_capacity(permitted_dataset_ids.len());
652        for ds in permitted_dataset_ids {
653            placeholders.push("?");
654            perm_params.push(Value::from(uuid_hex::to_hex(*ds)));
655        }
656        where_parts.push(format!(
657            "(user_id = ? OR dataset_id IN ({}))",
658            placeholders.join(", ")
659        ));
660        base_params.push(Value::from(user_hex.clone()));
661        base_params.extend(perm_params);
662    }
663    if let Some(s) = since {
664        where_parts.push("last_activity_at >= ?".to_string());
665        base_params.push(Value::from(s));
666    }
667    let where_clause = if where_parts.is_empty() {
668        String::new()
669    } else {
670        format!("WHERE {}", where_parts.join(" AND "))
671    };
672
673    // ---- (a) Totals ------------------------------------------------------
674    let totals_sql = format!(
675        "SELECT COUNT(*) AS sessions, \
676                COALESCE(SUM(tokens_in), 0) AS tokens_in, \
677                COALESCE(SUM(tokens_out), 0) AS tokens_out, \
678                COALESCE(SUM(cost_usd), 0.0) AS cost_usd \
679         FROM session_records {where_clause}"
680    );
681    let totals = TotalsRow::find_by_statement(Statement::from_sql_and_values(
682        backend,
683        &totals_sql,
684        base_params.clone(),
685    ))
686    .one(db)
687    .await
688    .map_err(map_sea_err)?
689    .unwrap_or(TotalsRow {
690        sessions: 0,
691        tokens_in: 0,
692        tokens_out: 0,
693        cost_usd: 0.0,
694    });
695
696    // ---- (b) Duration ---------------------------------------------------
697    // SQLite has no `EXTRACT(epoch FROM ...)`. Python falls back to
698    // loading `(started, ended, last_activity)` rows and folding in
699    // Python (`get_sessions_router.py:142-159`); we mirror that
700    // exactly for cross-backend portability.
701    let dur_sql = format!(
702        "SELECT started_at, last_activity_at, ended_at \
703         FROM session_records {where_clause}"
704    );
705    let dur_rows = DurRow::find_by_statement(Statement::from_sql_and_values(
706        backend,
707        &dur_sql,
708        base_params.clone(),
709    ))
710    .all(db)
711    .await
712    .map_err(map_sea_err)?;
713
714    let mut total_seconds: f64 = 0.0;
715    let mut session_count: i64 = 0;
716    for row in &dur_rows {
717        let Some(started) = row.started_at else {
718            continue;
719        };
720        let end = row.ended_at.or(row.last_activity_at);
721        let Some(end) = end else { continue };
722        let delta = (end - started).num_milliseconds() as f64 / 1000.0;
723        total_seconds += delta.max(0.0);
724        session_count += 1;
725    }
726    let avg_seconds = if session_count > 0 {
727        total_seconds / session_count as f64
728    } else {
729        0.0
730    };
731
732    // ---- (c) Status buckets ---------------------------------------------
733    let threshold = abandon_threshold_ts();
734    let (eff_sql, eff_param) = effective_status_sql_fragment(threshold);
735    // SELECT params: eff_param first, then base_params (where).
736    let mut bucket_params: Vec<Value> = Vec::with_capacity(base_params.len() + 1);
737    bucket_params.push(eff_param);
738    bucket_params.extend(base_params.clone());
739
740    let bucket_sql = format!(
741        "SELECT ({eff_sql}) AS s, COUNT(*) AS c \
742         FROM session_records {where_clause} \
743         GROUP BY s"
744    );
745    let buckets = StatusBucketRow::find_by_statement(Statement::from_sql_and_values(
746        backend,
747        &bucket_sql,
748        bucket_params,
749    ))
750    .all(db)
751    .await
752    .map_err(map_sea_err)?;
753
754    let mut completed: i64 = 0;
755    let mut failed: i64 = 0;
756    let mut abandoned: i64 = 0;
757    let mut running: i64 = 0;
758    for b in &buckets {
759        match b.s.as_str() {
760            "completed" => completed = b.c,
761            "failed" => failed = b.c,
762            "abandoned" => abandoned = b.c,
763            "running" => running = b.c,
764            _ => {}
765        }
766    }
767    let decided = completed + failed + abandoned;
768    let success_rate = if decided > 0 {
769        completed as f64 / decided as f64
770    } else {
771        1.0
772    };
773
774    let sessions_count = totals.sessions;
775    let avg_spend = if sessions_count > 0 {
776        totals.cost_usd / sessions_count as f64
777    } else {
778        0.0
779    };
780
781    Ok(SessionStats {
782        sessions: sessions_count,
783        total_spend_usd: totals.cost_usd,
784        avg_spend_per_session_usd: avg_spend,
785        tokens_in: totals.tokens_in,
786        tokens_out: totals.tokens_out,
787        tokens_total: totals.tokens_in + totals.tokens_out,
788        agent_time_s: total_seconds,
789        avg_session_s: avg_seconds,
790        success_rate,
791        completed,
792        failed,
793        abandoned,
794        running,
795    })
796}
797
798// ---------------------------------------------------------------------------
799// cost_by_model
800// ---------------------------------------------------------------------------
801
802#[derive(Debug, FromQueryResult)]
803struct CostRow {
804    model: Option<String>,
805    session_count: i64,
806    cost_usd: f64,
807    tokens_in: i64,
808    tokens_out: i64,
809}
810
811#[instrument(
812    name = "cognee.db.relational.session_lifecycle.cost_by_model",
813    level = "info",
814    skip_all,
815    fields(
816        cognee.db.system = tracing::field::Empty,
817        cognee.db.row_count = tracing::field::Empty,
818    ),
819    err,
820)]
821pub async fn cost_by_model(
822    db: &DatabaseConnection,
823    user_id: Uuid,
824    permitted_dataset_ids: &[Uuid],
825    since: Option<DateTime<Utc>>,
826) -> Result<Vec<CostByModelRow>, DatabaseError> {
827    Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
828    let backend = db.get_database_backend();
829    let user_hex = uuid_hex::to_hex(user_id);
830
831    let mut where_parts: Vec<String> = Vec::new();
832    let mut params: Vec<Value> = Vec::new();
833
834    if permitted_dataset_ids.is_empty() {
835        where_parts.push("sr.user_id = ?".to_string());
836        params.push(Value::from(user_hex.clone()));
837    } else {
838        let mut placeholders = Vec::with_capacity(permitted_dataset_ids.len());
839        let mut perm_params: Vec<Value> = Vec::with_capacity(permitted_dataset_ids.len());
840        for ds in permitted_dataset_ids {
841            placeholders.push("?");
842            perm_params.push(Value::from(uuid_hex::to_hex(*ds)));
843        }
844        where_parts.push(format!(
845            "(sr.user_id = ? OR sr.dataset_id IN ({}))",
846            placeholders.join(", ")
847        ));
848        params.push(Value::from(user_hex.clone()));
849        params.extend(perm_params);
850    }
851    if let Some(s) = since {
852        where_parts.push("sr.last_activity_at >= ?".to_string());
853        params.push(Value::from(s));
854    }
855    let where_clause = if where_parts.is_empty() {
856        String::new()
857    } else {
858        format!("WHERE {}", where_parts.join(" AND "))
859    };
860
861    // COUNT(DISTINCT smu.session_id) — not raw row count — matches
862    // `get_sessions_router.py:220`. ORDER BY total cost descending.
863    let sql = format!(
864        "SELECT smu.model AS model, \
865                COUNT(DISTINCT smu.session_id) AS session_count, \
866                COALESCE(SUM(smu.cost_usd), 0.0) AS cost_usd, \
867                COALESCE(SUM(smu.tokens_in), 0) AS tokens_in, \
868                COALESCE(SUM(smu.tokens_out), 0) AS tokens_out \
869         FROM session_model_usage smu \
870         JOIN session_records sr ON smu.session_id = sr.session_id \
871                                 AND smu.user_id = sr.user_id \
872         {where_clause} \
873         GROUP BY smu.model \
874         ORDER BY SUM(smu.cost_usd) DESC"
875    );
876
877    let rows = CostRow::find_by_statement(Statement::from_sql_and_values(backend, &sql, params))
878        .all(db)
879        .await
880        .map_err(map_sea_err)?;
881
882    let result: Vec<CostByModelRow> = rows
883        .into_iter()
884        .map(|r| CostByModelRow {
885            model: r.model.unwrap_or_else(|| "unknown".to_string()),
886            session_count: r.session_count,
887            cost_usd: r.cost_usd,
888            tokens_in: r.tokens_in,
889            tokens_out: r.tokens_out,
890        })
891        .collect();
892    Span::current().record(COGNEE_DB_ROW_COUNT, result.len() as i64);
893    Ok(result)
894}
895
896// ---------------------------------------------------------------------------
897// Trait impl on DatabaseConnection
898// ---------------------------------------------------------------------------
899
900#[async_trait::async_trait]
901impl SessionLifecycleDb for DatabaseConnection {
902    async fn ensure_and_touch_session(
903        &self,
904        session_id: &str,
905        user_id: Uuid,
906        dataset_id: Option<Uuid>,
907    ) -> Result<(), DatabaseError> {
908        ensure_and_touch_session(self, session_id, user_id, dataset_id).await
909    }
910
911    async fn accumulate_usage(
912        &self,
913        session_id: &str,
914        user_id: Uuid,
915        model: Option<&str>,
916        tokens_in: i64,
917        tokens_out: i64,
918        cost_usd: f64,
919        errored: bool,
920    ) -> Result<(), DatabaseError> {
921        accumulate_usage(
922            self, session_id, user_id, model, tokens_in, tokens_out, cost_usd, errored,
923        )
924        .await
925    }
926
927    async fn get_session_row(
928        &self,
929        session_id: &str,
930        user_id: Uuid,
931        permitted_dataset_ids: &[Uuid],
932        prefer_other_owner: bool,
933    ) -> Result<Option<SessionRowWithStatus>, DatabaseError> {
934        get_session_row(
935            self,
936            session_id,
937            user_id,
938            permitted_dataset_ids,
939            prefer_other_owner,
940        )
941        .await
942    }
943
944    async fn list_session_rows(
945        &self,
946        filters: SessionListFilters,
947    ) -> Result<SessionListPage, DatabaseError> {
948        list_session_rows(self, filters).await
949    }
950
951    async fn aggregate_stats(
952        &self,
953        user_id: Uuid,
954        permitted_dataset_ids: &[Uuid],
955        since: Option<DateTime<Utc>>,
956    ) -> Result<SessionStats, DatabaseError> {
957        aggregate_stats(self, user_id, permitted_dataset_ids, since).await
958    }
959
960    async fn cost_by_model(
961        &self,
962        user_id: Uuid,
963        permitted_dataset_ids: &[Uuid],
964        since: Option<DateTime<Utc>>,
965    ) -> Result<Vec<CostByModelRow>, DatabaseError> {
966        cost_by_model(self, user_id, permitted_dataset_ids, since).await
967    }
968}