Skip to main content

aster/session/
extension_data.rs

1// Extension data management for sessions
2// Provides a simple way to store extension-specific data with versioned keys
3
4use crate::config::ExtensionConfig;
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use utoipa::ToSchema;
10
11/// Extension data containing all extension states
12/// Keys are in format "extension_name.version" (e.g., "todo.v0")
13#[derive(Debug, Clone, Serialize, Deserialize, Default, ToSchema)]
14pub struct ExtensionData {
15    #[serde(flatten)]
16    pub extension_states: HashMap<String, Value>,
17}
18
19impl ExtensionData {
20    /// Create a new empty ExtensionData
21    pub fn new() -> Self {
22        Self {
23            extension_states: HashMap::new(),
24        }
25    }
26
27    /// Get extension state for a specific extension and version
28    pub fn get_extension_state(&self, extension_name: &str, version: &str) -> Option<&Value> {
29        let key = format!("{}.{}", extension_name, version);
30        self.extension_states.get(&key)
31    }
32
33    /// Set extension state for a specific extension and version
34    pub fn set_extension_state(&mut self, extension_name: &str, version: &str, state: Value) {
35        let key = format!("{}.{}", extension_name, version);
36        self.extension_states.insert(key, state);
37    }
38}
39
40/// Helper trait for extension-specific state management
41pub trait ExtensionState: Sized + Serialize + for<'de> Deserialize<'de> {
42    /// The name of the extension
43    const EXTENSION_NAME: &'static str;
44
45    /// The version of the extension state format
46    const VERSION: &'static str;
47
48    /// Convert from JSON value
49    fn from_value(value: &Value) -> Result<Self> {
50        serde_json::from_value(value.clone()).map_err(|e| {
51            anyhow::anyhow!(
52                "Failed to deserialize {} state: {}",
53                Self::EXTENSION_NAME,
54                e
55            )
56        })
57    }
58
59    /// Convert to JSON value
60    fn to_value(&self) -> Result<Value> {
61        serde_json::to_value(self).map_err(|e| {
62            anyhow::anyhow!("Failed to serialize {} state: {}", Self::EXTENSION_NAME, e)
63        })
64    }
65
66    /// Get state from extension data
67    fn from_extension_data(extension_data: &ExtensionData) -> Option<Self> {
68        extension_data
69            .get_extension_state(Self::EXTENSION_NAME, Self::VERSION)
70            .and_then(|v| Self::from_value(v).ok())
71    }
72
73    /// Save state to extension data
74    fn to_extension_data(&self, extension_data: &mut ExtensionData) -> Result<()> {
75        let value = self.to_value()?;
76        extension_data.set_extension_state(Self::EXTENSION_NAME, Self::VERSION, value);
77        Ok(())
78    }
79}
80
81/// TODO extension state implementation
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct TodoState {
84    pub content: String,
85}
86
87impl ExtensionState for TodoState {
88    const EXTENSION_NAME: &'static str = "todo";
89    const VERSION: &'static str = "v0";
90}
91
92impl TodoState {
93    /// Create a new TODO state
94    pub fn new(content: String) -> Self {
95        Self { content }
96    }
97}
98
99/// Enabled extensions state implementation for storing which extensions are active
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct EnabledExtensionsState {
102    pub extensions: Vec<ExtensionConfig>,
103}
104
105impl ExtensionState for EnabledExtensionsState {
106    const EXTENSION_NAME: &'static str = "enabled_extensions";
107    const VERSION: &'static str = "v0";
108}
109
110impl EnabledExtensionsState {
111    pub fn new(extensions: Vec<ExtensionConfig>) -> Self {
112        Self { extensions }
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use serde_json::json;
120
121    #[test]
122    fn test_extension_data_basic_operations() {
123        let mut extension_data = ExtensionData::new();
124
125        // Test setting and getting extension state
126        let todo_state = json!({"content": "- Task 1\n- Task 2"});
127        extension_data.set_extension_state("todo", "v0", todo_state.clone());
128
129        assert_eq!(
130            extension_data.get_extension_state("todo", "v0"),
131            Some(&todo_state)
132        );
133        assert_eq!(extension_data.get_extension_state("todo", "v1"), None);
134    }
135
136    #[test]
137    fn test_multiple_extension_states() {
138        let mut extension_data = ExtensionData::new();
139
140        // Add multiple extension states
141        extension_data.set_extension_state("todo", "v0", json!("TODO content"));
142        extension_data.set_extension_state("memory", "v1", json!({"items": ["item1", "item2"]}));
143        extension_data.set_extension_state("config", "v2", json!({"setting": true}));
144
145        // Check all states exist
146        assert_eq!(extension_data.extension_states.len(), 3);
147        assert!(extension_data.get_extension_state("todo", "v0").is_some());
148        assert!(extension_data.get_extension_state("memory", "v1").is_some());
149        assert!(extension_data.get_extension_state("config", "v2").is_some());
150    }
151
152    #[test]
153    fn test_todo_state_trait() {
154        let mut extension_data = ExtensionData::new();
155
156        // Create and save TODO state
157        let todo = TodoState::new("- Task 1\n- Task 2".to_string());
158        todo.to_extension_data(&mut extension_data).unwrap();
159
160        // Retrieve TODO state
161        let retrieved = TodoState::from_extension_data(&extension_data);
162        assert!(retrieved.is_some());
163        assert_eq!(retrieved.unwrap().content, "- Task 1\n- Task 2");
164    }
165
166    #[test]
167    fn test_extension_data_serialization() {
168        let mut extension_data = ExtensionData::new();
169        extension_data.set_extension_state("todo", "v0", json!("TODO content"));
170        extension_data.set_extension_state("memory", "v1", json!({"key": "value"}));
171
172        // Serialize to JSON
173        let json = serde_json::to_value(&extension_data).unwrap();
174
175        // Check the structure
176        assert!(json.is_object());
177        assert_eq!(json.get("todo.v0"), Some(&json!("TODO content")));
178        assert_eq!(json.get("memory.v1"), Some(&json!({"key": "value"})));
179
180        // Deserialize back
181        let deserialized: ExtensionData = serde_json::from_value(json).unwrap();
182        assert_eq!(
183            deserialized.get_extension_state("todo", "v0"),
184            Some(&json!("TODO content"))
185        );
186        assert_eq!(
187            deserialized.get_extension_state("memory", "v1"),
188            Some(&json!({"key": "value"}))
189        );
190    }
191}