use async_trait::async_trait;
use super::{Plugin, PluginResult};
use crate::context::InvocationContext;
pub struct ContextFilterPlugin {
max_turns: Option<usize>,
exclude_tool_turns: bool,
}
impl ContextFilterPlugin {
pub fn new() -> Self {
Self {
max_turns: None,
exclude_tool_turns: false,
}
}
pub fn with_max_turns(mut self, max_turns: usize) -> Self {
self.max_turns = Some(max_turns);
self
}
pub fn with_exclude_tool_turns(mut self, exclude: bool) -> Self {
self.exclude_tool_turns = exclude;
self
}
pub fn max_turns(&self) -> Option<usize> {
self.max_turns
}
pub fn exclude_tool_turns(&self) -> bool {
self.exclude_tool_turns
}
}
impl Default for ContextFilterPlugin {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Plugin for ContextFilterPlugin {
fn name(&self) -> &str {
"context_filter"
}
async fn before_model(
&self,
_request: &crate::llm::LlmRequest,
_ctx: &InvocationContext,
) -> PluginResult {
PluginResult::Continue
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let plugin = ContextFilterPlugin::new();
assert!(plugin.max_turns().is_none());
assert!(!plugin.exclude_tool_turns());
}
#[test]
fn custom_config() {
let plugin = ContextFilterPlugin::new()
.with_max_turns(10)
.with_exclude_tool_turns(true);
assert_eq!(plugin.max_turns(), Some(10));
assert!(plugin.exclude_tool_turns());
}
#[test]
fn plugin_name() {
let plugin = ContextFilterPlugin::new();
assert_eq!(plugin.name(), "context_filter");
}
}