Skip to main content

adk_plugin/
context.rs

1//! Type-safe shared state container for plugins.
2//!
3//! [`PluginContext`] uses the TypeMap pattern where each type serves as its own key,
4//! providing a type-safe, concurrent key-value store for plugin shared state.
5//!
6//! # Overview
7//!
8//! Plugins often need to share state across hook invocations within a single agent run.
9//! For example, a rate-limiting plugin tracks request counts, a caching plugin stores
10//! cached responses, and a metrics plugin accumulates statistics.
11//!
12//! `PluginContext` enables this by allowing plugins to insert and retrieve values
13//! keyed by their Rust type. Each unique type can have exactly one value stored.
14//!
15//! # Examples
16//!
17//! ```rust
18//! use adk_plugin::PluginContext;
19//!
20//! #[derive(Clone, Debug, PartialEq)]
21//! struct RequestCount(u32);
22//!
23//! #[derive(Clone, Debug, PartialEq)]
24//! struct CacheHits(u64);
25//!
26//! # #[tokio::main]
27//! # async fn main() {
28//! let ctx = PluginContext::new();
29//!
30//! // Insert typed state
31//! ctx.insert(RequestCount(0)).await;
32//! ctx.insert(CacheHits(42)).await;
33//!
34//! // Retrieve typed state
35//! let count = ctx.get::<RequestCount>().await;
36//! assert_eq!(count, Some(RequestCount(0)));
37//!
38//! // Update state (last write wins)
39//! ctx.insert(RequestCount(5)).await;
40//! let count = ctx.get::<RequestCount>().await;
41//! assert_eq!(count, Some(RequestCount(5)));
42//!
43//! // Remove state
44//! let removed = ctx.remove::<CacheHits>().await;
45//! assert_eq!(removed, Some(CacheHits(42)));
46//! assert_eq!(ctx.get::<CacheHits>().await, None);
47//! # }
48//! ```
49
50use std::any::{Any, TypeId};
51use std::collections::HashMap;
52
53use tokio::sync::RwLock;
54
55/// A type-safe, concurrent key-value store for plugin shared state.
56///
57/// Uses the TypeMap pattern where each type serves as its own key.
58/// Thread-safe via [`tokio::sync::RwLock`] for concurrent async access.
59///
60/// # Concurrency
61///
62/// - Multiple readers can access state concurrently via [`get`](Self::get) and
63///   [`contains`](Self::contains).
64/// - Writers acquire exclusive access via [`insert`](Self::insert) and
65///   [`remove`](Self::remove).
66/// - No locks are held across await points — each method acquires and releases
67///   the lock within a single operation.
68///
69/// # Examples
70///
71/// ```rust
72/// use adk_plugin::PluginContext;
73///
74/// #[derive(Clone, Debug)]
75/// struct RateLimitState {
76///     requests_this_minute: u32,
77/// }
78///
79/// # #[tokio::main]
80/// # async fn main() {
81/// let ctx = PluginContext::new();
82///
83/// // A rate-limiting plugin writes state
84/// ctx.insert(RateLimitState { requests_this_minute: 1 }).await;
85///
86/// // A metrics plugin reads it
87/// if let Some(state) = ctx.get::<RateLimitState>().await {
88///     println!("Requests: {}", state.requests_this_minute);
89/// }
90/// # }
91/// ```
92pub struct PluginContext {
93    state: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
94}
95
96impl PluginContext {
97    /// Creates a new empty `PluginContext`.
98    ///
99    /// # Examples
100    ///
101    /// ```rust
102    /// use adk_plugin::PluginContext;
103    ///
104    /// let ctx = PluginContext::new();
105    /// ```
106    pub fn new() -> Self {
107        Self { state: RwLock::new(HashMap::new()) }
108    }
109
110    /// Inserts a value into the context. The type itself is the key.
111    ///
112    /// If a value of the same type already exists, it is replaced.
113    /// The previous value is discarded.
114    ///
115    /// # Examples
116    ///
117    /// ```rust
118    /// use adk_plugin::PluginContext;
119    ///
120    /// #[derive(Clone, Debug, PartialEq)]
121    /// struct Counter(u32);
122    ///
123    /// # #[tokio::main]
124    /// # async fn main() {
125    /// let ctx = PluginContext::new();
126    /// ctx.insert(Counter(1)).await;
127    /// ctx.insert(Counter(2)).await; // Replaces the previous value
128    ///
129    /// assert_eq!(ctx.get::<Counter>().await, Some(Counter(2)));
130    /// # }
131    /// ```
132    pub async fn insert<T: Send + Sync + 'static>(&self, value: T) {
133        self.state.write().await.insert(TypeId::of::<T>(), Box::new(value));
134    }
135
136    /// Gets a clone of the stored value for type `T`.
137    ///
138    /// Returns `None` if no value of type `T` has been inserted.
139    ///
140    /// # Examples
141    ///
142    /// ```rust
143    /// use adk_plugin::PluginContext;
144    ///
145    /// #[derive(Clone, Debug, PartialEq)]
146    /// struct Name(String);
147    ///
148    /// # #[tokio::main]
149    /// # async fn main() {
150    /// let ctx = PluginContext::new();
151    ///
152    /// assert_eq!(ctx.get::<Name>().await, None);
153    ///
154    /// ctx.insert(Name("alice".to_string())).await;
155    /// assert_eq!(ctx.get::<Name>().await, Some(Name("alice".to_string())));
156    /// # }
157    /// ```
158    pub async fn get<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
159        self.state.read().await.get(&TypeId::of::<T>()).and_then(|v| v.downcast_ref::<T>()).cloned()
160    }
161
162    /// Checks if a value of type `T` exists in the context.
163    ///
164    /// # Examples
165    ///
166    /// ```rust
167    /// use adk_plugin::PluginContext;
168    ///
169    /// #[derive(Clone, Debug)]
170    /// struct Marker;
171    ///
172    /// # #[tokio::main]
173    /// # async fn main() {
174    /// let ctx = PluginContext::new();
175    ///
176    /// assert!(!ctx.contains::<Marker>().await);
177    /// ctx.insert(Marker).await;
178    /// assert!(ctx.contains::<Marker>().await);
179    /// # }
180    /// ```
181    pub async fn contains<T: Send + Sync + 'static>(&self) -> bool {
182        self.state.read().await.contains_key(&TypeId::of::<T>())
183    }
184
185    /// Removes a value of type `T`, returning it if present.
186    ///
187    /// After removal, [`get`](Self::get) and [`contains`](Self::contains) for
188    /// type `T` will return `None` and `false` respectively.
189    ///
190    /// # Examples
191    ///
192    /// ```rust
193    /// use adk_plugin::PluginContext;
194    ///
195    /// #[derive(Clone, Debug, PartialEq)]
196    /// struct Token(String);
197    ///
198    /// # #[tokio::main]
199    /// # async fn main() {
200    /// let ctx = PluginContext::new();
201    /// ctx.insert(Token("abc".to_string())).await;
202    ///
203    /// let removed = ctx.remove::<Token>().await;
204    /// assert_eq!(removed, Some(Token("abc".to_string())));
205    /// assert_eq!(ctx.get::<Token>().await, None);
206    /// # }
207    /// ```
208    pub async fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
209        self.state
210            .write()
211            .await
212            .remove(&TypeId::of::<T>())
213            .and_then(|v| v.downcast::<T>().ok())
214            .map(|b| *b)
215    }
216}
217
218impl Default for PluginContext {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[derive(Clone, Debug, PartialEq)]
229    struct Counter(u32);
230
231    #[derive(Clone, Debug, PartialEq)]
232    struct Name(String);
233
234    #[tokio::test]
235    async fn test_new_context_is_empty() {
236        let ctx = PluginContext::new();
237        assert!(!ctx.contains::<Counter>().await);
238        assert_eq!(ctx.get::<Counter>().await, None);
239    }
240
241    #[tokio::test]
242    async fn test_insert_and_get() {
243        let ctx = PluginContext::new();
244        ctx.insert(Counter(42)).await;
245
246        let value = ctx.get::<Counter>().await;
247        assert_eq!(value, Some(Counter(42)));
248    }
249
250    #[tokio::test]
251    async fn test_insert_overwrites_previous() {
252        let ctx = PluginContext::new();
253        ctx.insert(Counter(1)).await;
254        ctx.insert(Counter(99)).await;
255
256        assert_eq!(ctx.get::<Counter>().await, Some(Counter(99)));
257    }
258
259    #[tokio::test]
260    async fn test_multiple_types() {
261        let ctx = PluginContext::new();
262        ctx.insert(Counter(10)).await;
263        ctx.insert(Name("hello".to_string())).await;
264
265        assert_eq!(ctx.get::<Counter>().await, Some(Counter(10)));
266        assert_eq!(ctx.get::<Name>().await, Some(Name("hello".to_string())));
267    }
268
269    #[tokio::test]
270    async fn test_contains() {
271        let ctx = PluginContext::new();
272        assert!(!ctx.contains::<Counter>().await);
273
274        ctx.insert(Counter(0)).await;
275        assert!(ctx.contains::<Counter>().await);
276    }
277
278    #[tokio::test]
279    async fn test_remove_returns_value() {
280        let ctx = PluginContext::new();
281        ctx.insert(Counter(7)).await;
282
283        let removed = ctx.remove::<Counter>().await;
284        assert_eq!(removed, Some(Counter(7)));
285    }
286
287    #[tokio::test]
288    async fn test_remove_makes_get_return_none() {
289        let ctx = PluginContext::new();
290        ctx.insert(Counter(7)).await;
291        ctx.remove::<Counter>().await;
292
293        assert_eq!(ctx.get::<Counter>().await, None);
294        assert!(!ctx.contains::<Counter>().await);
295    }
296
297    #[tokio::test]
298    async fn test_remove_nonexistent_returns_none() {
299        let ctx = PluginContext::new();
300        let removed = ctx.remove::<Counter>().await;
301        assert_eq!(removed, None);
302    }
303
304    #[tokio::test]
305    async fn test_default_creates_empty_context() {
306        let ctx = PluginContext::default();
307        assert!(!ctx.contains::<Counter>().await);
308    }
309
310    #[tokio::test]
311    async fn test_concurrent_access() {
312        use std::sync::Arc;
313
314        let ctx = Arc::new(PluginContext::new());
315        ctx.insert(Counter(0)).await;
316
317        let ctx_clone = Arc::clone(&ctx);
318        let writer = tokio::spawn(async move {
319            for i in 1..=100 {
320                ctx_clone.insert(Counter(i)).await;
321            }
322        });
323
324        let ctx_clone2 = Arc::clone(&ctx);
325        let reader = tokio::spawn(async move {
326            for _ in 0..100 {
327                // Should never panic — reads are always valid
328                let _ = ctx_clone2.get::<Counter>().await;
329            }
330        });
331
332        writer.await.unwrap();
333        reader.await.unwrap();
334
335        // Final value should be 100 (last write)
336        assert_eq!(ctx.get::<Counter>().await, Some(Counter(100)));
337    }
338}