aster/session/
extension_data.rs1use crate::config::ExtensionConfig;
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use utoipa::ToSchema;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default, ToSchema)]
14pub struct ExtensionData {
15 #[serde(flatten)]
16 pub extension_states: HashMap<String, Value>,
17}
18
19impl ExtensionData {
20 pub fn new() -> Self {
22 Self {
23 extension_states: HashMap::new(),
24 }
25 }
26
27 pub fn get_extension_state(&self, extension_name: &str, version: &str) -> Option<&Value> {
29 let key = format!("{}.{}", extension_name, version);
30 self.extension_states.get(&key)
31 }
32
33 pub fn set_extension_state(&mut self, extension_name: &str, version: &str, state: Value) {
35 let key = format!("{}.{}", extension_name, version);
36 self.extension_states.insert(key, state);
37 }
38}
39
40pub trait ExtensionState: Sized + Serialize + for<'de> Deserialize<'de> {
42 const EXTENSION_NAME: &'static str;
44
45 const VERSION: &'static str;
47
48 fn from_value(value: &Value) -> Result<Self> {
50 serde_json::from_value(value.clone()).map_err(|e| {
51 anyhow::anyhow!(
52 "Failed to deserialize {} state: {}",
53 Self::EXTENSION_NAME,
54 e
55 )
56 })
57 }
58
59 fn to_value(&self) -> Result<Value> {
61 serde_json::to_value(self).map_err(|e| {
62 anyhow::anyhow!("Failed to serialize {} state: {}", Self::EXTENSION_NAME, e)
63 })
64 }
65
66 fn from_extension_data(extension_data: &ExtensionData) -> Option<Self> {
68 extension_data
69 .get_extension_state(Self::EXTENSION_NAME, Self::VERSION)
70 .and_then(|v| Self::from_value(v).ok())
71 }
72
73 fn to_extension_data(&self, extension_data: &mut ExtensionData) -> Result<()> {
75 let value = self.to_value()?;
76 extension_data.set_extension_state(Self::EXTENSION_NAME, Self::VERSION, value);
77 Ok(())
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct TodoState {
84 pub content: String,
85}
86
87impl ExtensionState for TodoState {
88 const EXTENSION_NAME: &'static str = "todo";
89 const VERSION: &'static str = "v0";
90}
91
92impl TodoState {
93 pub fn new(content: String) -> Self {
95 Self { content }
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct EnabledExtensionsState {
102 pub extensions: Vec<ExtensionConfig>,
103}
104
105impl ExtensionState for EnabledExtensionsState {
106 const EXTENSION_NAME: &'static str = "enabled_extensions";
107 const VERSION: &'static str = "v0";
108}
109
110impl EnabledExtensionsState {
111 pub fn new(extensions: Vec<ExtensionConfig>) -> Self {
112 Self { extensions }
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use serde_json::json;
120
121 #[test]
122 fn test_extension_data_basic_operations() {
123 let mut extension_data = ExtensionData::new();
124
125 let todo_state = json!({"content": "- Task 1\n- Task 2"});
127 extension_data.set_extension_state("todo", "v0", todo_state.clone());
128
129 assert_eq!(
130 extension_data.get_extension_state("todo", "v0"),
131 Some(&todo_state)
132 );
133 assert_eq!(extension_data.get_extension_state("todo", "v1"), None);
134 }
135
136 #[test]
137 fn test_multiple_extension_states() {
138 let mut extension_data = ExtensionData::new();
139
140 extension_data.set_extension_state("todo", "v0", json!("TODO content"));
142 extension_data.set_extension_state("memory", "v1", json!({"items": ["item1", "item2"]}));
143 extension_data.set_extension_state("config", "v2", json!({"setting": true}));
144
145 assert_eq!(extension_data.extension_states.len(), 3);
147 assert!(extension_data.get_extension_state("todo", "v0").is_some());
148 assert!(extension_data.get_extension_state("memory", "v1").is_some());
149 assert!(extension_data.get_extension_state("config", "v2").is_some());
150 }
151
152 #[test]
153 fn test_todo_state_trait() {
154 let mut extension_data = ExtensionData::new();
155
156 let todo = TodoState::new("- Task 1\n- Task 2".to_string());
158 todo.to_extension_data(&mut extension_data).unwrap();
159
160 let retrieved = TodoState::from_extension_data(&extension_data);
162 assert!(retrieved.is_some());
163 assert_eq!(retrieved.unwrap().content, "- Task 1\n- Task 2");
164 }
165
166 #[test]
167 fn test_extension_data_serialization() {
168 let mut extension_data = ExtensionData::new();
169 extension_data.set_extension_state("todo", "v0", json!("TODO content"));
170 extension_data.set_extension_state("memory", "v1", json!({"key": "value"}));
171
172 let json = serde_json::to_value(&extension_data).unwrap();
174
175 assert!(json.is_object());
177 assert_eq!(json.get("todo.v0"), Some(&json!("TODO content")));
178 assert_eq!(json.get("memory.v1"), Some(&json!({"key": "value"})));
179
180 let deserialized: ExtensionData = serde_json::from_value(json).unwrap();
182 assert_eq!(
183 deserialized.get_extension_state("todo", "v0"),
184 Some(&json!("TODO content"))
185 );
186 assert_eq!(
187 deserialized.get_extension_state("memory", "v1"),
188 Some(&json!({"key": "value"}))
189 );
190 }
191}