roder_api/
extension_state.rs1use serde::{Deserialize, Serialize, de::DeserializeOwned};
2
3use crate::{
4 events::{ThreadId, TurnId},
5 extension::ExtensionId,
6};
7
8pub type ExtensionStateKey = String;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub enum ExtensionStoreScope {
12 Global,
13 Workspace {
14 workspace: String,
15 },
16 Thread {
17 thread_id: ThreadId,
18 },
19 Turn {
20 thread_id: ThreadId,
21 turn_id: TurnId,
22 },
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26pub struct ExtensionStateRecord {
27 pub extension_id: ExtensionId,
28 pub key: ExtensionStateKey,
29 pub scope: ExtensionStoreScope,
30 pub schema_version: u32,
31 pub value: serde_json::Value,
32}
33
34pub trait ExtensionStateCodec: Send + Sync + 'static {
35 type State: Serialize + DeserializeOwned + Send + Sync + 'static;
36
37 fn extension_id(&self) -> ExtensionId;
38 fn key(&self) -> ExtensionStateKey;
39 fn scope(&self) -> ExtensionStoreScope;
40 fn schema_version(&self) -> u32;
41 fn migrate_state(
42 &self,
43 _record: &ExtensionStateRecord,
44 ) -> anyhow::Result<Option<ExtensionStateRecord>> {
45 Ok(None)
46 }
47
48 fn encode_state(&self, state: &Self::State) -> anyhow::Result<ExtensionStateRecord> {
49 Ok(ExtensionStateRecord {
50 extension_id: self.extension_id(),
51 key: self.key(),
52 scope: self.scope(),
53 schema_version: self.schema_version(),
54 value: serde_json::to_value(state)?,
55 })
56 }
57
58 fn decode_state(&self, record: &ExtensionStateRecord) -> anyhow::Result<Self::State> {
59 if record.extension_id != self.extension_id() {
60 anyhow::bail!(
61 "extension state id mismatch: expected {}, got {}",
62 self.extension_id(),
63 record.extension_id
64 );
65 }
66 if record.key != self.key() {
67 anyhow::bail!(
68 "extension state key mismatch: expected {}, got {}",
69 self.key(),
70 record.key
71 );
72 }
73 if record.scope != self.scope() {
74 anyhow::bail!("extension state scope mismatch");
75 }
76 let record = if record.schema_version == self.schema_version() {
77 record.clone()
78 } else if let Some(migrated) = self.migrate_state(record)? {
79 if migrated.schema_version != self.schema_version() {
80 anyhow::bail!(
81 "extension state migration produced schema {}, expected {}",
82 migrated.schema_version,
83 self.schema_version()
84 );
85 }
86 migrated
87 } else {
88 anyhow::bail!(
89 "extension state schema mismatch: expected {}, got {}",
90 self.schema_version(),
91 record.schema_version
92 );
93 };
94 Ok(serde_json::from_value(record.value)?)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
103 struct DemoState {
104 value: String,
105 }
106
107 struct DemoCodec;
108
109 impl ExtensionStateCodec for DemoCodec {
110 type State = DemoState;
111
112 fn extension_id(&self) -> ExtensionId {
113 "demo".to_string()
114 }
115
116 fn key(&self) -> ExtensionStateKey {
117 "state".to_string()
118 }
119
120 fn scope(&self) -> ExtensionStoreScope {
121 ExtensionStoreScope::Thread {
122 thread_id: "thread-a".to_string(),
123 }
124 }
125
126 fn schema_version(&self) -> u32 {
127 1
128 }
129 }
130
131 #[test]
132 fn extension_state_codec_round_trips_thread_scoped_state() {
133 let codec = DemoCodec;
134 let state = DemoState {
135 value: "expanded".to_string(),
136 };
137
138 let record = codec.encode_state(&state).unwrap();
139 assert_eq!(
140 record.scope,
141 ExtensionStoreScope::Thread {
142 thread_id: "thread-a".to_string()
143 }
144 );
145 assert_eq!(codec.decode_state(&record).unwrap(), state);
146 }
147
148 #[test]
149 fn extension_state_codec_can_migrate_older_schema() {
150 struct MigratingCodec;
151
152 impl ExtensionStateCodec for MigratingCodec {
153 type State = DemoState;
154
155 fn extension_id(&self) -> ExtensionId {
156 "demo".to_string()
157 }
158
159 fn key(&self) -> ExtensionStateKey {
160 "state".to_string()
161 }
162
163 fn scope(&self) -> ExtensionStoreScope {
164 ExtensionStoreScope::Thread {
165 thread_id: "thread-a".to_string(),
166 }
167 }
168
169 fn schema_version(&self) -> u32 {
170 2
171 }
172
173 fn migrate_state(
174 &self,
175 record: &ExtensionStateRecord,
176 ) -> anyhow::Result<Option<ExtensionStateRecord>> {
177 if record.schema_version != 1 {
178 return Ok(None);
179 }
180 Ok(Some(ExtensionStateRecord {
181 schema_version: 2,
182 value: serde_json::json!({
183 "value": record.value["legacy_value"],
184 }),
185 ..record.clone()
186 }))
187 }
188 }
189
190 let codec = MigratingCodec;
191 let state = codec
192 .decode_state(&ExtensionStateRecord {
193 extension_id: "demo".to_string(),
194 key: "state".to_string(),
195 scope: ExtensionStoreScope::Thread {
196 thread_id: "thread-a".to_string(),
197 },
198 schema_version: 1,
199 value: serde_json::json!({ "legacy_value": "expanded" }),
200 })
201 .unwrap();
202
203 assert_eq!(
204 state,
205 DemoState {
206 value: "expanded".to_string()
207 }
208 );
209 }
210}