use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::marker::PhantomData;
use serde_json::Value;
use std::sync::Arc;
use crate::BoxFuture;
use crate::agents::directive::Directive;
use crate::agents::streaming::AgentEvent;
use crate::tools::Tool;
pub trait PluginStateKey: Send + Sync + 'static {
type State: Send + Sync + 'static;
const KEY: &'static str;
}
pub struct PluginHandle<P: PluginStateKey> {
_marker: PhantomData<P>,
}
impl<P: PluginStateKey> std::fmt::Debug for PluginHandle<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PluginHandle")
.field("key", &P::KEY)
.finish()
}
}
impl<P: PluginStateKey> Clone for PluginHandle<P> {
fn clone(&self) -> Self {
*self
}
}
impl<P: PluginStateKey> Copy for PluginHandle<P> {}
struct PluginStateMeta {
value: Box<dyn Any + Send + Sync>,
serialize: fn(&dyn Any) -> Option<Value>,
key: &'static str,
}
fn serialize_fn<T: serde::Serialize + 'static>(v: &dyn Any) -> Option<Value> {
v.downcast_ref::<T>()
.and_then(|t| serde_json::to_value(t).ok())
}
#[derive(Default)]
pub struct PluginStateMap {
entries: HashMap<TypeId, PluginStateMeta>,
}
impl std::fmt::Debug for PluginStateMap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PluginStateMap")
.field("len", &self.entries.len())
.finish()
}
}
impl PluginStateMap {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<P>(&mut self, state: P::State) -> Result<PluginHandle<P>, &'static str>
where
P: PluginStateKey,
P::State: serde::Serialize + 'static,
{
let id = TypeId::of::<P>();
if self.entries.contains_key(&id) {
return Err(P::KEY);
}
let _ = self.entries.insert(
id,
PluginStateMeta {
value: Box::new(state),
serialize: serialize_fn::<P::State>,
key: P::KEY,
},
);
Ok(PluginHandle {
_marker: PhantomData,
})
}
#[must_use]
pub fn get<P: PluginStateKey>(&self) -> Option<&P::State> {
self.entries
.get(&TypeId::of::<P>())
.and_then(|m| m.value.downcast_ref::<P::State>())
}
pub fn get_mut<P: PluginStateKey>(&mut self) -> Option<&mut P::State> {
self.entries
.get_mut(&TypeId::of::<P>())
.and_then(|m| m.value.downcast_mut::<P::State>())
}
pub fn insert<P: PluginStateKey>(&mut self, state: P::State)
where
P::State: serde::Serialize + 'static,
{
let _ = self.entries.insert(
TypeId::of::<P>(),
PluginStateMeta {
value: Box::new(state),
serialize: serialize_fn::<P::State>,
key: P::KEY,
},
);
}
#[must_use]
pub fn serialize_all(&self) -> Value {
let mut map = serde_json::Map::new();
for meta in self.entries.values() {
if let Some(v) = (meta.serialize)(meta.value.as_ref()) {
let _ = map.insert(meta.key.to_string(), v);
}
}
Value::Object(map)
}
}
#[derive(Debug, Clone)]
pub struct PluginInput {
pub turn: u32,
pub message: Option<String>,
}
pub trait Plugin: Send + Sync {
fn name(&self) -> &str;
fn on_user_message<'a>(
&'a self,
_input: &'a PluginInput,
_state: &'a PluginStateMap,
) -> BoxFuture<'a, Vec<Directive>> {
Box::pin(async { Vec::new() })
}
fn on_event<'a>(
&'a self,
_event: &'a AgentEvent,
_state: &'a PluginStateMap,
) -> BoxFuture<'a, Vec<Directive>> {
Box::pin(async { Vec::new() })
}
fn before_run<'a>(&'a self, _state: &'a PluginStateMap) -> BoxFuture<'a, Vec<Directive>> {
Box::pin(async { Vec::new() })
}
fn after_run<'a>(&'a self, _state: &'a PluginStateMap) -> BoxFuture<'a, Vec<Directive>> {
Box::pin(async { Vec::new() })
}
fn signal_routes(&self) -> Vec<crate::agents::signal::SignalRoute> {
Vec::new()
}
fn tools(&self) -> Vec<Arc<dyn Tool>> {
Vec::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct CounterState {
count: u32,
}
struct CounterKey;
impl PluginStateKey for CounterKey {
type State = CounterState;
const KEY: &'static str = "counter";
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct FlagState {
enabled: bool,
}
struct FlagKey;
impl PluginStateKey for FlagKey {
type State = FlagState;
const KEY: &'static str = "flag";
}
#[test]
fn test_type_safe_access() {
let mut map = PluginStateMap::new();
let _handle = map
.register::<CounterKey>(CounterState { count: 0 })
.expect("register");
let state = map.get::<CounterKey>().expect("get");
assert_eq!(state.count, 0);
map.get_mut::<CounterKey>().expect("get_mut").count = 42;
assert_eq!(map.get::<CounterKey>().expect("get after mut").count, 42);
}
#[test]
fn test_cross_plugin_isolation() {
let mut map = PluginStateMap::new();
let _ = map.register::<CounterKey>(CounterState { count: 10 });
let _ = map.register::<FlagKey>(FlagState { enabled: true });
assert!(map.get::<CounterKey>().is_some());
assert!(map.get::<FlagKey>().is_some());
map.get_mut::<CounterKey>().expect("mut").count = 99;
assert!(map.get::<FlagKey>().expect("flag").enabled);
}
#[test]
fn test_key_collision_detection() {
let mut map = PluginStateMap::new();
let _ = map
.register::<CounterKey>(CounterState { count: 0 })
.expect("first register");
let err = map
.register::<CounterKey>(CounterState { count: 1 })
.expect_err("second register should fail");
assert_eq!(err, CounterKey::KEY);
}
#[test]
fn test_serialization_round_trip() {
let mut map = PluginStateMap::new();
let _ = map.register::<CounterKey>(CounterState { count: 7 });
let _ = map.register::<FlagKey>(FlagState { enabled: false });
let serialized = map.serialize_all();
assert_eq!(serialized["counter"]["count"], 7);
assert_eq!(serialized["flag"]["enabled"], false);
}
}