use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::event::AgentEvent;
use crate::tokens::{TokenEstimator, CHAR_HEURISTIC};
use crate::tool::{ToolCall, ToolResult};
use crate::types::{AgentMessage, AssistantContent, Usage};
pub trait Plugin: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn capabilities(&self) -> PluginCapabilities {
PluginCapabilities::default()
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct PluginCapabilities {
pub before_tool_call: bool,
pub after_tool_call: bool,
pub context_transform: bool,
pub event_observer: bool,
pub steering: bool,
pub follow_up: bool,
pub tool_gate: bool,
pub inheritable_to_child: bool,
}
impl PluginCapabilities {
pub fn before_tool_call() -> Self {
Self {
before_tool_call: true,
..Self::default()
}
}
pub fn after_tool_call() -> Self {
Self {
after_tool_call: true,
..Self::default()
}
}
pub fn context_transform() -> Self {
Self {
context_transform: true,
..Self::default()
}
}
pub fn event_observer() -> Self {
Self {
event_observer: true,
..Self::default()
}
}
pub fn steering() -> Self {
Self {
steering: true,
..Self::default()
}
}
pub fn follow_up() -> Self {
Self {
follow_up: true,
..Self::default()
}
}
pub fn tool_gate() -> Self {
Self {
tool_gate: true,
..Self::default()
}
}
pub fn with_follow_up(mut self) -> Self {
self.follow_up = true;
self
}
pub fn with_tool_gate(mut self) -> Self {
self.tool_gate = true;
self
}
pub fn with_inheritable_to_child(mut self) -> Self {
self.inheritable_to_child = true;
self
}
}
pub struct BeforeToolCallContext<'a> {
pub assistant_message: &'a AgentMessage,
pub assistant_content: &'a AssistantContent,
pub tool_call: &'a ToolCall,
pub args: &'a Value,
pub messages: &'a [AgentMessage],
}
#[derive(Debug, Clone, Default)]
pub struct BeforeToolDecision {
pub block: bool,
pub reason: Option<String>,
pub details: Option<Value>,
}
impl BeforeToolDecision {
pub fn allow() -> Self {
Self::default()
}
pub fn block(reason: impl Into<String>) -> Self {
Self {
block: true,
reason: Some(reason.into()),
details: None,
}
}
pub fn block_with_details(reason: impl Into<String>, details: Value) -> Self {
Self {
block: true,
reason: Some(reason.into()),
details: Some(details),
}
}
}
#[async_trait]
pub trait BeforeToolCall: Plugin {
async fn on_before_tool_call(&self, ctx: BeforeToolCallContext<'_>) -> BeforeToolDecision;
}
pub struct AfterToolCallContext<'a> {
pub assistant_message: &'a AgentMessage,
pub tool_call: &'a ToolCall,
pub args: &'a Value,
pub result: &'a ToolResult,
pub is_error: bool,
pub messages: &'a [AgentMessage],
}
#[derive(Debug, Clone, Default)]
pub struct AfterToolDecision {
pub result: Option<ToolResult>,
pub mark_error: Option<bool>,
pub terminate: Option<bool>,
}
impl AfterToolDecision {
pub fn passthrough() -> Self {
Self::default()
}
pub fn override_result(result: ToolResult) -> Self {
Self {
result: Some(result),
..Self::default()
}
}
}
#[async_trait]
pub trait AfterToolCall: Plugin {
async fn on_after_tool_call(&self, ctx: AfterToolCallContext<'_>) -> AfterToolDecision;
}
pub struct TransformContext<'a> {
pub signal: &'a CancellationToken,
pub model_id: &'a str,
pub iteration: usize,
pub last_provider_usage: Option<&'a Usage>,
pub estimator: &'a dyn TokenEstimator,
}
impl<'a> TransformContext<'a> {
pub fn for_test(signal: &'a CancellationToken) -> Self {
Self {
signal,
model_id: "",
iteration: 0,
last_provider_usage: None,
estimator: &CHAR_HEURISTIC,
}
}
}
#[async_trait]
pub trait ContextTransform: Plugin {
fn should_run(&self, _messages: &[AgentMessage], _cx: &TransformContext<'_>) -> bool {
true
}
async fn transform(
&self,
messages: Vec<AgentMessage>,
cx: &TransformContext<'_>,
) -> Vec<AgentMessage>;
}
#[async_trait]
pub trait EventObserver: Plugin {
async fn on_event(&self, event: &AgentEvent);
}
#[async_trait]
pub trait SteeringSource: Plugin {
async fn next_steering_messages(&self) -> Vec<AgentMessage>;
}
#[async_trait]
pub trait FollowUpSource: Plugin {
async fn next_follow_up_messages(&self) -> Vec<AgentMessage>;
}
pub struct ToolGateContext<'a> {
pub iteration: usize,
pub messages: &'a [AgentMessage],
pub conversation_id: Option<&'a str>,
pub available_tool_names: &'a [&'a str],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolGateClass {
Required,
Advisory,
}
#[async_trait]
pub trait ToolGate: Plugin {
async fn next_turn_tool_allowlist(
&self,
ctx: ToolGateContext<'_>,
) -> Option<std::collections::HashSet<String>>;
async fn denial_reason(&self, _tool_name: &str, _ctx: ToolGateContext<'_>) -> Option<String> {
None
}
fn conflict_priority(&self) -> i32 {
0
}
fn tool_gate_class(&self) -> ToolGateClass {
ToolGateClass::Required
}
fn suppresses_advisory_gates(&self, _ctx: ToolGateContext<'_>) -> bool {
false
}
}
pub struct ChannelSteering {
rx: tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
}
#[derive(Clone)]
pub struct SteeringHandle {
tx: tokio::sync::mpsc::UnboundedSender<AgentMessage>,
}
impl SteeringHandle {
#[allow(clippy::result_large_err)]
pub fn steer(
&self,
message: AgentMessage,
) -> Result<(), tokio::sync::mpsc::error::SendError<AgentMessage>> {
self.tx.send(message)
}
}
impl ChannelSteering {
pub fn new() -> (Arc<Self>, SteeringHandle) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
(
Arc::new(Self {
rx: tokio::sync::Mutex::new(rx),
}),
SteeringHandle { tx },
)
}
}
impl Plugin for ChannelSteering {
fn name(&self) -> &'static str {
"channel_steering"
}
fn capabilities(&self) -> PluginCapabilities {
PluginCapabilities::steering()
}
}
#[async_trait]
impl SteeringSource for ChannelSteering {
async fn next_steering_messages(&self) -> Vec<AgentMessage> {
let mut rx = self.rx.lock().await;
let mut out = Vec::new();
while let Ok(msg) = rx.try_recv() {
out.push(msg);
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::UserContent;
#[tokio::test]
async fn channel_steering_drains() {
let (source, handle) = ChannelSteering::new();
handle
.steer(AgentMessage::User {
content: UserContent::Text("hi".into()),
timestamp: None,
})
.unwrap();
handle
.steer(AgentMessage::User {
content: UserContent::Text("again".into()),
timestamp: None,
})
.unwrap();
let drained = source.next_steering_messages().await;
assert_eq!(drained.len(), 2);
let drained2 = source.next_steering_messages().await;
assert!(drained2.is_empty());
}
}