use super::tables::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug)]
pub struct VariableStateManager {
variables: Arc<RwLock<HashMap<String, Value>>>,
max_variables: usize,
max_name_length: usize,
}
impl VariableStateManager {
pub fn new() -> Self {
Self {
variables: Arc::new(RwLock::new(HashMap::new())),
max_variables: 1000, max_name_length: 256, }
}
pub fn with_limits(max_variables: usize, max_name_length: usize) -> Self {
Self {
variables: Arc::new(RwLock::new(HashMap::new())),
max_variables,
max_name_length,
}
}
pub async fn set_variable(&self, name: String, value: Value) -> Result<bool, VariableError> {
if name.is_empty() {
return Err(VariableError::InvalidName(
"Variable name cannot be empty".to_string(),
));
}
if name.len() > self.max_name_length {
return Err(VariableError::InvalidName(format!(
"Variable name too long: {} > {}",
name.len(),
self.max_name_length
)));
}
if name.chars().any(|c| c.is_control() || c == '\0') {
return Err(VariableError::InvalidName(
"Variable name contains invalid characters".to_string(),
));
}
let mut variables = self.variables.write().await;
if !variables.contains_key(&name) && variables.len() >= self.max_variables {
return Err(VariableError::TooManyVariables(self.max_variables));
}
self.validate_value(&value)?;
tracing::debug!(
name = %name,
value_type = ?std::mem::discriminant(&value),
"Setting variable"
);
variables.insert(name, value);
Ok(true)
}
pub async fn get_variable(&self, name: &str) -> Result<Value, VariableError> {
let variables = self.variables.read().await;
match variables.get(name) {
Some(value) => {
tracing::debug!(
name = %name,
value_type = ?std::mem::discriminant(value),
"Retrieved variable"
);
Ok(value.clone())
}
None => Err(VariableError::VariableNotFound(name.to_string())),
}
}
pub async fn has_variable(&self, name: &str) -> bool {
let variables = self.variables.read().await;
variables.contains_key(name)
}
pub async fn delete_variable(&self, name: &str) -> Result<bool, VariableError> {
let mut variables = self.variables.write().await;
match variables.remove(name) {
Some(_) => {
tracing::debug!(name = %name, "Variable deleted");
Ok(true)
}
None => Err(VariableError::VariableNotFound(name.to_string())),
}
}
pub async fn clear_all_variables(&self) -> Result<bool, VariableError> {
let mut variables = self.variables.write().await;
let count = variables.len();
variables.clear();
tracing::debug!(cleared_count = count, "All variables cleared");
Ok(true)
}
pub async fn get_variable_names(&self) -> Vec<String> {
let variables = self.variables.read().await;
variables.keys().cloned().collect()
}
pub async fn variable_count(&self) -> usize {
let variables = self.variables.read().await;
variables.len()
}
pub async fn export_variables(&self) -> HashMap<String, Value> {
let variables = self.variables.read().await;
variables.clone()
}
pub async fn import_variables(
&self,
vars: HashMap<String, Value>,
) -> Result<(), VariableError> {
if vars.len() > self.max_variables {
return Err(VariableError::TooManyVariables(self.max_variables));
}
for (name, value) in &vars {
if name.len() > self.max_name_length {
return Err(VariableError::InvalidName(format!(
"Variable name too long: {} > {}",
name.len(),
self.max_name_length
)));
}
self.validate_value(value)?;
}
let mut variables = self.variables.write().await;
variables.clear();
variables.extend(vars);
tracing::debug!(imported_count = variables.len(), "Variables imported");
Ok(())
}
fn validate_value(&self, value: &Value) -> Result<(), VariableError> {
match value {
Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) | Value::Date(_) => {
Ok(())
}
Value::Array(arr) => {
if arr.len() > 1000 {
return Err(VariableError::ValueTooComplex(
"Array too large".to_string(),
));
}
for item in arr {
self.validate_value(item)?;
}
Ok(())
}
Value::Object(obj) => {
if obj.len() > 100 {
return Err(VariableError::ValueTooComplex(
"Object too complex".to_string(),
));
}
for (key, val) in obj {
if key.len() > self.max_name_length {
return Err(VariableError::ValueTooComplex(
"Object key too long".to_string(),
));
}
self.validate_value(val)?;
}
Ok(())
}
Value::Error { .. } => Ok(()),
Value::Stub(_) => Err(VariableError::UnsupportedValueType(
"Cannot store stub as variable".to_string(),
)),
Value::Promise(_) => Err(VariableError::UnsupportedValueType(
"Cannot store promise as variable".to_string(),
)),
}
}
}
impl Default for VariableStateManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum VariableError {
#[error("Invalid variable name: {0}")]
InvalidName(String),
#[error("Variable not found: {0}")]
VariableNotFound(String),
#[error("Too many variables (limit: {0})")]
TooManyVariables(usize),
#[error("Unsupported value type: {0}")]
UnsupportedValueType(String),
#[error("Value too complex: {0}")]
ValueTooComplex(String),
}
#[async_trait::async_trait]
pub trait VariableCapableRpcTarget: Send + Sync {
async fn set_variable(&self, name: String, value: Value) -> Result<Value, crate::RpcError>;
async fn get_variable(&self, name: String) -> Result<Value, crate::RpcError>;
async fn clear_all_variables(&self) -> Result<Value, crate::RpcError>;
async fn has_variable(&self, name: String) -> Result<Value, crate::RpcError>;
async fn list_variables(&self) -> Result<Value, crate::RpcError>;
async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, crate::RpcError>;
}
#[derive(Debug)]
pub struct DefaultVariableCapableTarget {
variable_manager: VariableStateManager,
delegate: Arc<dyn crate::RpcTarget>, }
impl DefaultVariableCapableTarget {
pub fn new(delegate: Arc<dyn crate::RpcTarget>) -> Self {
Self {
variable_manager: VariableStateManager::new(),
delegate,
}
}
pub fn with_variable_limits(
delegate: Arc<dyn crate::RpcTarget>,
max_variables: usize,
max_name_length: usize,
) -> Self {
Self {
variable_manager: VariableStateManager::with_limits(max_variables, max_name_length),
delegate,
}
}
}
#[async_trait::async_trait]
impl VariableCapableRpcTarget for DefaultVariableCapableTarget {
async fn set_variable(&self, name: String, value: Value) -> Result<Value, crate::RpcError> {
let result = self
.variable_manager
.set_variable(name, value)
.await
.map_err(|e| crate::RpcError::bad_request(e.to_string()))?;
Ok(Value::Bool(result))
}
async fn get_variable(&self, name: String) -> Result<Value, crate::RpcError> {
let value = self
.variable_manager
.get_variable(&name)
.await
.map_err(|e| crate::RpcError::bad_request(e.to_string()))?;
Ok(value)
}
async fn clear_all_variables(&self) -> Result<Value, crate::RpcError> {
let result = self
.variable_manager
.clear_all_variables()
.await
.map_err(|e| crate::RpcError::bad_request(e.to_string()))?;
Ok(Value::Bool(result))
}
async fn has_variable(&self, name: String) -> Result<Value, crate::RpcError> {
let exists = self.variable_manager.has_variable(&name).await;
Ok(Value::Bool(exists))
}
async fn list_variables(&self) -> Result<Value, crate::RpcError> {
let names = self.variable_manager.get_variable_names().await;
let values: Vec<Value> = names.into_iter().map(Value::String).collect();
Ok(Value::Array(values))
}
async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, crate::RpcError> {
match method {
"setVariable" => {
if args.len() != 2 {
return Err(crate::RpcError::bad_request(
"setVariable requires exactly 2 arguments (name, value)",
));
}
let name = match &args[0] {
Value::String(s) => s.clone(),
_ => {
return Err(crate::RpcError::bad_request(
"Variable name must be a string",
))
}
};
self.set_variable(name, args[1].clone()).await
}
"getVariable" => {
if args.len() != 1 {
return Err(crate::RpcError::bad_request(
"getVariable requires exactly 1 argument (name)",
));
}
let name = match &args[0] {
Value::String(s) => s.clone(),
_ => {
return Err(crate::RpcError::bad_request(
"Variable name must be a string",
))
}
};
self.get_variable(name).await
}
"clearAllVariables" => {
if !args.is_empty() {
return Err(crate::RpcError::bad_request(
"clearAllVariables takes no arguments",
));
}
self.clear_all_variables().await
}
"hasVariable" => {
if args.len() != 1 {
return Err(crate::RpcError::bad_request(
"hasVariable requires exactly 1 argument (name)",
));
}
let name = match &args[0] {
Value::String(s) => s.clone(),
_ => {
return Err(crate::RpcError::bad_request(
"Variable name must be a string",
))
}
};
self.has_variable(name).await
}
"listVariables" => {
if !args.is_empty() {
return Err(crate::RpcError::bad_request(
"listVariables takes no arguments",
));
}
self.list_variables().await
}
_ => self.delegate.call(method, args).await,
}
}
}
#[async_trait::async_trait]
impl crate::RpcTarget for DefaultVariableCapableTarget {
async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, crate::RpcError> {
VariableCapableRpcTarget::call(self, method, args).await
}
async fn get_property(&self, property: &str) -> Result<Value, crate::RpcError> {
if let Ok(value) = self.variable_manager.get_variable(property).await {
Ok(value)
} else {
self.delegate.get_property(property).await
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Number;
#[tokio::test]
async fn test_basic_variable_operations() {
let manager = VariableStateManager::new();
let result = manager
.set_variable("test".to_string(), Value::Number(Number::from(42)))
.await
.unwrap();
assert!(result);
let value = manager.get_variable("test").await.unwrap();
match value {
Value::Number(n) => assert_eq!(n.as_i64(), Some(42)),
_ => panic!("Expected number value"),
}
assert!(manager.has_variable("test").await);
assert!(!manager.has_variable("nonexistent").await);
assert_eq!(manager.variable_count().await, 1);
}
#[tokio::test]
async fn test_variable_validation() {
let manager = VariableStateManager::with_limits(2, 10);
let long_name = "a".repeat(20);
let result = manager
.set_variable(long_name, Value::Number(Number::from(1)))
.await;
assert!(result.is_err());
let result = manager
.set_variable("".to_string(), Value::Number(Number::from(1)))
.await;
assert!(result.is_err());
manager
.set_variable("var1".to_string(), Value::Number(Number::from(1)))
.await
.unwrap();
manager
.set_variable("var2".to_string(), Value::Number(Number::from(2)))
.await
.unwrap();
let result = manager
.set_variable("var3".to_string(), Value::Number(Number::from(3)))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_complex_values() {
let manager = VariableStateManager::new();
let array_val = Value::Array(vec![
Value::Number(Number::from(1)),
Value::String("test".to_string()),
Value::Bool(true),
]);
manager
.set_variable("array".to_string(), array_val)
.await
.unwrap();
let retrieved = manager.get_variable("array").await.unwrap();
match retrieved {
Value::Array(arr) => assert_eq!(arr.len(), 3),
_ => panic!("Expected array"),
}
let mut obj = std::collections::HashMap::new();
obj.insert(
"name".to_string(),
Box::new(Value::String("Alice".to_string())),
);
obj.insert("age".to_string(), Box::new(Value::Number(Number::from(30))));
let obj_val = Value::Object(obj);
manager
.set_variable("user".to_string(), obj_val)
.await
.unwrap();
let retrieved = manager.get_variable("user").await.unwrap();
match retrieved {
Value::Object(obj) => {
assert_eq!(obj.len(), 2);
assert!(obj.contains_key("name"));
assert!(obj.contains_key("age"));
}
_ => panic!("Expected object"),
}
}
#[tokio::test]
async fn test_clear_operations() {
let manager = VariableStateManager::new();
manager
.set_variable("var1".to_string(), Value::Number(Number::from(1)))
.await
.unwrap();
manager
.set_variable("var2".to_string(), Value::String("test".to_string()))
.await
.unwrap();
manager
.set_variable("var3".to_string(), Value::Bool(true))
.await
.unwrap();
assert_eq!(manager.variable_count().await, 3);
manager.delete_variable("var2").await.unwrap();
assert_eq!(manager.variable_count().await, 2);
assert!(!manager.has_variable("var2").await);
manager.clear_all_variables().await.unwrap();
assert_eq!(manager.variable_count().await, 0);
}
#[tokio::test]
async fn test_variable_names_list() {
let manager = VariableStateManager::new();
manager
.set_variable("alpha".to_string(), Value::Number(Number::from(1)))
.await
.unwrap();
manager
.set_variable("beta".to_string(), Value::Number(Number::from(2)))
.await
.unwrap();
manager
.set_variable("gamma".to_string(), Value::Number(Number::from(3)))
.await
.unwrap();
let names = manager.get_variable_names().await;
assert_eq!(names.len(), 3);
assert!(names.contains(&"alpha".to_string()));
assert!(names.contains(&"beta".to_string()));
assert!(names.contains(&"gamma".to_string()));
}
#[tokio::test]
async fn test_import_export_variables() {
let manager = VariableStateManager::new();
manager
.set_variable("var1".to_string(), Value::Number(Number::from(42)))
.await
.unwrap();
manager
.set_variable("var2".to_string(), Value::String("hello".to_string()))
.await
.unwrap();
let exported = manager.export_variables().await;
assert_eq!(exported.len(), 2);
let manager2 = VariableStateManager::new();
manager2.import_variables(exported).await.unwrap();
assert_eq!(manager2.variable_count().await, 2);
let val = manager2.get_variable("var1").await.unwrap();
match val {
Value::Number(n) => assert_eq!(n.as_i64(), Some(42)),
_ => panic!("Expected number"),
}
}
}