use std::collections::HashMap;
use awaken_contract::StateError;
use awaken_contract::contract::active_agent::ActiveAgentIdKey;
use awaken_contract::model::Phase;
use awaken_contract::registry_spec::AgentSpec;
use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
use crate::state::{KeyScope, MutationBatch, StateKeyOptions};
use super::action::HandoffAction;
use super::hook::HandoffSyncHook;
use super::state::{ActiveAgentKey, HandoffState};
use super::types::AgentOverlay;
pub const HANDOFF_PLUGIN_ID: &str = "agent_handoff";
pub struct HandoffPlugin {
overlays: HashMap<String, AgentOverlay>,
}
impl HandoffPlugin {
pub fn new(overlays: HashMap<String, AgentOverlay>) -> Self {
Self { overlays }
}
pub fn overlay(&self, agent: &str) -> Option<&AgentOverlay> {
self.overlays.get(agent)
}
pub fn effective_agent(state: &HandoffState) -> Option<&String> {
state
.requested_agent
.as_ref()
.or(state.active_agent.as_ref())
}
}
impl Plugin for HandoffPlugin {
fn descriptor(&self) -> PluginDescriptor {
PluginDescriptor {
name: HANDOFF_PLUGIN_ID,
}
}
fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
let thread_scope = StateKeyOptions {
scope: KeyScope::Thread,
persistent: true,
..StateKeyOptions::default()
};
registrar.register_key::<ActiveAgentKey>(thread_scope)?;
registrar.register_key::<ActiveAgentIdKey>(thread_scope)?;
registrar.register_phase_hook(HANDOFF_PLUGIN_ID, Phase::RunStart, HandoffSyncHook)?;
registrar.register_phase_hook(HANDOFF_PLUGIN_ID, Phase::StepEnd, HandoffSyncHook)?;
Ok(())
}
fn on_activate(
&self,
_agent_spec: &AgentSpec,
_patch: &mut MutationBatch,
) -> Result<(), StateError> {
Ok(())
}
fn on_deactivate(&self, patch: &mut MutationBatch) -> Result<(), StateError> {
patch.update::<ActiveAgentKey>(HandoffAction::Clear);
patch.update::<ActiveAgentIdKey>(None);
Ok(())
}
}