Skip to main content

forge_runtime/signals/
session.rs

1//! Server-side session management for the signals pipeline.
2//!
3//! Sessions are created when the first event arrives for a visitor_id.
4//! They are closed after `session_timeout_mins` of inactivity.
5
6use std::sync::Arc;
7
8use sqlx::PgPool;
9use tracing::{debug, error};
10use uuid::Uuid;
11
12/// Create or update a session for the given visitor.
13///
14/// Returns the session ID (existing or newly created).
15#[allow(clippy::too_many_arguments)]
16pub async fn upsert_session(
17    pool: &PgPool,
18    session_id: Option<Uuid>,
19    visitor_id: &str,
20    user_id: Option<Uuid>,
21    tenant_id: Option<Uuid>,
22    page_url: Option<&str>,
23    referrer: Option<&str>,
24    user_agent: Option<&str>,
25    client_ip: Option<&str>,
26    is_bot: bool,
27    event_type: &str,
28    device_type: Option<&str>,
29    browser: Option<&str>,
30    os: Option<&str>,
31) -> Option<Uuid> {
32    if let Some(sid) = session_id {
33        let is_page_view = event_type == "page_view";
34        let is_error = event_type == "error";
35        let is_rpc = event_type == "rpc_call";
36
37        let result = sqlx::query(
38            "UPDATE forge_signals_sessions SET
39                last_activity_at = NOW(),
40                event_count = event_count + 1,
41                page_view_count = page_view_count + CASE WHEN $2 THEN 1 ELSE 0 END,
42                rpc_call_count = rpc_call_count + CASE WHEN $3 THEN 1 ELSE 0 END,
43                error_count = error_count + CASE WHEN $4 THEN 1 ELSE 0 END,
44                exit_page = COALESCE($5, exit_page),
45                user_id = COALESCE(user_id, $6),
46                is_bounce = CASE WHEN page_view_count + CASE WHEN $2 THEN 1 ELSE 0 END > 1 THEN FALSE ELSE is_bounce END
47            WHERE id = $1",
48        )
49        .bind(sid)
50        .bind(is_page_view)
51        .bind(is_rpc)
52        .bind(is_error)
53        .bind(page_url)
54        .bind(user_id)
55        .execute(pool)
56        .await;
57
58        match result {
59            Ok(r) if r.rows_affected() > 0 => return Some(sid),
60            Ok(_) => {} // Session not found, create new one below
61            Err(e) => {
62                error!(error = %e, "failed to update signal session");
63                return Some(sid);
64            }
65        }
66    }
67
68    let new_id = Uuid::new_v4();
69    let referrer_domain = referrer.and_then(extract_domain);
70
71    let result = sqlx::query(
72        "INSERT INTO forge_signals_sessions (
73            id, visitor_id, user_id, tenant_id,
74            entry_page, exit_page,
75            referrer, referrer_domain,
76            user_agent, client_ip,
77            device_type, browser, os,
78            is_bot, event_count, page_view_count, rpc_call_count, error_count
79        ) VALUES ($1, $2, $3, $4, $5, $5, $6, $7, $8, $9, $10, $11, $12, $13, 1,
80            CASE WHEN $14 = 'page_view' THEN 1 ELSE 0 END,
81            CASE WHEN $14 = 'rpc_call' THEN 1 ELSE 0 END,
82            CASE WHEN $14 = 'error' THEN 1 ELSE 0 END
83        )",
84    )
85    .bind(new_id)
86    .bind(visitor_id)
87    .bind(user_id)
88    .bind(tenant_id)
89    .bind(page_url)
90    .bind(referrer)
91    .bind(referrer_domain)
92    .bind(user_agent)
93    .bind(client_ip)
94    .bind(device_type)
95    .bind(browser)
96    .bind(os)
97    .bind(is_bot)
98    .bind(event_type)
99    .execute(pool)
100    .await;
101
102    match result {
103        Ok(_) => {
104            debug!(session_id = %new_id, visitor_id, "created signal session");
105            Some(new_id)
106        }
107        Err(e) => {
108            error!(error = %e, "failed to create signal session");
109            None
110        }
111    }
112}
113
114/// Close stale sessions that have been inactive longer than the timeout.
115pub async fn close_stale_sessions(pool: &PgPool, timeout_mins: u32) {
116    let result = sqlx::query(
117        "UPDATE forge_signals_sessions SET
118            ended_at = NOW(),
119            duration_secs = EXTRACT(EPOCH FROM NOW() - started_at)::integer
120        WHERE ended_at IS NULL
121        AND last_activity_at < NOW() - ($1 || ' minutes')::interval",
122    )
123    .bind(timeout_mins as i32)
124    .execute(pool)
125    .await;
126
127    match result {
128        Ok(r) if r.rows_affected() > 0 => {
129            debug!(count = r.rows_affected(), "closed stale signal sessions");
130        }
131        Ok(_) => {}
132        Err(e) => error!(error = %e, "failed to close stale signal sessions"),
133    }
134}
135
136/// Link a user_id to an existing session (on identify).
137pub async fn identify_session(pool: &PgPool, session_id: Uuid, user_id: Uuid) {
138    let result = sqlx::query(
139        "UPDATE forge_signals_sessions SET user_id = $2 WHERE id = $1 AND user_id IS NULL",
140    )
141    .bind(session_id)
142    .bind(user_id)
143    .execute(pool)
144    .await;
145
146    if let Err(e) = result {
147        error!(error = %e, "failed to identify signal session");
148    }
149}
150
151/// Upsert user in forge_signals_users on identify().
152pub async fn upsert_user(
153    pool: &PgPool,
154    user_id: Uuid,
155    traits: &serde_json::Value,
156    referrer: Option<&str>,
157    utm_source: Option<&str>,
158    utm_medium: Option<&str>,
159    utm_campaign: Option<&str>,
160) {
161    let referrer_domain = referrer.and_then(extract_domain);
162
163    let result = sqlx::query(
164        "INSERT INTO forge_signals_users (
165            id, first_referrer, first_referrer_domain,
166            first_utm_source, first_utm_medium, first_utm_campaign,
167            traits, total_sessions, total_events
168        ) VALUES ($1, $2, $3, $4, $5, $6, $7, 1, 1)
169        ON CONFLICT (id) DO UPDATE SET
170            last_seen_at = NOW(),
171            total_events = forge_signals_users.total_events + 1,
172            traits = forge_signals_users.traits || $7,
173            updated_at = NOW()",
174    )
175    .bind(user_id)
176    .bind(referrer)
177    .bind(referrer_domain)
178    .bind(utm_source)
179    .bind(utm_medium)
180    .bind(utm_campaign)
181    .bind(traits)
182    .execute(pool)
183    .await;
184
185    if let Err(e) = result {
186        error!(error = %e, "failed to upsert signal user");
187    }
188}
189
190/// Spawn a background task that periodically closes stale sessions.
191pub fn spawn_session_reaper(pool: Arc<PgPool>, timeout_mins: u32) {
192    tokio::spawn(async move {
193        // Delay first run to avoid DB pool contention during startup.
194        tokio::time::sleep(std::time::Duration::from_secs(60)).await;
195        let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
196        loop {
197            interval.tick().await;
198            close_stale_sessions(&pool, timeout_mins).await;
199        }
200    });
201}
202
203/// Extract the domain from a URL string (e.g. "https://google.com/search" -> "google.com").
204fn extract_domain(url: &str) -> Option<String> {
205    // Strip scheme
206    let without_scheme = url
207        .strip_prefix("https://")
208        .or_else(|| url.strip_prefix("http://"))
209        .unwrap_or(url);
210
211    // Take everything before the first /
212    let domain = without_scheme.split('/').next()?;
213    // Strip port
214    let domain = domain.split(':').next()?;
215
216    if domain.is_empty() {
217        None
218    } else {
219        Some(domain.to_lowercase())
220    }
221}
222
223#[cfg(test)]
224#[allow(clippy::unwrap_used)]
225mod tests {
226    use super::*;
227
228    #[tokio::test]
229    async fn extracts_domain_from_url() {
230        assert_eq!(
231            extract_domain("https://google.com/search"),
232            Some("google.com".into())
233        );
234        assert_eq!(
235            extract_domain("http://example.com:8080/path"),
236            Some("example.com".into())
237        );
238        assert_eq!(
239            extract_domain("https://Sub.Domain.COM/"),
240            Some("sub.domain.com".into())
241        );
242    }
243
244    #[tokio::test]
245    async fn handles_edge_cases() {
246        assert_eq!(extract_domain(""), None);
247        assert_eq!(extract_domain("not-a-url"), Some("not-a-url".into()));
248    }
249}