Skip to main content

stygian_proxy/
session.rs

1//! Domain-scoped proxy session stickiness.
2//!
3//! A *sticky session* binds a target domain to a specific proxy for a
4//! configurable TTL. Requests to the same domain reuse the same proxy,
5//! preserving IP consistency for anti-bot fingerprint checks while still
6//! rotating across different domains.
7//!
8//! # Example
9//!
10//! ```
11//! use stygian_proxy::session::{SessionMap, StickyPolicy};
12//! use std::time::Duration;
13//! use uuid::Uuid;
14//!
15//! let map = SessionMap::new();
16//! let ttl = Duration::from_secs(300);
17//! let proxy_id = Uuid::new_v4();
18//!
19//! map.bind("example.com", proxy_id, ttl);
20//! assert_eq!(map.lookup("example.com"), Some(proxy_id));
21//! ```
22
23use std::collections::HashMap;
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26
27use serde::{Deserialize, Serialize};
28use tokio::sync::RwLock;
29use uuid::Uuid;
30
31/// Default session TTL: 5 minutes.
32const DEFAULT_TTL_SECS: u64 = 300;
33
34// ── StickyPolicy ─────────────────────────────────────────────────────────────
35
36/// Policy controlling when and how proxy sessions are pinned to a key.
37///
38/// # Example
39///
40/// ```
41/// use stygian_proxy::session::StickyPolicy;
42/// use std::time::Duration;
43///
44/// let policy = StickyPolicy::domain(Duration::from_secs(600));
45/// assert!(!policy.is_disabled());
46/// ```
47#[derive(Debug, Clone, Default, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case", tag = "mode")]
49#[non_exhaustive]
50pub enum StickyPolicy {
51    /// No session stickiness — every request may use a different proxy.
52    #[default]
53    Disabled,
54    /// Pin by domain with a fixed TTL per binding.
55    Domain {
56        /// How long a domain→proxy binding remains valid.
57        #[serde(with = "serde_duration_secs")]
58        ttl: Duration,
59    },
60}
61
62impl StickyPolicy {
63    /// Create a domain-scoped policy with the given TTL.
64    pub const fn domain(ttl: Duration) -> Self {
65        Self::Domain { ttl }
66    }
67
68    /// Create a domain-scoped policy with the default TTL (5 minutes).
69    pub fn domain_default() -> Self {
70        Self::Domain {
71            ttl: Duration::from_secs(DEFAULT_TTL_SECS),
72        }
73    }
74
75    /// Returns `true` when session stickiness is turned off.
76    pub const fn is_disabled(&self) -> bool {
77        matches!(self, Self::Disabled)
78    }
79}
80
81// ── ProxySession ─────────────────────────────────────────────────────────────
82
83/// A single domain→proxy binding with an expiration deadline.
84#[derive(Debug, Clone)]
85struct ProxySession {
86    /// The proxy this session is bound to.
87    proxy_id: Uuid,
88    /// When this session was created.
89    bound_at: Instant,
90    /// How long the binding is valid.
91    ttl: Duration,
92}
93
94impl ProxySession {
95    /// Returns `true` when `bound_at + ttl` has elapsed.
96    fn is_expired(&self) -> bool {
97        self.bound_at.elapsed() >= self.ttl
98    }
99}
100
101// ── SessionMap ───────────────────────────────────────────────────────────────
102
103/// Thread-safe map of session keys (typically domains) to proxy bindings.
104///
105/// All operations acquire short-lived locks to minimise contention.
106///
107/// # Example
108///
109/// ```
110/// use stygian_proxy::session::SessionMap;
111/// use std::time::Duration;
112/// use uuid::Uuid;
113///
114/// let map = SessionMap::new();
115/// let id = Uuid::new_v4();
116/// map.bind("example.com", id, Duration::from_secs(60));
117/// assert_eq!(map.lookup("example.com"), Some(id));
118/// ```
119#[derive(Debug, Clone)]
120pub struct SessionMap {
121    inner: Arc<RwLock<HashMap<String, ProxySession>>>,
122}
123
124impl Default for SessionMap {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl SessionMap {
131    /// Create an empty session map.
132    pub fn new() -> Self {
133        Self {
134            inner: Arc::new(RwLock::new(HashMap::new())),
135        }
136    }
137
138    /// Look up the proxy bound to `key`, returning `None` when no session
139    /// exists or the existing session has expired.
140    ///
141    /// Expired entries are lazily removed on the next [`bind`](Self::bind)
142    /// or [`purge_expired`](Self::purge_expired) call.
143    pub fn lookup(&self, key: &str) -> Option<Uuid> {
144        // try_read avoids blocking if a write is in progress.
145        let guard = self.inner.try_read().ok()?;
146        guard
147            .get(key)
148            .filter(|s| !s.is_expired())
149            .map(|s| s.proxy_id)
150    }
151
152    /// Bind `key` to `proxy_id` with the given TTL. Overwrites any existing
153    /// session for the same key.
154    pub fn bind(&self, key: &str, proxy_id: Uuid, ttl: Duration) {
155        let session = ProxySession {
156            proxy_id,
157            bound_at: Instant::now(),
158            ttl,
159        };
160        if let Ok(mut guard) = self.inner.try_write() {
161            guard.insert(key.to_string(), session);
162        }
163    }
164
165    /// Remove all expired sessions, returning the number removed.
166    pub fn purge_expired(&self) -> usize {
167        let Ok(mut guard) = self.inner.try_write() else {
168            return 0;
169        };
170        let before = guard.len();
171        guard.retain(|_, s| !s.is_expired());
172        before - guard.len()
173    }
174
175    /// Remove a specific session by key.
176    pub fn unbind(&self, key: &str) {
177        if let Ok(mut guard) = self.inner.try_write() {
178            guard.remove(key);
179        }
180    }
181
182    /// Returns the number of active (non-expired) sessions.
183    pub fn active_count(&self) -> usize {
184        let Ok(guard) = self.inner.try_read() else {
185            return 0;
186        };
187        guard.values().filter(|s| !s.is_expired()).count()
188    }
189}
190
191// ── serde helper ─────────────────────────────────────────────────────────────
192
193mod serde_duration_secs {
194    use serde::{Deserialize, Deserializer, Serialize, Serializer};
195    use std::time::Duration;
196
197    pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
198        d.as_secs().serialize(s)
199    }
200
201    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
202        Ok(Duration::from_secs(u64::deserialize(d)?))
203    }
204}
205
206// ── tests ────────────────────────────────────────────────────────────────────
207
208#[cfg(test)]
209#[allow(clippy::unwrap_used)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn same_domain_returns_same_proxy() {
215        let map = SessionMap::new();
216        let id = Uuid::new_v4();
217        map.bind("example.com", id, Duration::from_secs(60));
218        assert_eq!(map.lookup("example.com"), Some(id));
219        assert_eq!(map.lookup("example.com"), Some(id));
220    }
221
222    #[test]
223    fn different_domains_independent() {
224        let map = SessionMap::new();
225        let id_a = Uuid::new_v4();
226        let id_b = Uuid::new_v4();
227        map.bind("a.com", id_a, Duration::from_secs(60));
228        map.bind("b.com", id_b, Duration::from_secs(60));
229        assert_eq!(map.lookup("a.com"), Some(id_a));
230        assert_eq!(map.lookup("b.com"), Some(id_b));
231    }
232
233    #[test]
234    fn expired_session_returns_none() {
235        let map = SessionMap::new();
236        let id = Uuid::new_v4();
237        // TTL of 0 means it expires immediately.
238        map.bind("example.com", id, Duration::ZERO);
239        // Spin-wait a tiny bit to ensure the instant has elapsed.
240        std::thread::sleep(Duration::from_millis(1));
241        assert_eq!(map.lookup("example.com"), None);
242    }
243
244    #[test]
245    fn purge_removes_expired() {
246        let map = SessionMap::new();
247        map.bind("expired.com", Uuid::new_v4(), Duration::ZERO);
248        map.bind("active.com", Uuid::new_v4(), Duration::from_secs(300));
249        std::thread::sleep(Duration::from_millis(1));
250
251        let removed = map.purge_expired();
252        assert_eq!(removed, 1);
253        assert_eq!(map.active_count(), 1);
254    }
255
256    #[test]
257    fn unbind_removes_session() {
258        let map = SessionMap::new();
259        map.bind("example.com", Uuid::new_v4(), Duration::from_secs(60));
260        map.unbind("example.com");
261        assert_eq!(map.lookup("example.com"), None);
262    }
263
264    #[test]
265    fn rebind_overwrites_previous() {
266        let map = SessionMap::new();
267        let old_id = Uuid::new_v4();
268        let new_id = Uuid::new_v4();
269        map.bind("example.com", old_id, Duration::from_secs(60));
270        map.bind("example.com", new_id, Duration::from_secs(60));
271        assert_eq!(map.lookup("example.com"), Some(new_id));
272    }
273
274    #[test]
275    fn policy_domain_default_ttl() {
276        let policy = StickyPolicy::domain_default();
277        match policy {
278            StickyPolicy::Domain { ttl } => {
279                assert_eq!(ttl, Duration::from_secs(300));
280            }
281            _ => panic!("expected Domain variant"),
282        }
283    }
284
285    #[test]
286    fn policy_disabled_by_default() {
287        let policy = StickyPolicy::default();
288        assert!(policy.is_disabled());
289    }
290
291    #[test]
292    fn policy_serde_roundtrip() {
293        let policy = StickyPolicy::domain(Duration::from_secs(120));
294        let json = serde_json::to_string(&policy).unwrap();
295        let back: StickyPolicy = serde_json::from_str(&json).unwrap();
296        match back {
297            StickyPolicy::Domain { ttl } => assert_eq!(ttl, Duration::from_secs(120)),
298            _ => panic!("expected Domain variant"),
299        }
300    }
301}