use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::plugins::{Plugin, PluginRegistrar};
use awaken_contract::StateError;
use awaken_contract::contract::tool::Tool;
use awaken_contract::model::Phase;
use crate::plugins::{KeyRegistration, ProfileKeyRegistration, RequestTransformArc};
use super::{EffectHandlerArc, PhaseHookArc, ScheduledActionHandlerArc, ToolGateHookArc};
#[derive(Clone)]
pub(crate) struct TaggedPhaseHook {
pub(crate) plugin_id: String,
pub(crate) hook: PhaseHookArc,
}
#[derive(Clone)]
pub(crate) struct TaggedToolGateHook {
pub(crate) plugin_id: String,
pub(crate) hook: ToolGateHookArc,
}
#[derive(Clone)]
pub(crate) struct TaggedRequestTransform {
pub(crate) plugin_id: String,
pub(crate) transform: RequestTransformArc,
}
#[derive(Clone)]
pub struct ExecutionEnv {
pub(crate) phase_hooks: HashMap<Phase, Vec<TaggedPhaseHook>>,
pub(crate) tool_gate_hooks: Vec<TaggedToolGateHook>,
pub(crate) scheduled_action_handlers: HashMap<String, ScheduledActionHandlerArc>,
pub(crate) effect_handlers: HashMap<String, EffectHandlerArc>,
pub(crate) request_transforms: Vec<TaggedRequestTransform>,
pub(crate) key_registrations: Vec<KeyRegistration>,
pub(crate) tools: HashMap<String, Arc<dyn Tool>>,
pub(crate) plugins: Vec<Arc<dyn Plugin>>,
pub profile_key_registrations: Vec<ProfileKeyRegistration>,
}
impl ExecutionEnv {
pub fn from_plugins(
plugins: &[Arc<dyn Plugin>],
active_plugin_filter: &HashSet<String>,
) -> Result<Self, StateError> {
let mut all_hooks: HashMap<Phase, Vec<TaggedPhaseHook>> = HashMap::new();
let mut all_tool_gate_hooks: Vec<TaggedToolGateHook> = Vec::new();
let mut all_action_handlers: HashMap<String, ScheduledActionHandlerArc> = HashMap::new();
let mut all_effect_handlers: HashMap<String, EffectHandlerArc> = HashMap::new();
let mut all_transforms: Vec<TaggedRequestTransform> = Vec::new();
let mut all_key_registrations: Vec<KeyRegistration> = Vec::new();
let mut all_profile_key_registrations: Vec<ProfileKeyRegistration> = Vec::new();
let mut all_tools: HashMap<String, Arc<dyn Tool>> = HashMap::new();
for plugin in plugins {
let plugin_name = plugin.descriptor().name.to_string();
let plugin_active =
active_plugin_filter.is_empty() || active_plugin_filter.contains(&plugin_name);
let mut registrar = PluginRegistrar::new();
plugin.register(&mut registrar)?;
if plugin_active {
for entry in registrar.tools {
if all_tools.contains_key(&entry.id) {
return Err(StateError::ToolAlreadyRegistered { tool_id: entry.id });
}
tracing::debug!(
plugin_id = %plugin_name,
tool_id = %entry.id,
"registered_plugin_tool"
);
all_tools.insert(entry.id, entry.tool);
}
} else {
tracing::debug!(
plugin_id = %plugin_name,
tools_skipped = registrar.tools.len(),
"plugin_tools_filtered"
);
}
if plugin_active {
for entry in registrar.phase_hooks {
all_hooks
.entry(entry.phase)
.or_default()
.push(TaggedPhaseHook {
plugin_id: entry.plugin_id,
hook: entry.hook,
});
}
for entry in registrar.tool_gate_hooks {
all_tool_gate_hooks.push(TaggedToolGateHook {
plugin_id: entry.plugin_id,
hook: entry.hook,
});
}
}
for entry in registrar.scheduled_actions {
if all_action_handlers.contains_key(&entry.key) {
return Err(StateError::HandlerAlreadyRegistered { key: entry.key });
}
all_action_handlers.insert(entry.key, entry.handler);
}
for entry in registrar.effects {
if all_effect_handlers.contains_key(&entry.key) {
return Err(StateError::EffectHandlerAlreadyRegistered { key: entry.key });
}
all_effect_handlers.insert(entry.key, entry.handler);
}
if plugin_active {
for entry in registrar.request_transforms {
tracing::debug!(plugin_id = %entry.plugin_id, "registered_request_transform");
all_transforms.push(TaggedRequestTransform {
plugin_id: entry.plugin_id,
transform: entry.transform,
});
}
} else {
tracing::debug!(
plugin_id = %plugin_name,
transforms_skipped = registrar.request_transforms.len(),
"plugin_transforms_filtered"
);
}
all_key_registrations.extend(registrar.keys);
all_profile_key_registrations.extend(registrar.profile_keys);
}
Ok(Self {
phase_hooks: all_hooks,
tool_gate_hooks: all_tool_gate_hooks,
scheduled_action_handlers: all_action_handlers,
effect_handlers: all_effect_handlers,
request_transforms: all_transforms,
key_registrations: all_key_registrations,
profile_key_registrations: all_profile_key_registrations,
tools: all_tools,
plugins: plugins.to_vec(),
})
}
pub fn empty() -> Self {
Self {
phase_hooks: HashMap::new(),
tool_gate_hooks: Vec::new(),
scheduled_action_handlers: HashMap::new(),
effect_handlers: HashMap::new(),
request_transforms: Vec::new(),
key_registrations: Vec::new(),
profile_key_registrations: Vec::new(),
tools: HashMap::new(),
plugins: Vec::new(),
}
}
pub(crate) fn transform_arcs(&self) -> Vec<RequestTransformArc> {
self.request_transforms
.iter()
.map(|t| {
tracing::trace!(plugin_id = %t.plugin_id, "collecting_request_transform");
Arc::clone(&t.transform)
})
.collect()
}
pub(crate) fn hooks_for_phase(&self, phase: Phase) -> &[TaggedPhaseHook] {
self.phase_hooks
.get(&phase)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub(crate) fn tool_gate_hooks(&self) -> &[TaggedToolGateHook] {
&self.tool_gate_hooks
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hooks::PhaseContext;
use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
use crate::state::StateCommand;
use async_trait::async_trait;
use awaken_contract::StateError;
use awaken_contract::contract::message::Message;
use awaken_contract::contract::tool::{
Tool, ToolCallContext, ToolDescriptor, ToolError, ToolOutput, ToolResult,
};
use awaken_contract::contract::transform::{InferenceRequestTransform, TransformOutput};
use awaken_contract::model::Phase;
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
struct NoOpHook;
#[async_trait]
impl crate::hooks::PhaseHook for NoOpHook {
async fn run(&self, _ctx: &PhaseContext) -> Result<StateCommand, StateError> {
Ok(StateCommand::default())
}
}
struct NoOpTool(String);
#[async_trait]
impl Tool for NoOpTool {
fn descriptor(&self) -> ToolDescriptor {
ToolDescriptor::new(&self.0, &self.0, format!("{} tool", self.0))
}
async fn execute(
&self,
_args: Value,
_ctx: &ToolCallContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolResult::success(&self.0, Value::Null).into())
}
}
struct NoOpTransform;
impl InferenceRequestTransform for NoOpTransform {
fn transform(&self, messages: Vec<Message>, _tools: &[ToolDescriptor]) -> TransformOutput {
TransformOutput { messages }
}
}
struct FullPlugin {
name: &'static str,
tool_id: &'static str,
}
impl Plugin for FullPlugin {
fn descriptor(&self) -> PluginDescriptor {
PluginDescriptor { name: self.name }
}
fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
registrar.register_tool(self.tool_id, Arc::new(NoOpTool(self.tool_id.into())))?;
registrar.register_phase_hook(self.name, Phase::StepStart, NoOpHook)?;
registrar.register_request_transform(self.name, NoOpTransform);
Ok(())
}
}
#[test]
fn empty_filter_includes_all_plugins() {
let plugins: Vec<Arc<dyn Plugin>> = vec![
Arc::new(FullPlugin {
name: "alpha",
tool_id: "alpha_tool",
}),
Arc::new(FullPlugin {
name: "beta",
tool_id: "beta_tool",
}),
];
let env = ExecutionEnv::from_plugins(&plugins, &HashSet::new()).unwrap();
assert!(env.tools.contains_key("alpha_tool"));
assert!(env.tools.contains_key("beta_tool"));
assert_eq!(env.hooks_for_phase(Phase::StepStart).len(), 2);
assert_eq!(env.request_transforms.len(), 2);
}
#[test]
fn filter_excludes_tools_from_inactive_plugin() {
let plugins: Vec<Arc<dyn Plugin>> = vec![
Arc::new(FullPlugin {
name: "alpha",
tool_id: "alpha_tool",
}),
Arc::new(FullPlugin {
name: "beta",
tool_id: "beta_tool",
}),
];
let filter: HashSet<String> = ["alpha".to_string()].into();
let env = ExecutionEnv::from_plugins(&plugins, &filter).unwrap();
assert!(
env.tools.contains_key("alpha_tool"),
"active plugin tool should be present"
);
assert!(
!env.tools.contains_key("beta_tool"),
"inactive plugin tool should be filtered"
);
}
#[test]
fn filter_excludes_hooks_from_inactive_plugin() {
let plugins: Vec<Arc<dyn Plugin>> = vec![
Arc::new(FullPlugin {
name: "alpha",
tool_id: "alpha_tool",
}),
Arc::new(FullPlugin {
name: "beta",
tool_id: "beta_tool",
}),
];
let filter: HashSet<String> = ["alpha".to_string()].into();
let env = ExecutionEnv::from_plugins(&plugins, &filter).unwrap();
let hooks = env.hooks_for_phase(Phase::StepStart);
assert_eq!(hooks.len(), 1);
assert_eq!(hooks[0].plugin_id, "alpha");
}
#[test]
fn filter_excludes_transforms_from_inactive_plugin() {
let plugins: Vec<Arc<dyn Plugin>> = vec![
Arc::new(FullPlugin {
name: "alpha",
tool_id: "alpha_tool",
}),
Arc::new(FullPlugin {
name: "beta",
tool_id: "beta_tool",
}),
];
let filter: HashSet<String> = ["beta".to_string()].into();
let env = ExecutionEnv::from_plugins(&plugins, &filter).unwrap();
assert_eq!(env.request_transforms.len(), 1);
assert_eq!(env.request_transforms[0].plugin_id, "beta");
}
#[test]
fn filter_with_all_plugins_active_includes_everything() {
let plugins: Vec<Arc<dyn Plugin>> = vec![
Arc::new(FullPlugin {
name: "alpha",
tool_id: "alpha_tool",
}),
Arc::new(FullPlugin {
name: "beta",
tool_id: "beta_tool",
}),
];
let filter: HashSet<String> = ["alpha".to_string(), "beta".to_string()].into();
let env = ExecutionEnv::from_plugins(&plugins, &filter).unwrap();
assert_eq!(env.tools.len(), 2);
assert_eq!(env.hooks_for_phase(Phase::StepStart).len(), 2);
assert_eq!(env.request_transforms.len(), 2);
}
#[test]
fn filter_with_no_matching_plugins_produces_empty_env() {
let plugins: Vec<Arc<dyn Plugin>> = vec![Arc::new(FullPlugin {
name: "alpha",
tool_id: "alpha_tool",
})];
let filter: HashSet<String> = ["nonexistent".to_string()].into();
let env = ExecutionEnv::from_plugins(&plugins, &filter).unwrap();
assert!(env.tools.is_empty());
assert!(env.hooks_for_phase(Phase::StepStart).is_empty());
assert!(env.request_transforms.is_empty());
}
#[test]
fn filter_still_registers_keys_from_inactive_plugins() {
let plugins: Vec<Arc<dyn Plugin>> = vec![Arc::new(FullPlugin {
name: "alpha",
tool_id: "alpha_tool",
})];
let filter: HashSet<String> = ["nonexistent".to_string()].into();
let env = ExecutionEnv::from_plugins(&plugins, &filter).unwrap();
assert!(env.tools.is_empty());
assert!(env.hooks_for_phase(Phase::StepStart).is_empty());
}
#[test]
fn from_plugins_collects_profile_keys() {
use awaken_contract::contract::profile_store::ProfileKey;
struct ProfileLocale;
impl ProfileKey for ProfileLocale {
const KEY: &'static str = "locale";
type Value = String;
}
struct ProfilePlugin;
impl Plugin for ProfilePlugin {
fn descriptor(&self) -> PluginDescriptor {
PluginDescriptor {
name: "profile-plugin",
}
}
fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
registrar.register_profile_key::<ProfileLocale>()?;
Ok(())
}
}
let plugins: Vec<Arc<dyn Plugin>> = vec![Arc::new(ProfilePlugin)];
let filter: HashSet<String> = ["nonexistent".to_string()].into();
let env = ExecutionEnv::from_plugins(&plugins, &filter).unwrap();
assert_eq!(env.profile_key_registrations.len(), 1);
assert_eq!(env.profile_key_registrations[0].key, "locale");
}
}