adk_plugin/context.rs
1//! Type-safe shared state container for plugins.
2//!
3//! [`PluginContext`] uses the TypeMap pattern where each type serves as its own key,
4//! providing a type-safe, concurrent key-value store for plugin shared state.
5//!
6//! # Overview
7//!
8//! Plugins often need to share state across hook invocations within a single agent run.
9//! For example, a rate-limiting plugin tracks request counts, a caching plugin stores
10//! cached responses, and a metrics plugin accumulates statistics.
11//!
12//! `PluginContext` enables this by allowing plugins to insert and retrieve values
13//! keyed by their Rust type. Each unique type can have exactly one value stored.
14//!
15//! # Examples
16//!
17//! ```rust
18//! use adk_plugin::PluginContext;
19//!
20//! #[derive(Clone, Debug, PartialEq)]
21//! struct RequestCount(u32);
22//!
23//! #[derive(Clone, Debug, PartialEq)]
24//! struct CacheHits(u64);
25//!
26//! # #[tokio::main]
27//! # async fn main() {
28//! let ctx = PluginContext::new();
29//!
30//! // Insert typed state
31//! ctx.insert(RequestCount(0)).await;
32//! ctx.insert(CacheHits(42)).await;
33//!
34//! // Retrieve typed state
35//! let count = ctx.get::<RequestCount>().await;
36//! assert_eq!(count, Some(RequestCount(0)));
37//!
38//! // Update state (last write wins)
39//! ctx.insert(RequestCount(5)).await;
40//! let count = ctx.get::<RequestCount>().await;
41//! assert_eq!(count, Some(RequestCount(5)));
42//!
43//! // Remove state
44//! let removed = ctx.remove::<CacheHits>().await;
45//! assert_eq!(removed, Some(CacheHits(42)));
46//! assert_eq!(ctx.get::<CacheHits>().await, None);
47//! # }
48//! ```
49
50use std::any::{Any, TypeId};
51use std::collections::HashMap;
52
53use tokio::sync::RwLock;
54
55/// A type-safe, concurrent key-value store for plugin shared state.
56///
57/// Uses the TypeMap pattern where each type serves as its own key.
58/// Thread-safe via [`tokio::sync::RwLock`] for concurrent async access.
59///
60/// # Concurrency
61///
62/// - Multiple readers can access state concurrently via [`get`](Self::get) and
63/// [`contains`](Self::contains).
64/// - Writers acquire exclusive access via [`insert`](Self::insert) and
65/// [`remove`](Self::remove).
66/// - No locks are held across await points — each method acquires and releases
67/// the lock within a single operation.
68///
69/// # Examples
70///
71/// ```rust
72/// use adk_plugin::PluginContext;
73///
74/// #[derive(Clone, Debug)]
75/// struct RateLimitState {
76/// requests_this_minute: u32,
77/// }
78///
79/// # #[tokio::main]
80/// # async fn main() {
81/// let ctx = PluginContext::new();
82///
83/// // A rate-limiting plugin writes state
84/// ctx.insert(RateLimitState { requests_this_minute: 1 }).await;
85///
86/// // A metrics plugin reads it
87/// if let Some(state) = ctx.get::<RateLimitState>().await {
88/// println!("Requests: {}", state.requests_this_minute);
89/// }
90/// # }
91/// ```
92pub struct PluginContext {
93 state: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
94}
95
96impl PluginContext {
97 /// Creates a new empty `PluginContext`.
98 ///
99 /// # Examples
100 ///
101 /// ```rust
102 /// use adk_plugin::PluginContext;
103 ///
104 /// let ctx = PluginContext::new();
105 /// ```
106 pub fn new() -> Self {
107 Self { state: RwLock::new(HashMap::new()) }
108 }
109
110 /// Inserts a value into the context. The type itself is the key.
111 ///
112 /// If a value of the same type already exists, it is replaced.
113 /// The previous value is discarded.
114 ///
115 /// # Examples
116 ///
117 /// ```rust
118 /// use adk_plugin::PluginContext;
119 ///
120 /// #[derive(Clone, Debug, PartialEq)]
121 /// struct Counter(u32);
122 ///
123 /// # #[tokio::main]
124 /// # async fn main() {
125 /// let ctx = PluginContext::new();
126 /// ctx.insert(Counter(1)).await;
127 /// ctx.insert(Counter(2)).await; // Replaces the previous value
128 ///
129 /// assert_eq!(ctx.get::<Counter>().await, Some(Counter(2)));
130 /// # }
131 /// ```
132 pub async fn insert<T: Send + Sync + 'static>(&self, value: T) {
133 self.state.write().await.insert(TypeId::of::<T>(), Box::new(value));
134 }
135
136 /// Gets a clone of the stored value for type `T`.
137 ///
138 /// Returns `None` if no value of type `T` has been inserted.
139 ///
140 /// # Examples
141 ///
142 /// ```rust
143 /// use adk_plugin::PluginContext;
144 ///
145 /// #[derive(Clone, Debug, PartialEq)]
146 /// struct Name(String);
147 ///
148 /// # #[tokio::main]
149 /// # async fn main() {
150 /// let ctx = PluginContext::new();
151 ///
152 /// assert_eq!(ctx.get::<Name>().await, None);
153 ///
154 /// ctx.insert(Name("alice".to_string())).await;
155 /// assert_eq!(ctx.get::<Name>().await, Some(Name("alice".to_string())));
156 /// # }
157 /// ```
158 pub async fn get<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
159 self.state.read().await.get(&TypeId::of::<T>()).and_then(|v| v.downcast_ref::<T>()).cloned()
160 }
161
162 /// Checks if a value of type `T` exists in the context.
163 ///
164 /// # Examples
165 ///
166 /// ```rust
167 /// use adk_plugin::PluginContext;
168 ///
169 /// #[derive(Clone, Debug)]
170 /// struct Marker;
171 ///
172 /// # #[tokio::main]
173 /// # async fn main() {
174 /// let ctx = PluginContext::new();
175 ///
176 /// assert!(!ctx.contains::<Marker>().await);
177 /// ctx.insert(Marker).await;
178 /// assert!(ctx.contains::<Marker>().await);
179 /// # }
180 /// ```
181 pub async fn contains<T: Send + Sync + 'static>(&self) -> bool {
182 self.state.read().await.contains_key(&TypeId::of::<T>())
183 }
184
185 /// Removes a value of type `T`, returning it if present.
186 ///
187 /// After removal, [`get`](Self::get) and [`contains`](Self::contains) for
188 /// type `T` will return `None` and `false` respectively.
189 ///
190 /// # Examples
191 ///
192 /// ```rust
193 /// use adk_plugin::PluginContext;
194 ///
195 /// #[derive(Clone, Debug, PartialEq)]
196 /// struct Token(String);
197 ///
198 /// # #[tokio::main]
199 /// # async fn main() {
200 /// let ctx = PluginContext::new();
201 /// ctx.insert(Token("abc".to_string())).await;
202 ///
203 /// let removed = ctx.remove::<Token>().await;
204 /// assert_eq!(removed, Some(Token("abc".to_string())));
205 /// assert_eq!(ctx.get::<Token>().await, None);
206 /// # }
207 /// ```
208 pub async fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
209 self.state
210 .write()
211 .await
212 .remove(&TypeId::of::<T>())
213 .and_then(|v| v.downcast::<T>().ok())
214 .map(|b| *b)
215 }
216}
217
218impl Default for PluginContext {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[derive(Clone, Debug, PartialEq)]
229 struct Counter(u32);
230
231 #[derive(Clone, Debug, PartialEq)]
232 struct Name(String);
233
234 #[tokio::test]
235 async fn test_new_context_is_empty() {
236 let ctx = PluginContext::new();
237 assert!(!ctx.contains::<Counter>().await);
238 assert_eq!(ctx.get::<Counter>().await, None);
239 }
240
241 #[tokio::test]
242 async fn test_insert_and_get() {
243 let ctx = PluginContext::new();
244 ctx.insert(Counter(42)).await;
245
246 let value = ctx.get::<Counter>().await;
247 assert_eq!(value, Some(Counter(42)));
248 }
249
250 #[tokio::test]
251 async fn test_insert_overwrites_previous() {
252 let ctx = PluginContext::new();
253 ctx.insert(Counter(1)).await;
254 ctx.insert(Counter(99)).await;
255
256 assert_eq!(ctx.get::<Counter>().await, Some(Counter(99)));
257 }
258
259 #[tokio::test]
260 async fn test_multiple_types() {
261 let ctx = PluginContext::new();
262 ctx.insert(Counter(10)).await;
263 ctx.insert(Name("hello".to_string())).await;
264
265 assert_eq!(ctx.get::<Counter>().await, Some(Counter(10)));
266 assert_eq!(ctx.get::<Name>().await, Some(Name("hello".to_string())));
267 }
268
269 #[tokio::test]
270 async fn test_contains() {
271 let ctx = PluginContext::new();
272 assert!(!ctx.contains::<Counter>().await);
273
274 ctx.insert(Counter(0)).await;
275 assert!(ctx.contains::<Counter>().await);
276 }
277
278 #[tokio::test]
279 async fn test_remove_returns_value() {
280 let ctx = PluginContext::new();
281 ctx.insert(Counter(7)).await;
282
283 let removed = ctx.remove::<Counter>().await;
284 assert_eq!(removed, Some(Counter(7)));
285 }
286
287 #[tokio::test]
288 async fn test_remove_makes_get_return_none() {
289 let ctx = PluginContext::new();
290 ctx.insert(Counter(7)).await;
291 ctx.remove::<Counter>().await;
292
293 assert_eq!(ctx.get::<Counter>().await, None);
294 assert!(!ctx.contains::<Counter>().await);
295 }
296
297 #[tokio::test]
298 async fn test_remove_nonexistent_returns_none() {
299 let ctx = PluginContext::new();
300 let removed = ctx.remove::<Counter>().await;
301 assert_eq!(removed, None);
302 }
303
304 #[tokio::test]
305 async fn test_default_creates_empty_context() {
306 let ctx = PluginContext::default();
307 assert!(!ctx.contains::<Counter>().await);
308 }
309
310 #[tokio::test]
311 async fn test_concurrent_access() {
312 use std::sync::Arc;
313
314 let ctx = Arc::new(PluginContext::new());
315 ctx.insert(Counter(0)).await;
316
317 let ctx_clone = Arc::clone(&ctx);
318 let writer = tokio::spawn(async move {
319 for i in 1..=100 {
320 ctx_clone.insert(Counter(i)).await;
321 }
322 });
323
324 let ctx_clone2 = Arc::clone(&ctx);
325 let reader = tokio::spawn(async move {
326 for _ in 0..100 {
327 // Should never panic — reads are always valid
328 let _ = ctx_clone2.get::<Counter>().await;
329 }
330 });
331
332 writer.await.unwrap();
333 reader.await.unwrap();
334
335 // Final value should be 100 (last write)
336 assert_eq!(ctx.get::<Counter>().await, Some(Counter(100)));
337 }
338}