1use crate::agents::{LoadSession, SaveSession};
8use crate::auth::session::{SessionData, SessionId};
9use crate::state::ActonHtmxState;
10use acton_reactive::prelude::{AgentHandle, AgentHandleInterface};
11use axum::{
12 body::Body,
13 extract::Request,
14 http::header::{COOKIE, SET_COOKIE},
15 response::Response,
16};
17use std::str::FromStr;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21use tower::{Layer, Service};
22
23pub const SESSION_COOKIE_NAME: &str = "acton_session";
25
26#[derive(Clone, Debug)]
28pub struct SessionConfig {
29 pub cookie_name: String,
31 pub cookie_path: String,
33 pub http_only: bool,
35 pub secure: bool,
37 pub same_site: SameSite,
39 pub max_age_secs: u64,
41 pub agent_timeout_ms: u64,
43}
44
45impl Default for SessionConfig {
46 fn default() -> Self {
47 Self {
48 cookie_name: SESSION_COOKIE_NAME.to_string(),
49 cookie_path: "/".to_string(),
50 http_only: true,
51 secure: !cfg!(debug_assertions),
52 same_site: SameSite::Lax,
53 max_age_secs: 86400, agent_timeout_ms: 100,
55 }
56 }
57}
58
59#[derive(Clone, Copy, Debug, Default)]
61pub enum SameSite {
62 Strict,
64 #[default]
66 Lax,
67 None,
69}
70
71impl SameSite {
72 #[must_use]
74 pub const fn as_str(self) -> &'static str {
75 match self {
76 Self::Strict => "Strict",
77 Self::Lax => "Lax",
78 Self::None => "None",
79 }
80 }
81}
82
83#[derive(Clone)]
88pub struct SessionLayer {
89 config: SessionConfig,
90 session_manager: AgentHandle,
91}
92
93impl std::fmt::Debug for SessionLayer {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("SessionLayer")
96 .field("config", &self.config)
97 .field("session_manager", &"AgentHandle")
98 .finish()
99 }
100}
101
102impl SessionLayer {
103 #[must_use]
105 pub fn new(state: &ActonHtmxState) -> Self {
106 Self {
107 config: SessionConfig::default(),
108 session_manager: state.session_manager().clone(),
109 }
110 }
111
112 #[must_use]
114 pub fn with_config(state: &ActonHtmxState, config: SessionConfig) -> Self {
115 Self {
116 config,
117 session_manager: state.session_manager().clone(),
118 }
119 }
120
121 #[must_use]
123 pub fn from_handle(session_manager: AgentHandle) -> Self {
124 Self {
125 config: SessionConfig::default(),
126 session_manager,
127 }
128 }
129}
130
131impl<S> Layer<S> for SessionLayer {
132 type Service = SessionMiddleware<S>;
133
134 fn layer(&self, inner: S) -> Self::Service {
135 SessionMiddleware {
136 inner,
137 config: Arc::new(self.config.clone()),
138 session_manager: self.session_manager.clone(),
139 }
140 }
141}
142
143#[derive(Clone)]
148pub struct SessionMiddleware<S> {
149 inner: S,
150 config: Arc<SessionConfig>,
151 session_manager: AgentHandle,
152}
153
154impl<S: std::fmt::Debug> std::fmt::Debug for SessionMiddleware<S> {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_struct("SessionMiddleware")
157 .field("inner", &self.inner)
158 .field("config", &self.config)
159 .field("session_manager", &"AgentHandle")
160 .finish()
161 }
162}
163
164impl<S> Service<Request> for SessionMiddleware<S>
165where
166 S: Service<Request, Response = Response<Body>> + Clone + Send + 'static,
167 S::Future: Send + 'static,
168{
169 type Response = Response<Body>;
170 type Error = S::Error;
171 type Future = std::pin::Pin<
172 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
173 >;
174
175 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176 self.inner.poll_ready(cx)
177 }
178
179 fn call(&mut self, mut req: Request) -> Self::Future {
180 let config = self.config.clone();
181 let session_manager = self.session_manager.clone();
182 let mut inner = self.inner.clone();
183 let timeout = Duration::from_millis(config.agent_timeout_ms);
184
185 Box::pin(async move {
186 let existing_session_id = extract_session_id(&req, &config.cookie_name);
188
189 let (session_id, session_data, is_new) = if let Some(id) = existing_session_id {
191 let (request, rx) = LoadSession::with_response(id.clone());
193 session_manager.send(request).await;
194
195 if let Ok(Ok(Some(data))) = tokio::time::timeout(timeout, rx).await {
197 (id, data, false)
198 } else {
199 let new_id = SessionId::generate();
201 (new_id, SessionData::new(), true)
202 }
203 } else {
204 let id = SessionId::generate();
206 (id, SessionData::new(), true)
207 };
208
209 req.extensions_mut().insert(session_id.clone());
211 req.extensions_mut().insert(session_data.clone());
212
213 let mut response = inner.call(req).await?;
215
216 let final_session_data = response
219 .extensions()
220 .get::<SessionData>()
221 .cloned()
222 .unwrap_or(session_data);
223
224 let save_request = SaveSession::new(session_id.clone(), final_session_data);
226 session_manager.send(save_request).await;
227
228 if is_new {
230 set_session_cookie(&mut response, &session_id, &config);
231 }
232
233 Ok(response)
234 })
235 }
236}
237
238fn extract_session_id(req: &Request, cookie_name: &str) -> Option<SessionId> {
240 let cookie_header = req.headers().get(COOKIE)?;
241 let cookie_str = cookie_header.to_str().ok()?;
242
243 for cookie in cookie_str.split(';') {
245 let cookie = cookie.trim();
246 if let Some((name, value)) = cookie.split_once('=') {
247 if name.trim() == cookie_name {
248 return SessionId::from_str(value.trim()).ok();
249 }
250 }
251 }
252
253 None
254}
255
256fn set_session_cookie(
258 response: &mut Response<Body>,
259 session_id: &SessionId,
260 config: &SessionConfig,
261) {
262 let mut cookie_value = format!(
263 "{}={}; Path={}; Max-Age={}; SameSite={}",
264 config.cookie_name,
265 session_id.as_str(),
266 config.cookie_path,
267 config.max_age_secs,
268 config.same_site.as_str()
269 );
270
271 if config.http_only {
272 cookie_value.push_str("; HttpOnly");
273 }
274
275 if config.secure {
276 cookie_value.push_str("; Secure");
277 }
278
279 if let Ok(header_value) = cookie_value.parse() {
280 response.headers_mut().append(SET_COOKIE, header_value);
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_session_config_default() {
290 let config = SessionConfig::default();
291 assert_eq!(config.cookie_name, SESSION_COOKIE_NAME);
292 assert!(config.http_only);
293 assert_eq!(config.max_age_secs, 86400);
294 }
295
296 #[test]
297 fn test_same_site_as_str() {
298 assert_eq!(SameSite::Strict.as_str(), "Strict");
299 assert_eq!(SameSite::Lax.as_str(), "Lax");
300 assert_eq!(SameSite::None.as_str(), "None");
301 }
302}