1use std::env;
32
33use chrono::{DateTime, Duration, Utc};
34use cognee_utils::tracing_keys::{COGNEE_DB_ROW_COUNT, COGNEE_DB_SYSTEM};
35use sea_orm::{
36 ColumnTrait, ConnectionTrait, DatabaseBackend, DatabaseConnection, EntityTrait,
37 FromQueryResult, QueryFilter, Statement, Value,
38};
39use tracing::{Span, instrument};
40use uuid::Uuid;
41
42use crate::conversions::map_sea_err;
43use crate::database_system_label;
44use crate::entities::session_record;
45use crate::traits::{
46 CostByModelRow, SessionLifecycleDb, SessionListFilters, SessionListPage, SessionRowWithStatus,
47 SessionStats,
48};
49use crate::types::DatabaseError;
50use crate::uuid_hex;
51
52pub fn abandon_after_seconds() -> i64 {
57 env::var("SESSION_ABANDON_AFTER_SECONDS")
58 .ok()
59 .and_then(|s| {
60 let trimmed = s.trim().to_string();
61 if trimmed.is_empty() {
62 None
63 } else {
64 trimmed.parse::<i64>().ok()
65 }
66 })
67 .unwrap_or(1800)
68}
69
70fn abandon_threshold_ts() -> DateTime<Utc> {
75 Utc::now() - Duration::seconds(abandon_after_seconds())
76}
77
78fn effective_status_sql_fragment(threshold: DateTime<Utc>) -> (String, Value) {
85 let sql =
89 "CASE WHEN status = 'running' AND last_activity_at < ? THEN 'abandoned' ELSE status END"
90 .to_string();
91 (sql, threshold.into())
92}
93
94#[instrument(
99 name = "cognee.db.relational.session_lifecycle.ensure_and_touch_session",
100 level = "info",
101 skip_all,
102 fields(cognee.db.system = tracing::field::Empty),
103 err,
104)]
105pub async fn ensure_and_touch_session(
106 db: &DatabaseConnection,
107 session_id: &str,
108 user_id: Uuid,
109 dataset_id: Option<Uuid>,
110) -> Result<(), DatabaseError> {
111 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
112 let now = Utc::now();
113 let backend = db.get_database_backend();
114
115 let user_hex = uuid_hex::to_hex(user_id);
116 let dataset_hex = uuid_hex::to_hex_opt(dataset_id);
117
118 let sql = match backend {
127 DatabaseBackend::Sqlite | DatabaseBackend::Postgres => {
128 "INSERT INTO session_records (\
129 session_id, user_id, dataset_id, status, started_at, \
130 last_activity_at, ended_at, tokens_in, tokens_out, \
131 cost_usd, error_count, last_model\
132 ) VALUES ($1, $2, $3, 'running', $4, $4, NULL, 0, 0, 0.0, 0, NULL)\
133 ON CONFLICT (session_id, user_id) DO UPDATE SET \
134 last_activity_at = $4, \
135 dataset_id = COALESCE(session_records.dataset_id, $3) \
136 WHERE session_records.status = 'running'"
137 }
138 DatabaseBackend::MySql => {
139 return Err(DatabaseError::QueryError(
140 "ensure_and_touch_session: MySQL backend not supported".to_string(),
141 ));
142 }
143 };
144
145 db.execute(Statement::from_sql_and_values(
146 backend,
147 sql,
148 [
149 session_id.into(),
150 user_hex.into(),
151 Value::from(dataset_hex),
152 now.into(),
153 ],
154 ))
155 .await
156 .map_err(map_sea_err)?;
157 Ok(())
158}
159
160#[allow(clippy::too_many_arguments)]
169#[instrument(
170 name = "cognee.db.relational.session_lifecycle.accumulate_usage",
171 level = "info",
172 skip_all,
173 fields(cognee.db.system = tracing::field::Empty),
174 err,
175)]
176pub async fn accumulate_usage(
177 db: &DatabaseConnection,
178 session_id: &str,
179 user_id: Uuid,
180 model: Option<&str>,
181 tokens_in: i64,
182 tokens_out: i64,
183 cost_usd: f64,
184 errored: bool,
185) -> Result<(), DatabaseError> {
186 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
187 if tokens_in == 0 && tokens_out == 0 && cost_usd == 0.0 && !errored && model.is_none() {
190 return Ok(());
191 }
192
193 let backend = db.get_database_backend();
194 let user_hex = uuid_hex::to_hex(user_id);
195
196 let mut set_parts: Vec<String> = Vec::new();
203 let mut params: Vec<Value> = Vec::new();
204 let mut next_idx: usize = 1;
205
206 let push_inc = |col: &str,
207 delta: Value,
208 set_parts: &mut Vec<String>,
209 params: &mut Vec<Value>,
210 next_idx: &mut usize| {
211 set_parts.push(format!("{col} = {col} + ${next_idx}"));
212 params.push(delta);
213 *next_idx += 1;
214 };
215
216 if tokens_in != 0 {
217 let v = i32::try_from(tokens_in).map_err(|_| {
220 DatabaseError::QueryError("accumulate_usage: tokens_in delta overflows i32".to_string())
221 })?;
222 push_inc(
223 "tokens_in",
224 Value::from(v),
225 &mut set_parts,
226 &mut params,
227 &mut next_idx,
228 );
229 }
230 if tokens_out != 0 {
231 let v = i32::try_from(tokens_out).map_err(|_| {
232 DatabaseError::QueryError(
233 "accumulate_usage: tokens_out delta overflows i32".to_string(),
234 )
235 })?;
236 push_inc(
237 "tokens_out",
238 Value::from(v),
239 &mut set_parts,
240 &mut params,
241 &mut next_idx,
242 );
243 }
244 if cost_usd != 0.0 {
245 push_inc(
246 "cost_usd",
247 Value::from(cost_usd),
248 &mut set_parts,
249 &mut params,
250 &mut next_idx,
251 );
252 }
253 if errored {
254 set_parts.push(format!("error_count = error_count + ${next_idx}"));
255 params.push(Value::from(1_i32));
256 next_idx += 1;
257 }
258 if let Some(m) = model {
259 set_parts.push(format!("last_model = ${next_idx}"));
260 params.push(Value::from(m.to_string()));
261 next_idx += 1;
262 }
263
264 if !set_parts.is_empty() {
265 let where_session_idx = next_idx;
267 params.push(Value::from(session_id.to_string()));
268 next_idx += 1;
269 let where_user_idx = next_idx;
270 params.push(Value::from(user_hex.clone()));
271 next_idx += 1;
272
273 let sql = format!(
274 "UPDATE session_records SET {set_clause} \
275 WHERE session_id = ${sid} AND user_id = ${uid} AND status = 'running'",
276 set_clause = set_parts.join(", "),
277 sid = where_session_idx,
278 uid = where_user_idx,
279 );
280 let _ = next_idx;
281
282 db.execute(Statement::from_sql_and_values(backend, sql, params))
283 .await
284 .map_err(map_sea_err)?;
285 }
286
287 if let Some(m) = model
291 && (tokens_in != 0 || tokens_out != 0 || cost_usd != 0.0)
292 {
293 let now = Utc::now();
294 let ti = i32::try_from(tokens_in).map_err(|_| {
295 DatabaseError::QueryError("accumulate_usage: tokens_in delta overflows i32".to_string())
296 })?;
297 let to = i32::try_from(tokens_out).map_err(|_| {
298 DatabaseError::QueryError(
299 "accumulate_usage: tokens_out delta overflows i32".to_string(),
300 )
301 })?;
302
303 let sql = match backend {
304 DatabaseBackend::Sqlite | DatabaseBackend::Postgres => {
305 "INSERT INTO session_model_usage (\
306 session_id, user_id, model, tokens_in, tokens_out, cost_usd, updated_at\
307 ) VALUES ($1, $2, $3, $4, $5, $6, $7)\
308 ON CONFLICT (session_id, user_id, model) DO UPDATE SET \
309 tokens_in = session_model_usage.tokens_in + $4, \
310 tokens_out = session_model_usage.tokens_out + $5, \
311 cost_usd = session_model_usage.cost_usd + $6, \
312 updated_at = $7"
313 }
314 DatabaseBackend::MySql => {
315 return Err(DatabaseError::QueryError(
316 "accumulate_usage: MySQL backend not supported".to_string(),
317 ));
318 }
319 };
320
321 db.execute(Statement::from_sql_and_values(
322 backend,
323 sql,
324 [
325 Value::from(session_id.to_string()),
326 Value::from(user_hex.clone()),
327 Value::from(m.to_string()),
328 Value::from(ti),
329 Value::from(to),
330 Value::from(cost_usd),
331 Value::from(now),
332 ],
333 ))
334 .await
335 .map_err(map_sea_err)?;
336 }
337
338 Ok(())
339}
340
341#[instrument(
346 name = "cognee.db.relational.session_lifecycle.get_session_row",
347 level = "info",
348 skip_all,
349 fields(
350 cognee.db.system = tracing::field::Empty,
351 cognee.db.row_count = tracing::field::Empty,
352 ),
353 err,
354)]
355pub async fn get_session_row(
356 db: &DatabaseConnection,
357 session_id: &str,
358 user_id: Uuid,
359 permitted_dataset_ids: &[Uuid],
360 prefer_other_owner: bool,
361) -> Result<Option<SessionRowWithStatus>, DatabaseError> {
362 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
363 let user_hex = uuid_hex::to_hex(user_id);
364
365 let mut query =
368 session_record::Entity::find().filter(session_record::Column::SessionId.eq(session_id));
369
370 if permitted_dataset_ids.is_empty() {
371 query = query.filter(session_record::Column::UserId.eq(user_hex.clone()));
372 } else {
373 let permitted_hex: Vec<String> = permitted_dataset_ids
374 .iter()
375 .map(|u| uuid_hex::to_hex(*u))
376 .collect();
377 let cond = sea_orm::Condition::any()
379 .add(session_record::Column::UserId.eq(user_hex.clone()))
380 .add(session_record::Column::DatasetId.is_in(permitted_hex));
381 query = query.filter(cond);
382 }
383
384 let rows = query.all(db).await.map_err(map_sea_err)?;
385 if rows.is_empty() {
386 Span::current().record(COGNEE_DB_ROW_COUNT, 0i64);
387 return Ok(None);
388 }
389
390 let chosen = if prefer_other_owner {
393 rows.iter()
394 .find(|r| r.user_id != user_hex)
395 .cloned()
396 .unwrap_or_else(|| rows[0].clone())
397 } else {
398 rows[0].clone()
399 };
400
401 let threshold = abandon_threshold_ts();
402 let effective = compute_effective_status(&chosen, threshold);
403 Span::current().record(COGNEE_DB_ROW_COUNT, 1i64);
404 Ok(Some(SessionRowWithStatus {
405 record: chosen,
406 effective_status: effective,
407 }))
408}
409
410fn compute_effective_status(row: &session_record::Model, threshold: DateTime<Utc>) -> String {
412 if row.status == "running" && row.last_activity_at < threshold {
413 "abandoned".to_string()
414 } else {
415 row.status.clone()
416 }
417}
418
419fn sortable_column(order_by: &str) -> &'static str {
426 match order_by {
427 "started_at" => "started_at",
428 "ended_at" => "ended_at",
429 "cost_usd" => "cost_usd",
430 "tokens_in" => "tokens_in",
431 "tokens_out" => "tokens_out",
432 _ => "last_activity_at",
434 }
435}
436
437#[derive(Debug, FromQueryResult)]
438struct ListRow {
439 session_id: String,
440 user_id: String,
441 dataset_id: Option<String>,
442 status: String,
443 started_at: DateTime<Utc>,
444 last_activity_at: DateTime<Utc>,
445 ended_at: Option<DateTime<Utc>>,
446 tokens_in: i32,
447 tokens_out: i32,
448 cost_usd: f64,
449 error_count: i32,
450 last_model: Option<String>,
451 effective_status: String,
452}
453
454#[derive(Debug, FromQueryResult)]
455struct CountRow {
456 n: i64,
457}
458
459#[instrument(
460 name = "cognee.db.relational.session_lifecycle.list_session_rows",
461 level = "info",
462 skip_all,
463 fields(
464 cognee.db.system = tracing::field::Empty,
465 cognee.db.row_count = tracing::field::Empty,
466 ),
467 err,
468)]
469pub async fn list_session_rows(
470 db: &DatabaseConnection,
471 filters: SessionListFilters,
472) -> Result<SessionListPage, DatabaseError> {
473 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
474 let backend = db.get_database_backend();
475 let threshold = abandon_threshold_ts();
476 let (eff_sql, eff_param) = effective_status_sql_fragment(threshold);
477 let user_hex = uuid_hex::to_hex(filters.user_id);
478
479 let mut where_parts: Vec<String> = Vec::new();
487 let mut where_params: Vec<Value> = Vec::new();
488
489 if filters.permitted_dataset_ids.is_empty() {
491 where_parts.push("user_id = ?".to_string());
492 where_params.push(Value::from(user_hex.clone()));
493 } else {
494 let mut placeholders = Vec::with_capacity(filters.permitted_dataset_ids.len());
495 let mut perm_params: Vec<Value> = Vec::with_capacity(filters.permitted_dataset_ids.len());
496 for ds in &filters.permitted_dataset_ids {
497 placeholders.push("?");
498 perm_params.push(Value::from(uuid_hex::to_hex(*ds)));
499 }
500 where_parts.push(format!(
501 "(user_id = ? OR dataset_id IN ({}))",
502 placeholders.join(", ")
503 ));
504 where_params.push(Value::from(user_hex.clone()));
505 where_params.extend(perm_params);
506 }
507
508 if let Some(since) = filters.since {
509 where_parts.push("last_activity_at >= ?".to_string());
510 where_params.push(Value::from(since));
511 }
512
513 if let Some(ref status_filter) = filters.status_filter {
514 where_parts.push(format!("({eff_sql}) = ?"));
516 where_params.push(eff_param.clone());
517 where_params.push(Value::from(status_filter.clone()));
518 }
519
520 let where_clause = if where_parts.is_empty() {
521 String::new()
522 } else {
523 format!("WHERE {}", where_parts.join(" AND "))
524 };
525
526 let count_sql = format!("SELECT COUNT(*) AS n FROM session_records {where_clause}");
528 let count_row = CountRow::find_by_statement(Statement::from_sql_and_values(
529 backend,
530 &count_sql,
531 where_params.clone(),
532 ))
533 .one(db)
534 .await
535 .map_err(map_sea_err)?;
536 let total = count_row.map(|r| r.n).unwrap_or(0);
537
538 let sort_col = sortable_column(&filters.order_by);
540 let direction = if filters.descending { "DESC" } else { "ASC" };
541
542 let mut page_params: Vec<Value> = Vec::with_capacity(where_params.len() + 3);
545 page_params.push(eff_param.clone()); page_params.extend(where_params);
547
548 let page_sql = format!(
549 "SELECT session_id, user_id, dataset_id, status, started_at, \
550 last_activity_at, ended_at, tokens_in, tokens_out, cost_usd, \
551 error_count, last_model, ({eff_sql}) AS effective_status \
552 FROM session_records {where_clause} \
553 ORDER BY {sort_col} {direction} \
554 LIMIT ? OFFSET ?"
555 );
556 page_params.push(Value::from(i64::from(filters.limit)));
557 page_params.push(Value::from(i64::from(filters.offset)));
558
559 let raw_rows = ListRow::find_by_statement(Statement::from_sql_and_values(
560 backend,
561 &page_sql,
562 page_params,
563 ))
564 .all(db)
565 .await
566 .map_err(map_sea_err)?;
567
568 let sessions: Vec<SessionRowWithStatus> = raw_rows
569 .into_iter()
570 .map(|r| SessionRowWithStatus {
571 record: session_record::Model {
572 session_id: r.session_id,
573 user_id: r.user_id,
574 dataset_id: r.dataset_id,
575 status: r.status,
576 started_at: r.started_at,
577 last_activity_at: r.last_activity_at,
578 ended_at: r.ended_at,
579 tokens_in: r.tokens_in,
580 tokens_out: r.tokens_out,
581 cost_usd: r.cost_usd,
582 error_count: r.error_count,
583 last_model: r.last_model,
584 },
585 effective_status: r.effective_status,
586 })
587 .collect();
588
589 Span::current().record(COGNEE_DB_ROW_COUNT, sessions.len() as i64);
590 Ok(SessionListPage {
591 sessions,
592 total,
593 limit: filters.limit,
594 offset: filters.offset,
595 })
596}
597
598#[derive(Debug, FromQueryResult)]
603struct TotalsRow {
604 sessions: i64,
605 tokens_in: i64,
606 tokens_out: i64,
607 cost_usd: f64,
608}
609
610#[derive(Debug, FromQueryResult)]
611struct DurRow {
612 started_at: Option<DateTime<Utc>>,
613 last_activity_at: Option<DateTime<Utc>>,
614 ended_at: Option<DateTime<Utc>>,
615}
616
617#[derive(Debug, FromQueryResult)]
618struct StatusBucketRow {
619 s: String,
620 c: i64,
621}
622
623#[instrument(
624 name = "cognee.db.relational.session_lifecycle.aggregate_stats",
625 level = "info",
626 skip_all,
627 fields(cognee.db.system = tracing::field::Empty),
628 err,
629)]
630pub async fn aggregate_stats(
631 db: &DatabaseConnection,
632 user_id: Uuid,
633 permitted_dataset_ids: &[Uuid],
634 since: Option<DateTime<Utc>>,
635) -> Result<SessionStats, DatabaseError> {
636 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
637 let backend = db.get_database_backend();
638 let user_hex = uuid_hex::to_hex(user_id);
639
640 let mut where_parts: Vec<String> = Vec::new();
644 let mut base_params: Vec<Value> = Vec::new();
645
646 if permitted_dataset_ids.is_empty() {
647 where_parts.push("user_id = ?".to_string());
648 base_params.push(Value::from(user_hex.clone()));
649 } else {
650 let mut placeholders = Vec::with_capacity(permitted_dataset_ids.len());
651 let mut perm_params: Vec<Value> = Vec::with_capacity(permitted_dataset_ids.len());
652 for ds in permitted_dataset_ids {
653 placeholders.push("?");
654 perm_params.push(Value::from(uuid_hex::to_hex(*ds)));
655 }
656 where_parts.push(format!(
657 "(user_id = ? OR dataset_id IN ({}))",
658 placeholders.join(", ")
659 ));
660 base_params.push(Value::from(user_hex.clone()));
661 base_params.extend(perm_params);
662 }
663 if let Some(s) = since {
664 where_parts.push("last_activity_at >= ?".to_string());
665 base_params.push(Value::from(s));
666 }
667 let where_clause = if where_parts.is_empty() {
668 String::new()
669 } else {
670 format!("WHERE {}", where_parts.join(" AND "))
671 };
672
673 let totals_sql = format!(
675 "SELECT COUNT(*) AS sessions, \
676 COALESCE(SUM(tokens_in), 0) AS tokens_in, \
677 COALESCE(SUM(tokens_out), 0) AS tokens_out, \
678 COALESCE(SUM(cost_usd), 0.0) AS cost_usd \
679 FROM session_records {where_clause}"
680 );
681 let totals = TotalsRow::find_by_statement(Statement::from_sql_and_values(
682 backend,
683 &totals_sql,
684 base_params.clone(),
685 ))
686 .one(db)
687 .await
688 .map_err(map_sea_err)?
689 .unwrap_or(TotalsRow {
690 sessions: 0,
691 tokens_in: 0,
692 tokens_out: 0,
693 cost_usd: 0.0,
694 });
695
696 let dur_sql = format!(
702 "SELECT started_at, last_activity_at, ended_at \
703 FROM session_records {where_clause}"
704 );
705 let dur_rows = DurRow::find_by_statement(Statement::from_sql_and_values(
706 backend,
707 &dur_sql,
708 base_params.clone(),
709 ))
710 .all(db)
711 .await
712 .map_err(map_sea_err)?;
713
714 let mut total_seconds: f64 = 0.0;
715 let mut session_count: i64 = 0;
716 for row in &dur_rows {
717 let Some(started) = row.started_at else {
718 continue;
719 };
720 let end = row.ended_at.or(row.last_activity_at);
721 let Some(end) = end else { continue };
722 let delta = (end - started).num_milliseconds() as f64 / 1000.0;
723 total_seconds += delta.max(0.0);
724 session_count += 1;
725 }
726 let avg_seconds = if session_count > 0 {
727 total_seconds / session_count as f64
728 } else {
729 0.0
730 };
731
732 let threshold = abandon_threshold_ts();
734 let (eff_sql, eff_param) = effective_status_sql_fragment(threshold);
735 let mut bucket_params: Vec<Value> = Vec::with_capacity(base_params.len() + 1);
737 bucket_params.push(eff_param);
738 bucket_params.extend(base_params.clone());
739
740 let bucket_sql = format!(
741 "SELECT ({eff_sql}) AS s, COUNT(*) AS c \
742 FROM session_records {where_clause} \
743 GROUP BY s"
744 );
745 let buckets = StatusBucketRow::find_by_statement(Statement::from_sql_and_values(
746 backend,
747 &bucket_sql,
748 bucket_params,
749 ))
750 .all(db)
751 .await
752 .map_err(map_sea_err)?;
753
754 let mut completed: i64 = 0;
755 let mut failed: i64 = 0;
756 let mut abandoned: i64 = 0;
757 let mut running: i64 = 0;
758 for b in &buckets {
759 match b.s.as_str() {
760 "completed" => completed = b.c,
761 "failed" => failed = b.c,
762 "abandoned" => abandoned = b.c,
763 "running" => running = b.c,
764 _ => {}
765 }
766 }
767 let decided = completed + failed + abandoned;
768 let success_rate = if decided > 0 {
769 completed as f64 / decided as f64
770 } else {
771 1.0
772 };
773
774 let sessions_count = totals.sessions;
775 let avg_spend = if sessions_count > 0 {
776 totals.cost_usd / sessions_count as f64
777 } else {
778 0.0
779 };
780
781 Ok(SessionStats {
782 sessions: sessions_count,
783 total_spend_usd: totals.cost_usd,
784 avg_spend_per_session_usd: avg_spend,
785 tokens_in: totals.tokens_in,
786 tokens_out: totals.tokens_out,
787 tokens_total: totals.tokens_in + totals.tokens_out,
788 agent_time_s: total_seconds,
789 avg_session_s: avg_seconds,
790 success_rate,
791 completed,
792 failed,
793 abandoned,
794 running,
795 })
796}
797
798#[derive(Debug, FromQueryResult)]
803struct CostRow {
804 model: Option<String>,
805 session_count: i64,
806 cost_usd: f64,
807 tokens_in: i64,
808 tokens_out: i64,
809}
810
811#[instrument(
812 name = "cognee.db.relational.session_lifecycle.cost_by_model",
813 level = "info",
814 skip_all,
815 fields(
816 cognee.db.system = tracing::field::Empty,
817 cognee.db.row_count = tracing::field::Empty,
818 ),
819 err,
820)]
821pub async fn cost_by_model(
822 db: &DatabaseConnection,
823 user_id: Uuid,
824 permitted_dataset_ids: &[Uuid],
825 since: Option<DateTime<Utc>>,
826) -> Result<Vec<CostByModelRow>, DatabaseError> {
827 Span::current().record(COGNEE_DB_SYSTEM, database_system_label(db));
828 let backend = db.get_database_backend();
829 let user_hex = uuid_hex::to_hex(user_id);
830
831 let mut where_parts: Vec<String> = Vec::new();
832 let mut params: Vec<Value> = Vec::new();
833
834 if permitted_dataset_ids.is_empty() {
835 where_parts.push("sr.user_id = ?".to_string());
836 params.push(Value::from(user_hex.clone()));
837 } else {
838 let mut placeholders = Vec::with_capacity(permitted_dataset_ids.len());
839 let mut perm_params: Vec<Value> = Vec::with_capacity(permitted_dataset_ids.len());
840 for ds in permitted_dataset_ids {
841 placeholders.push("?");
842 perm_params.push(Value::from(uuid_hex::to_hex(*ds)));
843 }
844 where_parts.push(format!(
845 "(sr.user_id = ? OR sr.dataset_id IN ({}))",
846 placeholders.join(", ")
847 ));
848 params.push(Value::from(user_hex.clone()));
849 params.extend(perm_params);
850 }
851 if let Some(s) = since {
852 where_parts.push("sr.last_activity_at >= ?".to_string());
853 params.push(Value::from(s));
854 }
855 let where_clause = if where_parts.is_empty() {
856 String::new()
857 } else {
858 format!("WHERE {}", where_parts.join(" AND "))
859 };
860
861 let sql = format!(
864 "SELECT smu.model AS model, \
865 COUNT(DISTINCT smu.session_id) AS session_count, \
866 COALESCE(SUM(smu.cost_usd), 0.0) AS cost_usd, \
867 COALESCE(SUM(smu.tokens_in), 0) AS tokens_in, \
868 COALESCE(SUM(smu.tokens_out), 0) AS tokens_out \
869 FROM session_model_usage smu \
870 JOIN session_records sr ON smu.session_id = sr.session_id \
871 AND smu.user_id = sr.user_id \
872 {where_clause} \
873 GROUP BY smu.model \
874 ORDER BY SUM(smu.cost_usd) DESC"
875 );
876
877 let rows = CostRow::find_by_statement(Statement::from_sql_and_values(backend, &sql, params))
878 .all(db)
879 .await
880 .map_err(map_sea_err)?;
881
882 let result: Vec<CostByModelRow> = rows
883 .into_iter()
884 .map(|r| CostByModelRow {
885 model: r.model.unwrap_or_else(|| "unknown".to_string()),
886 session_count: r.session_count,
887 cost_usd: r.cost_usd,
888 tokens_in: r.tokens_in,
889 tokens_out: r.tokens_out,
890 })
891 .collect();
892 Span::current().record(COGNEE_DB_ROW_COUNT, result.len() as i64);
893 Ok(result)
894}
895
896#[async_trait::async_trait]
901impl SessionLifecycleDb for DatabaseConnection {
902 async fn ensure_and_touch_session(
903 &self,
904 session_id: &str,
905 user_id: Uuid,
906 dataset_id: Option<Uuid>,
907 ) -> Result<(), DatabaseError> {
908 ensure_and_touch_session(self, session_id, user_id, dataset_id).await
909 }
910
911 async fn accumulate_usage(
912 &self,
913 session_id: &str,
914 user_id: Uuid,
915 model: Option<&str>,
916 tokens_in: i64,
917 tokens_out: i64,
918 cost_usd: f64,
919 errored: bool,
920 ) -> Result<(), DatabaseError> {
921 accumulate_usage(
922 self, session_id, user_id, model, tokens_in, tokens_out, cost_usd, errored,
923 )
924 .await
925 }
926
927 async fn get_session_row(
928 &self,
929 session_id: &str,
930 user_id: Uuid,
931 permitted_dataset_ids: &[Uuid],
932 prefer_other_owner: bool,
933 ) -> Result<Option<SessionRowWithStatus>, DatabaseError> {
934 get_session_row(
935 self,
936 session_id,
937 user_id,
938 permitted_dataset_ids,
939 prefer_other_owner,
940 )
941 .await
942 }
943
944 async fn list_session_rows(
945 &self,
946 filters: SessionListFilters,
947 ) -> Result<SessionListPage, DatabaseError> {
948 list_session_rows(self, filters).await
949 }
950
951 async fn aggregate_stats(
952 &self,
953 user_id: Uuid,
954 permitted_dataset_ids: &[Uuid],
955 since: Option<DateTime<Utc>>,
956 ) -> Result<SessionStats, DatabaseError> {
957 aggregate_stats(self, user_id, permitted_dataset_ids, since).await
958 }
959
960 async fn cost_by_model(
961 &self,
962 user_id: Uuid,
963 permitted_dataset_ids: &[Uuid],
964 since: Option<DateTime<Utc>>,
965 ) -> Result<Vec<CostByModelRow>, DatabaseError> {
966 cost_by_model(self, user_id, permitted_dataset_ids, since).await
967 }
968}