use std::collections::HashMap;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use super::context::RequestContext;
#[cfg(target_arch = "wasm32")]
use std::cell::RefCell;
#[cfg(target_arch = "wasm32")]
type SessionStateMap = std::collections::HashMap<String, HashMap<String, Value>>;
#[cfg(target_arch = "wasm32")]
thread_local! {
static SESSION_STATE: RefCell<SessionStateMap> = RefCell::new(HashMap::new());
}
#[cfg(not(target_arch = "wasm32"))]
use std::sync::{LazyLock, RwLock};
#[cfg(not(target_arch = "wasm32"))]
static SESSION_STATE: LazyLock<RwLock<HashMap<String, HashMap<String, Value>>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
pub type ProgressCallback = Box<dyn Fn(&str, u64, Option<u64>, Option<&str>) + Send + Sync>;
#[derive(Debug)]
pub struct SessionStateGuard {
session_id: String,
}
impl SessionStateGuard {
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
}
}
pub fn session_id(&self) -> &str {
&self.session_id
}
}
impl Drop for SessionStateGuard {
fn drop(&mut self) {
cleanup_session_state(&self.session_id);
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StateError {
NoSessionId,
SerializationFailed(String),
DeserializationFailed(String),
}
impl std::fmt::Display for StateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoSessionId => write!(f, "no session ID set on context"),
Self::SerializationFailed(e) => write!(f, "serialization failed: {}", e),
Self::DeserializationFailed(e) => write!(f, "deserialization failed: {}", e),
}
}
}
impl std::error::Error for StateError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum LogLevel {
Debug,
Info,
Warning,
Error,
}
pub trait RichContextExt {
fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T>;
fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError>;
fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool;
fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError>;
fn remove_state(&self, key: &str) -> bool;
fn clear_state(&self);
fn has_state(&self, key: &str) -> bool;
fn log_debug(&self, message: impl AsRef<str>);
fn log_info(&self, message: impl AsRef<str>);
fn log_warning(&self, message: impl AsRef<str>);
fn log_error(&self, message: impl AsRef<str>);
fn log(&self, level: LogLevel, message: impl AsRef<str>);
fn report_progress(&self, current: u64, total: u64, message: Option<&str>);
fn report_progress_with_callback(
&self,
current: u64,
total: Option<u64>,
message: Option<&str>,
callback: &ProgressCallback,
);
}
impl RichContextExt for RequestContext {
fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.try_get_state(key).ok().flatten()
}
#[cfg(target_arch = "wasm32")]
fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError> {
let session_id = self.session_id().ok_or(StateError::NoSessionId)?;
SESSION_STATE.with(|state| {
let state = state.borrow();
let Some(session_state) = state.get(session_id) else {
return Ok(None);
};
let Some(value) = session_state.get(key) else {
return Ok(None);
};
serde_json::from_value(value.clone())
.map(Some)
.map_err(|e| StateError::DeserializationFailed(e.to_string()))
})
}
#[cfg(not(target_arch = "wasm32"))]
fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError> {
let session_id = self.session_id().ok_or(StateError::NoSessionId)?;
let state = SESSION_STATE.read().unwrap();
let Some(session_state) = state.get(session_id) else {
return Ok(None);
};
let Some(value) = session_state.get(key) else {
return Ok(None);
};
serde_json::from_value(value.clone())
.map(Some)
.map_err(|e| StateError::DeserializationFailed(e.to_string()))
}
fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool {
self.try_set_state(key, value).is_ok()
}
#[cfg(target_arch = "wasm32")]
fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError> {
let session_id = self.session_id().ok_or(StateError::NoSessionId)?;
let json_value = serde_json::to_value(value)
.map_err(|e| StateError::SerializationFailed(e.to_string()))?;
SESSION_STATE.with(|state| {
let mut state = state.borrow_mut();
let session_state = state.entry(session_id.to_string()).or_default();
session_state.insert(key.to_string(), json_value);
});
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError> {
let session_id = self.session_id().ok_or(StateError::NoSessionId)?;
let json_value = serde_json::to_value(value)
.map_err(|e| StateError::SerializationFailed(e.to_string()))?;
let mut state = SESSION_STATE.write().unwrap();
let session_state = state.entry(session_id.to_string()).or_default();
session_state.insert(key.to_string(), json_value);
Ok(())
}
#[cfg(target_arch = "wasm32")]
fn remove_state(&self, key: &str) -> bool {
let Some(session_id) = self.session_id() else {
return false;
};
SESSION_STATE.with(|state| {
let mut state = state.borrow_mut();
if let Some(session_state) = state.get_mut(session_id) {
session_state.remove(key);
return true;
}
false
})
}
#[cfg(not(target_arch = "wasm32"))]
fn remove_state(&self, key: &str) -> bool {
let Some(session_id) = self.session_id() else {
return false;
};
let mut state = SESSION_STATE.write().unwrap();
if let Some(session_state) = state.get_mut(session_id) {
session_state.remove(key);
return true;
}
false
}
#[cfg(target_arch = "wasm32")]
fn clear_state(&self) {
if let Some(session_id) = self.session_id() {
SESSION_STATE.with(|state| {
let mut state = state.borrow_mut();
if let Some(session_state) = state.get_mut(session_id) {
session_state.clear();
}
});
}
}
#[cfg(not(target_arch = "wasm32"))]
fn clear_state(&self) {
if let Some(session_id) = self.session_id() {
let mut state = SESSION_STATE.write().unwrap();
if let Some(session_state) = state.get_mut(session_id) {
session_state.clear();
}
}
}
#[cfg(target_arch = "wasm32")]
fn has_state(&self, key: &str) -> bool {
let Some(session_id) = self.session_id() else {
return false;
};
SESSION_STATE.with(|state| {
let state = state.borrow();
state
.get(session_id)
.map(|s| s.contains_key(key))
.unwrap_or(false)
})
}
#[cfg(not(target_arch = "wasm32"))]
fn has_state(&self, key: &str) -> bool {
let Some(session_id) = self.session_id() else {
return false;
};
let state = SESSION_STATE.read().unwrap();
state
.get(session_id)
.map(|s| s.contains_key(key))
.unwrap_or(false)
}
fn log_debug(&self, message: impl AsRef<str>) {
self.log(LogLevel::Debug, message);
}
fn log_info(&self, message: impl AsRef<str>) {
self.log(LogLevel::Info, message);
}
fn log_warning(&self, message: impl AsRef<str>) {
self.log(LogLevel::Warning, message);
}
fn log_error(&self, message: impl AsRef<str>) {
self.log(LogLevel::Error, message);
}
#[cfg(target_arch = "wasm32")]
fn log(&self, level: LogLevel, message: impl AsRef<str>) {
let msg = message.as_ref();
let prefix = format!("[{}] ", self.request_id());
let full_msg = format!("{}{}", prefix, msg);
match level {
LogLevel::Debug => web_sys::console::debug_1(&full_msg.into()),
LogLevel::Info => web_sys::console::info_1(&full_msg.into()),
LogLevel::Warning => web_sys::console::warn_1(&full_msg.into()),
LogLevel::Error => web_sys::console::error_1(&full_msg.into()),
}
}
#[cfg(not(target_arch = "wasm32"))]
fn log(&self, level: LogLevel, message: impl AsRef<str>) {
let msg = message.as_ref();
let prefix = format!("[{}] ", self.request_id());
let level_str = match level {
LogLevel::Debug => "DEBUG",
LogLevel::Info => "INFO",
LogLevel::Warning => "WARN",
LogLevel::Error => "ERROR",
};
println!("{} {}{}", level_str, prefix, msg);
}
fn report_progress(&self, current: u64, total: u64, message: Option<&str>) {
let percentage = if total > 0 {
(current as f64 / total as f64 * 100.0) as u32
} else {
0
};
let msg = match message {
Some(m) => format!("Progress: {}% ({}/{}) - {}", percentage, current, total, m),
None => format!("Progress: {}% ({}/{})", percentage, current, total),
};
self.log_info(msg);
}
fn report_progress_with_callback(
&self,
current: u64,
total: Option<u64>,
message: Option<&str>,
callback: &ProgressCallback,
) {
callback(self.request_id(), current, total, message);
}
}
#[cfg(target_arch = "wasm32")]
pub fn cleanup_session_state(session_id: &str) {
SESSION_STATE.with(|state| {
state.borrow_mut().remove(session_id);
});
}
#[cfg(not(target_arch = "wasm32"))]
pub fn cleanup_session_state(session_id: &str) {
SESSION_STATE.write().unwrap().remove(session_id);
}
#[cfg(target_arch = "wasm32")]
pub fn active_sessions_count() -> usize {
SESSION_STATE.with(|state| state.borrow().len())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn active_sessions_count() -> usize {
SESSION_STATE.read().unwrap().len()
}
#[cfg(test)]
#[allow(dead_code)]
fn clear_all_session_state() {
#[cfg(not(target_arch = "wasm32"))]
{
SESSION_STATE.write().unwrap().clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cleanup_test_sessions() {
cleanup_session_state("test-session-1");
cleanup_session_state("test-session-2");
cleanup_session_state("session-iso-1");
cleanup_session_state("session-iso-2");
cleanup_session_state("complex-session-1");
cleanup_session_state("guard-test-session");
cleanup_session_state("error-test-session");
cleanup_session_state("logging-test");
cleanup_session_state("progress-test");
}
#[test]
fn test_get_set_state() {
cleanup_test_sessions();
let ctx = RequestContext::new().with_session_id("test-session-1");
assert!(ctx.set_state("counter", &42i32));
assert!(ctx.set_state("name", &"Alice".to_string()));
assert_eq!(ctx.get_state::<i32>("counter"), Some(42));
assert_eq!(ctx.get_state::<String>("name"), Some("Alice".to_string()));
assert_eq!(ctx.get_state::<i32>("missing"), None);
assert!(ctx.has_state("counter"));
assert!(!ctx.has_state("missing"));
assert!(ctx.remove_state("counter"));
assert_eq!(ctx.get_state::<i32>("counter"), None);
assert!(!ctx.has_state("counter"));
ctx.clear_state();
assert_eq!(ctx.get_state::<String>("name"), None);
cleanup_session_state("test-session-1");
}
#[test]
fn test_state_without_session() {
let ctx = RequestContext::new();
assert!(!ctx.set_state("key", &"value"));
assert_eq!(ctx.get_state::<String>("key"), None);
assert!(!ctx.has_state("key"));
assert_eq!(
ctx.try_set_state("key", &"value"),
Err(StateError::NoSessionId)
);
assert_eq!(
ctx.try_get_state::<String>("key"),
Err(StateError::NoSessionId)
);
}
#[test]
fn test_state_isolation() {
cleanup_test_sessions();
let ctx1 = RequestContext::new().with_session_id("session-iso-1");
let ctx2 = RequestContext::new().with_session_id("session-iso-2");
ctx1.set_state("value", &1i32);
ctx2.set_state("value", &2i32);
assert_eq!(ctx1.get_state::<i32>("value"), Some(1));
assert_eq!(ctx2.get_state::<i32>("value"), Some(2));
cleanup_session_state("session-iso-1");
cleanup_session_state("session-iso-2");
}
#[test]
fn test_complex_types() {
cleanup_test_sessions();
let ctx = RequestContext::new().with_session_id("complex-session-1");
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct MyData {
count: i32,
items: Vec<String>,
}
let data = MyData {
count: 3,
items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
};
ctx.set_state("data", &data);
let retrieved: Option<MyData> = ctx.get_state("data");
assert_eq!(retrieved, Some(data));
cleanup_session_state("complex-session-1");
}
#[test]
fn test_session_state_guard() {
cleanup_test_sessions();
let session_id = "guard-test-session";
{
let _guard = SessionStateGuard::new(session_id);
let ctx = RequestContext::new().with_session_id(session_id);
ctx.set_state("key", &"value");
assert_eq!(ctx.get_state::<String>("key"), Some("value".to_string()));
assert!(active_sessions_count() > 0 || ctx.has_state("key"));
}
let ctx = RequestContext::new().with_session_id(session_id);
assert_eq!(ctx.get_state::<String>("key"), None);
}
#[test]
fn test_try_get_state_errors() {
cleanup_test_sessions();
let ctx = RequestContext::new().with_session_id("error-test-session");
ctx.set_state("number", &42i32);
let result: Result<Option<String>, StateError> = ctx.try_get_state("number");
assert!(matches!(result, Err(StateError::DeserializationFailed(_))));
cleanup_session_state("error-test-session");
}
#[test]
fn test_state_error_display() {
assert_eq!(
StateError::NoSessionId.to_string(),
"no session ID set on context"
);
assert!(
StateError::SerializationFailed("test".into())
.to_string()
.contains("serialization failed")
);
assert!(
StateError::DeserializationFailed("test".into())
.to_string()
.contains("deserialization failed")
);
}
#[test]
fn test_logging() {
let ctx = RequestContext::new().with_session_id("logging-test");
ctx.log_debug("debug message");
ctx.log_info("info message");
ctx.log_warning("warning message");
ctx.log_error("error message");
ctx.log(LogLevel::Info, "custom level message");
}
#[test]
fn test_progress_reporting() {
let ctx = RequestContext::new().with_session_id("progress-test");
ctx.report_progress(50, 100, Some("halfway"));
ctx.report_progress(100, 100, None);
let calls = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let calls_clone = calls.clone();
let callback: ProgressCallback = Box::new(move |token, current, total, message| {
calls_clone.lock().unwrap().push((
token.to_string(),
current,
total,
message.map(String::from),
));
});
ctx.report_progress_with_callback(25, Some(100), Some("processing"), &callback);
let recorded = calls.lock().unwrap();
assert_eq!(recorded.len(), 1);
assert_eq!(recorded[0].1, 25);
assert_eq!(recorded[0].2, Some(100));
assert_eq!(recorded[0].3, Some("processing".to_string()));
}
#[test]
fn test_log_level_ordering() {
assert!(LogLevel::Debug < LogLevel::Info);
assert!(LogLevel::Info < LogLevel::Warning);
assert!(LogLevel::Warning < LogLevel::Error);
}
}