capnweb_core/protocol/
variable_state.rs

1// Variable State Management for Cap'n Web Protocol
2// Implements setVariable, getVariable, and clearAllVariables functionality
3
4use super::tables::Value;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// Variable state manager for capabilities
10#[derive(Debug)]
11pub struct VariableStateManager {
12    /// Variables stored per session/capability
13    variables: Arc<RwLock<HashMap<String, Value>>>,
14    /// Maximum number of variables allowed
15    max_variables: usize,
16    /// Maximum variable name length
17    max_name_length: usize,
18}
19
20impl VariableStateManager {
21    /// Create a new variable state manager
22    pub fn new() -> Self {
23        Self {
24            variables: Arc::new(RwLock::new(HashMap::new())),
25            max_variables: 1000,  // Reasonable default limit
26            max_name_length: 256, // Reasonable default limit
27        }
28    }
29
30    /// Create a new variable state manager with custom limits
31    pub fn with_limits(max_variables: usize, max_name_length: usize) -> Self {
32        Self {
33            variables: Arc::new(RwLock::new(HashMap::new())),
34            max_variables,
35            max_name_length,
36        }
37    }
38
39    /// Set a variable value
40    pub async fn set_variable(&self, name: String, value: Value) -> Result<bool, VariableError> {
41        // Validate variable name
42        if name.is_empty() {
43            return Err(VariableError::InvalidName(
44                "Variable name cannot be empty".to_string(),
45            ));
46        }
47
48        if name.len() > self.max_name_length {
49            return Err(VariableError::InvalidName(format!(
50                "Variable name too long: {} > {}",
51                name.len(),
52                self.max_name_length
53            )));
54        }
55
56        // Check for invalid characters (optional - could be customized)
57        if name.chars().any(|c| c.is_control() || c == '\0') {
58            return Err(VariableError::InvalidName(
59                "Variable name contains invalid characters".to_string(),
60            ));
61        }
62
63        let mut variables = self.variables.write().await;
64
65        // Check variable limits (only for new variables)
66        if !variables.contains_key(&name) && variables.len() >= self.max_variables {
67            return Err(VariableError::TooManyVariables(self.max_variables));
68        }
69
70        // Validate value (ensure it's serializable)
71        self.validate_value(&value)?;
72
73        tracing::debug!(
74            name = %name,
75            value_type = ?std::mem::discriminant(&value),
76            "Setting variable"
77        );
78
79        variables.insert(name, value);
80        Ok(true)
81    }
82
83    /// Get a variable value
84    pub async fn get_variable(&self, name: &str) -> Result<Value, VariableError> {
85        let variables = self.variables.read().await;
86
87        match variables.get(name) {
88            Some(value) => {
89                tracing::debug!(
90                    name = %name,
91                    value_type = ?std::mem::discriminant(value),
92                    "Retrieved variable"
93                );
94                Ok(value.clone())
95            }
96            None => Err(VariableError::VariableNotFound(name.to_string())),
97        }
98    }
99
100    /// Check if a variable exists
101    pub async fn has_variable(&self, name: &str) -> bool {
102        let variables = self.variables.read().await;
103        variables.contains_key(name)
104    }
105
106    /// Delete a variable
107    pub async fn delete_variable(&self, name: &str) -> Result<bool, VariableError> {
108        let mut variables = self.variables.write().await;
109
110        match variables.remove(name) {
111            Some(_) => {
112                tracing::debug!(name = %name, "Variable deleted");
113                Ok(true)
114            }
115            None => Err(VariableError::VariableNotFound(name.to_string())),
116        }
117    }
118
119    /// Clear all variables
120    pub async fn clear_all_variables(&self) -> Result<bool, VariableError> {
121        let mut variables = self.variables.write().await;
122        let count = variables.len();
123        variables.clear();
124
125        tracing::debug!(cleared_count = count, "All variables cleared");
126        Ok(true)
127    }
128
129    /// Get all variable names
130    pub async fn get_variable_names(&self) -> Vec<String> {
131        let variables = self.variables.read().await;
132        variables.keys().cloned().collect()
133    }
134
135    /// Get variable count
136    pub async fn variable_count(&self) -> usize {
137        let variables = self.variables.read().await;
138        variables.len()
139    }
140
141    /// Export all variables as a HashMap (for serialization/debugging)
142    pub async fn export_variables(&self) -> HashMap<String, Value> {
143        let variables = self.variables.read().await;
144        variables.clone()
145    }
146
147    /// Import variables from a HashMap (for deserialization/restoration)
148    pub async fn import_variables(
149        &self,
150        vars: HashMap<String, Value>,
151    ) -> Result<(), VariableError> {
152        // Validate all variables first
153        if vars.len() > self.max_variables {
154            return Err(VariableError::TooManyVariables(self.max_variables));
155        }
156
157        for (name, value) in &vars {
158            if name.len() > self.max_name_length {
159                return Err(VariableError::InvalidName(format!(
160                    "Variable name too long: {} > {}",
161                    name.len(),
162                    self.max_name_length
163                )));
164            }
165            self.validate_value(value)?;
166        }
167
168        // If validation passes, import all variables
169        let mut variables = self.variables.write().await;
170        variables.clear();
171        variables.extend(vars);
172
173        tracing::debug!(imported_count = variables.len(), "Variables imported");
174        Ok(())
175    }
176
177    /// Validate a value for storage
178    fn validate_value(&self, value: &Value) -> Result<(), VariableError> {
179        match value {
180            // Simple types are always valid
181            Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) | Value::Date(_) => {
182                Ok(())
183            }
184
185            // Arrays and objects require recursive validation
186            Value::Array(arr) => {
187                if arr.len() > 1000 {
188                    // Prevent overly large arrays
189                    return Err(VariableError::ValueTooComplex(
190                        "Array too large".to_string(),
191                    ));
192                }
193                for item in arr {
194                    self.validate_value(item)?;
195                }
196                Ok(())
197            }
198
199            Value::Object(obj) => {
200                if obj.len() > 100 {
201                    // Prevent overly complex objects
202                    return Err(VariableError::ValueTooComplex(
203                        "Object too complex".to_string(),
204                    ));
205                }
206                for (key, val) in obj {
207                    if key.len() > self.max_name_length {
208                        return Err(VariableError::ValueTooComplex(
209                            "Object key too long".to_string(),
210                        ));
211                    }
212                    self.validate_value(val)?;
213                }
214                Ok(())
215            }
216
217            // Error values are valid for storage
218            Value::Error { .. } => Ok(()),
219
220            // Complex types (stubs, promises) cannot be stored as variables
221            Value::Stub(_) => Err(VariableError::UnsupportedValueType(
222                "Cannot store stub as variable".to_string(),
223            )),
224            Value::Promise(_) => Err(VariableError::UnsupportedValueType(
225                "Cannot store promise as variable".to_string(),
226            )),
227        }
228    }
229}
230
231impl Default for VariableStateManager {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237/// Errors related to variable operations
238#[derive(Debug, thiserror::Error)]
239pub enum VariableError {
240    #[error("Invalid variable name: {0}")]
241    InvalidName(String),
242
243    #[error("Variable not found: {0}")]
244    VariableNotFound(String),
245
246    #[error("Too many variables (limit: {0})")]
247    TooManyVariables(usize),
248
249    #[error("Unsupported value type: {0}")]
250    UnsupportedValueType(String),
251
252    #[error("Value too complex: {0}")]
253    ValueTooComplex(String),
254}
255
256/// Enhanced RPC target trait with variable management
257#[async_trait::async_trait]
258pub trait VariableCapableRpcTarget: Send + Sync {
259    /// Set a variable
260    async fn set_variable(&self, name: String, value: Value) -> Result<Value, crate::RpcError>;
261
262    /// Get a variable
263    async fn get_variable(&self, name: String) -> Result<Value, crate::RpcError>;
264
265    /// Clear all variables
266    async fn clear_all_variables(&self) -> Result<Value, crate::RpcError>;
267
268    /// Check if variable exists
269    async fn has_variable(&self, name: String) -> Result<Value, crate::RpcError>;
270
271    /// Get all variable names
272    async fn list_variables(&self) -> Result<Value, crate::RpcError>;
273
274    /// Regular RPC method calls (fallback)
275    async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, crate::RpcError>;
276}
277
278/// Default implementation of variable-capable RPC target
279#[derive(Debug)]
280pub struct DefaultVariableCapableTarget {
281    variable_manager: VariableStateManager,
282    delegate: Arc<dyn crate::RpcTarget>, // Delegate for non-variable methods
283}
284
285impl DefaultVariableCapableTarget {
286    pub fn new(delegate: Arc<dyn crate::RpcTarget>) -> Self {
287        Self {
288            variable_manager: VariableStateManager::new(),
289            delegate,
290        }
291    }
292
293    pub fn with_variable_limits(
294        delegate: Arc<dyn crate::RpcTarget>,
295        max_variables: usize,
296        max_name_length: usize,
297    ) -> Self {
298        Self {
299            variable_manager: VariableStateManager::with_limits(max_variables, max_name_length),
300            delegate,
301        }
302    }
303}
304
305#[async_trait::async_trait]
306impl VariableCapableRpcTarget for DefaultVariableCapableTarget {
307    async fn set_variable(&self, name: String, value: Value) -> Result<Value, crate::RpcError> {
308        let result = self
309            .variable_manager
310            .set_variable(name, value)
311            .await
312            .map_err(|e| crate::RpcError::bad_request(e.to_string()))?;
313        Ok(Value::Bool(result))
314    }
315
316    async fn get_variable(&self, name: String) -> Result<Value, crate::RpcError> {
317        let value = self
318            .variable_manager
319            .get_variable(&name)
320            .await
321            .map_err(|e| crate::RpcError::bad_request(e.to_string()))?;
322        Ok(value)
323    }
324
325    async fn clear_all_variables(&self) -> Result<Value, crate::RpcError> {
326        let result = self
327            .variable_manager
328            .clear_all_variables()
329            .await
330            .map_err(|e| crate::RpcError::bad_request(e.to_string()))?;
331        Ok(Value::Bool(result))
332    }
333
334    async fn has_variable(&self, name: String) -> Result<Value, crate::RpcError> {
335        let exists = self.variable_manager.has_variable(&name).await;
336        Ok(Value::Bool(exists))
337    }
338
339    async fn list_variables(&self) -> Result<Value, crate::RpcError> {
340        let names = self.variable_manager.get_variable_names().await;
341        let values: Vec<Value> = names.into_iter().map(Value::String).collect();
342        Ok(Value::Array(values))
343    }
344
345    async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, crate::RpcError> {
346        match method {
347            "setVariable" => {
348                if args.len() != 2 {
349                    return Err(crate::RpcError::bad_request(
350                        "setVariable requires exactly 2 arguments (name, value)",
351                    ));
352                }
353
354                let name = match &args[0] {
355                    Value::String(s) => s.clone(),
356                    _ => {
357                        return Err(crate::RpcError::bad_request(
358                            "Variable name must be a string",
359                        ))
360                    }
361                };
362
363                self.set_variable(name, args[1].clone()).await
364            }
365
366            "getVariable" => {
367                if args.len() != 1 {
368                    return Err(crate::RpcError::bad_request(
369                        "getVariable requires exactly 1 argument (name)",
370                    ));
371                }
372
373                let name = match &args[0] {
374                    Value::String(s) => s.clone(),
375                    _ => {
376                        return Err(crate::RpcError::bad_request(
377                            "Variable name must be a string",
378                        ))
379                    }
380                };
381
382                self.get_variable(name).await
383            }
384
385            "clearAllVariables" => {
386                if !args.is_empty() {
387                    return Err(crate::RpcError::bad_request(
388                        "clearAllVariables takes no arguments",
389                    ));
390                }
391                self.clear_all_variables().await
392            }
393
394            "hasVariable" => {
395                if args.len() != 1 {
396                    return Err(crate::RpcError::bad_request(
397                        "hasVariable requires exactly 1 argument (name)",
398                    ));
399                }
400
401                let name = match &args[0] {
402                    Value::String(s) => s.clone(),
403                    _ => {
404                        return Err(crate::RpcError::bad_request(
405                            "Variable name must be a string",
406                        ))
407                    }
408                };
409
410                self.has_variable(name).await
411            }
412
413            "listVariables" => {
414                if !args.is_empty() {
415                    return Err(crate::RpcError::bad_request(
416                        "listVariables takes no arguments",
417                    ));
418                }
419                self.list_variables().await
420            }
421
422            // Delegate other methods to the underlying target
423            _ => self.delegate.call(method, args).await,
424        }
425    }
426}
427
428#[async_trait::async_trait]
429impl crate::RpcTarget for DefaultVariableCapableTarget {
430    async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, crate::RpcError> {
431        VariableCapableRpcTarget::call(self, method, args).await
432    }
433
434    async fn get_property(&self, property: &str) -> Result<Value, crate::RpcError> {
435        // For variables, we could support getting variables as properties
436        if let Ok(value) = self.variable_manager.get_variable(property).await {
437            Ok(value)
438        } else {
439            // Delegate to underlying target
440            self.delegate.get_property(property).await
441        }
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use serde_json::Number;
449
450    #[tokio::test]
451    async fn test_basic_variable_operations() {
452        let manager = VariableStateManager::new();
453
454        // Set a variable
455        let result = manager
456            .set_variable("test".to_string(), Value::Number(Number::from(42)))
457            .await
458            .unwrap();
459        assert!(result);
460
461        // Get the variable
462        let value = manager.get_variable("test").await.unwrap();
463        match value {
464            Value::Number(n) => assert_eq!(n.as_i64(), Some(42)),
465            _ => panic!("Expected number value"),
466        }
467
468        // Check if variable exists
469        assert!(manager.has_variable("test").await);
470        assert!(!manager.has_variable("nonexistent").await);
471
472        // Variable count
473        assert_eq!(manager.variable_count().await, 1);
474    }
475
476    #[tokio::test]
477    async fn test_variable_validation() {
478        let manager = VariableStateManager::with_limits(2, 10);
479
480        // Test name length validation
481        let long_name = "a".repeat(20);
482        let result = manager
483            .set_variable(long_name, Value::Number(Number::from(1)))
484            .await;
485        assert!(result.is_err());
486
487        // Test empty name
488        let result = manager
489            .set_variable("".to_string(), Value::Number(Number::from(1)))
490            .await;
491        assert!(result.is_err());
492
493        // Test variable limit
494        manager
495            .set_variable("var1".to_string(), Value::Number(Number::from(1)))
496            .await
497            .unwrap();
498        manager
499            .set_variable("var2".to_string(), Value::Number(Number::from(2)))
500            .await
501            .unwrap();
502
503        let result = manager
504            .set_variable("var3".to_string(), Value::Number(Number::from(3)))
505            .await;
506        assert!(result.is_err());
507    }
508
509    #[tokio::test]
510    async fn test_complex_values() {
511        let manager = VariableStateManager::new();
512
513        // Test array storage
514        let array_val = Value::Array(vec![
515            Value::Number(Number::from(1)),
516            Value::String("test".to_string()),
517            Value::Bool(true),
518        ]);
519        manager
520            .set_variable("array".to_string(), array_val)
521            .await
522            .unwrap();
523
524        let retrieved = manager.get_variable("array").await.unwrap();
525        match retrieved {
526            Value::Array(arr) => assert_eq!(arr.len(), 3),
527            _ => panic!("Expected array"),
528        }
529
530        // Test object storage
531        let mut obj = std::collections::HashMap::new();
532        obj.insert(
533            "name".to_string(),
534            Box::new(Value::String("Alice".to_string())),
535        );
536        obj.insert("age".to_string(), Box::new(Value::Number(Number::from(30))));
537
538        let obj_val = Value::Object(obj);
539        manager
540            .set_variable("user".to_string(), obj_val)
541            .await
542            .unwrap();
543
544        let retrieved = manager.get_variable("user").await.unwrap();
545        match retrieved {
546            Value::Object(obj) => {
547                assert_eq!(obj.len(), 2);
548                assert!(obj.contains_key("name"));
549                assert!(obj.contains_key("age"));
550            }
551            _ => panic!("Expected object"),
552        }
553    }
554
555    #[tokio::test]
556    async fn test_clear_operations() {
557        let manager = VariableStateManager::new();
558
559        // Set multiple variables
560        manager
561            .set_variable("var1".to_string(), Value::Number(Number::from(1)))
562            .await
563            .unwrap();
564        manager
565            .set_variable("var2".to_string(), Value::String("test".to_string()))
566            .await
567            .unwrap();
568        manager
569            .set_variable("var3".to_string(), Value::Bool(true))
570            .await
571            .unwrap();
572
573        assert_eq!(manager.variable_count().await, 3);
574
575        // Delete one variable
576        manager.delete_variable("var2").await.unwrap();
577        assert_eq!(manager.variable_count().await, 2);
578        assert!(!manager.has_variable("var2").await);
579
580        // Clear all variables
581        manager.clear_all_variables().await.unwrap();
582        assert_eq!(manager.variable_count().await, 0);
583    }
584
585    #[tokio::test]
586    async fn test_variable_names_list() {
587        let manager = VariableStateManager::new();
588
589        // Set variables
590        manager
591            .set_variable("alpha".to_string(), Value::Number(Number::from(1)))
592            .await
593            .unwrap();
594        manager
595            .set_variable("beta".to_string(), Value::Number(Number::from(2)))
596            .await
597            .unwrap();
598        manager
599            .set_variable("gamma".to_string(), Value::Number(Number::from(3)))
600            .await
601            .unwrap();
602
603        let names = manager.get_variable_names().await;
604        assert_eq!(names.len(), 3);
605        assert!(names.contains(&"alpha".to_string()));
606        assert!(names.contains(&"beta".to_string()));
607        assert!(names.contains(&"gamma".to_string()));
608    }
609
610    #[tokio::test]
611    async fn test_import_export_variables() {
612        let manager = VariableStateManager::new();
613
614        // Set up some variables
615        manager
616            .set_variable("var1".to_string(), Value::Number(Number::from(42)))
617            .await
618            .unwrap();
619        manager
620            .set_variable("var2".to_string(), Value::String("hello".to_string()))
621            .await
622            .unwrap();
623
624        // Export variables
625        let exported = manager.export_variables().await;
626        assert_eq!(exported.len(), 2);
627
628        // Create new manager and import
629        let manager2 = VariableStateManager::new();
630        manager2.import_variables(exported).await.unwrap();
631
632        assert_eq!(manager2.variable_count().await, 2);
633
634        let val = manager2.get_variable("var1").await.unwrap();
635        match val {
636            Value::Number(n) => assert_eq!(n.as_i64(), Some(42)),
637            _ => panic!("Expected number"),
638        }
639    }
640}