Skip to main content

roder_api/
extension_state.rs

1use serde::{Deserialize, Serialize, de::DeserializeOwned};
2
3use crate::{
4    events::{ThreadId, TurnId},
5    extension::ExtensionId,
6};
7
8pub type ExtensionStateKey = String;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub enum ExtensionStoreScope {
12    Global,
13    Workspace {
14        workspace: String,
15    },
16    Thread {
17        thread_id: ThreadId,
18    },
19    Turn {
20        thread_id: ThreadId,
21        turn_id: TurnId,
22    },
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26pub struct ExtensionStateRecord {
27    pub extension_id: ExtensionId,
28    pub key: ExtensionStateKey,
29    pub scope: ExtensionStoreScope,
30    pub schema_version: u32,
31    pub value: serde_json::Value,
32}
33
34pub trait ExtensionStateCodec: Send + Sync + 'static {
35    type State: Serialize + DeserializeOwned + Send + Sync + 'static;
36
37    fn extension_id(&self) -> ExtensionId;
38    fn key(&self) -> ExtensionStateKey;
39    fn scope(&self) -> ExtensionStoreScope;
40    fn schema_version(&self) -> u32;
41    fn migrate_state(
42        &self,
43        _record: &ExtensionStateRecord,
44    ) -> anyhow::Result<Option<ExtensionStateRecord>> {
45        Ok(None)
46    }
47
48    fn encode_state(&self, state: &Self::State) -> anyhow::Result<ExtensionStateRecord> {
49        Ok(ExtensionStateRecord {
50            extension_id: self.extension_id(),
51            key: self.key(),
52            scope: self.scope(),
53            schema_version: self.schema_version(),
54            value: serde_json::to_value(state)?,
55        })
56    }
57
58    fn decode_state(&self, record: &ExtensionStateRecord) -> anyhow::Result<Self::State> {
59        if record.extension_id != self.extension_id() {
60            anyhow::bail!(
61                "extension state id mismatch: expected {}, got {}",
62                self.extension_id(),
63                record.extension_id
64            );
65        }
66        if record.key != self.key() {
67            anyhow::bail!(
68                "extension state key mismatch: expected {}, got {}",
69                self.key(),
70                record.key
71            );
72        }
73        if record.scope != self.scope() {
74            anyhow::bail!("extension state scope mismatch");
75        }
76        let record = if record.schema_version == self.schema_version() {
77            record.clone()
78        } else if let Some(migrated) = self.migrate_state(record)? {
79            if migrated.schema_version != self.schema_version() {
80                anyhow::bail!(
81                    "extension state migration produced schema {}, expected {}",
82                    migrated.schema_version,
83                    self.schema_version()
84                );
85            }
86            migrated
87        } else {
88            anyhow::bail!(
89                "extension state schema mismatch: expected {}, got {}",
90                self.schema_version(),
91                record.schema_version
92            );
93        };
94        Ok(serde_json::from_value(record.value)?)
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
103    struct DemoState {
104        value: String,
105    }
106
107    struct DemoCodec;
108
109    impl ExtensionStateCodec for DemoCodec {
110        type State = DemoState;
111
112        fn extension_id(&self) -> ExtensionId {
113            "demo".to_string()
114        }
115
116        fn key(&self) -> ExtensionStateKey {
117            "state".to_string()
118        }
119
120        fn scope(&self) -> ExtensionStoreScope {
121            ExtensionStoreScope::Thread {
122                thread_id: "thread-a".to_string(),
123            }
124        }
125
126        fn schema_version(&self) -> u32 {
127            1
128        }
129    }
130
131    #[test]
132    fn extension_state_codec_round_trips_thread_scoped_state() {
133        let codec = DemoCodec;
134        let state = DemoState {
135            value: "expanded".to_string(),
136        };
137
138        let record = codec.encode_state(&state).unwrap();
139        assert_eq!(
140            record.scope,
141            ExtensionStoreScope::Thread {
142                thread_id: "thread-a".to_string()
143            }
144        );
145        assert_eq!(codec.decode_state(&record).unwrap(), state);
146    }
147
148    #[test]
149    fn extension_state_codec_can_migrate_older_schema() {
150        struct MigratingCodec;
151
152        impl ExtensionStateCodec for MigratingCodec {
153            type State = DemoState;
154
155            fn extension_id(&self) -> ExtensionId {
156                "demo".to_string()
157            }
158
159            fn key(&self) -> ExtensionStateKey {
160                "state".to_string()
161            }
162
163            fn scope(&self) -> ExtensionStoreScope {
164                ExtensionStoreScope::Thread {
165                    thread_id: "thread-a".to_string(),
166                }
167            }
168
169            fn schema_version(&self) -> u32 {
170                2
171            }
172
173            fn migrate_state(
174                &self,
175                record: &ExtensionStateRecord,
176            ) -> anyhow::Result<Option<ExtensionStateRecord>> {
177                if record.schema_version != 1 {
178                    return Ok(None);
179                }
180                Ok(Some(ExtensionStateRecord {
181                    schema_version: 2,
182                    value: serde_json::json!({
183                        "value": record.value["legacy_value"],
184                    }),
185                    ..record.clone()
186                }))
187            }
188        }
189
190        let codec = MigratingCodec;
191        let state = codec
192            .decode_state(&ExtensionStateRecord {
193                extension_id: "demo".to_string(),
194                key: "state".to_string(),
195                scope: ExtensionStoreScope::Thread {
196                    thread_id: "thread-a".to_string(),
197                },
198                schema_version: 1,
199                value: serde_json::json!({ "legacy_value": "expanded" }),
200            })
201            .unwrap();
202
203        assert_eq!(
204            state,
205            DemoState {
206                value: "expanded".to_string()
207            }
208        );
209    }
210}