use async_trait::async_trait;
use echo_core::llm::{ChatRequest, LlmClient, Message};
use serde_json::Value;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
use echo_core::error::Result;
#[async_trait]
pub trait Classifier: Send + Sync {
async fn classify(
&self,
tool_name: &str,
tool_input: &Value,
context: &ClassifierContext,
) -> Result<ClassifierResult>;
}
#[derive(Debug, Clone, Default)]
pub struct RiskContext {
pub has_sensitive_files: bool,
pub is_destructive: bool,
pub directory_depth: usize,
pub repetition_count: u32,
}
#[derive(Debug, Clone)]
pub struct ClassifierContext {
pub messages: Vec<Message>,
pub agent_name: String,
pub session_id: String,
pub allow_rules: Vec<String>,
pub soft_deny_rules: Vec<String>,
pub workspace_path: Option<String>,
pub project_type: Option<String>,
pub recent_files: Vec<String>,
pub risk_context: Option<RiskContext>,
}
impl ClassifierContext {
pub fn new(agent_name: String, session_id: String) -> Self {
Self {
messages: Vec::new(),
agent_name,
session_id,
allow_rules: Vec::new(),
soft_deny_rules: Vec::new(),
workspace_path: None,
project_type: None,
recent_files: Vec::new(),
risk_context: None,
}
}
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn with_allow_rules(mut self, rules: Vec<String>) -> Self {
self.allow_rules = rules;
self
}
pub fn with_soft_deny_rules(mut self, rules: Vec<String>) -> Self {
self.soft_deny_rules = rules;
self
}
pub fn with_workspace_path(mut self, path: String) -> Self {
self.workspace_path = Some(path);
self
}
pub fn with_project_type(mut self, ptype: String) -> Self {
self.project_type = Some(ptype);
self
}
pub fn with_recent_files(mut self, files: Vec<String>) -> Self {
self.recent_files = files;
self
}
pub fn with_risk_context(mut self, ctx: RiskContext) -> Self {
self.risk_context = Some(ctx);
self
}
}
#[derive(Debug, Clone)]
pub struct ClassifierResult {
pub should_block: bool,
pub reason: String,
pub confidence: f32,
}
impl ClassifierResult {
pub fn allow(reason: String) -> Self {
Self {
should_block: false,
reason,
confidence: 1.0,
}
}
pub fn block(reason: String) -> Self {
Self {
should_block: true,
reason,
confidence: 1.0,
}
}
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
}
#[derive(Debug)]
pub struct DenialTracker {
consecutive_denials: u32,
max_consecutive: u32,
total_denials: u32,
max_total_denials: u32,
last_denial_time: Option<Instant>,
}
impl DenialTracker {
pub const DEFAULT_MAX_CONSECUTIVE: u32 = 3;
pub const DEFAULT_MAX_TOTAL: u32 = 20;
pub fn new() -> Self {
Self {
consecutive_denials: 0,
max_consecutive: Self::DEFAULT_MAX_CONSECUTIVE,
total_denials: 0,
max_total_denials: Self::DEFAULT_MAX_TOTAL,
last_denial_time: None,
}
}
pub fn with_max_consecutive(max: u32) -> Self {
Self {
consecutive_denials: 0,
max_consecutive: max,
total_denials: 0,
max_total_denials: Self::DEFAULT_MAX_TOTAL,
last_denial_time: None,
}
}
pub fn with_max_total(mut self, max: u32) -> Self {
self.max_total_denials = max;
self
}
pub fn record_denial(&mut self) {
self.consecutive_denials += 1;
self.total_denials += 1;
self.last_denial_time = Some(Instant::now());
}
pub fn reset(&mut self) {
self.consecutive_denials = 0;
}
pub fn should_fallback(&self) -> bool {
self.consecutive_denials >= self.max_consecutive
|| self.total_denials >= self.max_total_denials
}
pub fn consecutive_denials(&self) -> u32 {
self.consecutive_denials
}
pub fn total_denials(&self) -> u32 {
self.total_denials
}
pub fn last_denial_time(&self) -> Option<Instant> {
self.last_denial_time
}
}
impl Default for DenialTracker {
fn default() -> Self {
Self::new()
}
}
pub struct RuleClassifier {
deny_patterns: Vec<String>,
allow_patterns: Vec<String>,
denial_tracker: Arc<Mutex<DenialTracker>>,
}
impl RuleClassifier {
pub fn new() -> Self {
Self {
deny_patterns: Vec::new(),
allow_patterns: Vec::new(),
denial_tracker: Arc::new(Mutex::new(DenialTracker::new())),
}
}
pub fn add_deny_pattern(&mut self, pattern: String) {
self.deny_patterns.push(pattern);
}
pub fn add_allow_pattern(&mut self, pattern: String) {
self.allow_patterns.push(pattern);
}
pub fn with_deny_patterns(mut self, patterns: Vec<String>) -> Self {
self.deny_patterns = patterns;
self
}
pub fn with_allow_patterns(mut self, patterns: Vec<String>) -> Self {
self.allow_patterns = patterns;
self
}
fn matches_pattern(pattern: &str, tool_name: &str) -> bool {
super::pattern::matches_tool_pattern(pattern, tool_name)
}
}
impl Default for RuleClassifier {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Classifier for RuleClassifier {
async fn classify(
&self,
tool_name: &str,
_tool_input: &Value,
_context: &ClassifierContext,
) -> Result<ClassifierResult> {
for pattern in &self.allow_patterns {
if Self::matches_pattern(pattern, tool_name) {
return Ok(ClassifierResult::allow(format!(
"工具 '{}' 匹配允许规则 '{}'",
tool_name, pattern
)));
}
}
for pattern in &self.deny_patterns {
if Self::matches_pattern(pattern, tool_name) {
let mut tracker = self.denial_tracker.lock().await;
tracker.record_denial();
return Ok(ClassifierResult::block(format!(
"工具 '{}' 匹配阻止规则 '{}'",
tool_name, pattern
)));
}
}
Ok(ClassifierResult::allow(format!(
"工具 '{}' 未匹配任何规则,默认允许",
tool_name
)))
}
}
pub struct LlmClassifier<C: LlmClient + ?Sized> {
client: Arc<C>,
#[allow(dead_code)] model: String,
denial_tracker: Arc<Mutex<DenialTracker>>,
prompt_template: String,
}
impl<C: LlmClient + ?Sized> LlmClassifier<C> {
pub fn new(client: Arc<C>, model: String) -> Self {
Self {
client,
model,
denial_tracker: Arc::new(Mutex::new(DenialTracker::new())),
prompt_template: Self::default_prompt_template(),
}
}
pub fn with_prompt_template(mut self, template: String) -> Self {
self.prompt_template = template;
self
}
fn truncate_tool_input(tool_input: &Value) -> String {
let full = serde_json::to_string_pretty(tool_input).unwrap_or_default();
if full.len() <= 2000 {
return full;
}
let end = full
.char_indices()
.take_while(|(idx, _)| *idx < 2000)
.last()
.map(|(idx, c)| idx + c.len_utf8())
.unwrap_or(0);
format!("{}...(truncated)", &full[..end])
}
fn default_prompt_template() -> String {
r#"你是一个权限决策助手。请根据以下信息判断是否允许执行工具调用。
## 工具信息
- 工具名称: {tool_name}
- 工具参数: {tool_input}
## 环境上下文
- 工作区: {workspace_path}
- 项目类型: {project_type}
- 最近文件: {recent_files}
- 风险信号: {risk_context}
## 规则
### 允许规则:
{allow_rules}
### 需谨慎的规则:
{soft_deny_rules}
## 对话上下文
{conversation_summary}
请判断是否应该阻止此工具调用。回复格式:
- 如果允许: {"should_block": false, "reason": "允许原因"}
- 如果阻止: {"should_block": true, "reason": "阻止原因"}
请只返回 JSON,不要其他内容。"#
.to_string()
}
fn build_prompt(
&self,
tool_name: &str,
tool_input: &Value,
context: &ClassifierContext,
) -> String {
let allow_rules = if context.allow_rules.is_empty() {
"无".to_string()
} else {
context.allow_rules.join("\n- ")
};
let soft_deny_rules = if context.soft_deny_rules.is_empty() {
"无".to_string()
} else {
context.soft_deny_rules.join("\n- ")
};
let conversation_summary = context
.messages
.iter()
.filter_map(|m| m.content.as_text())
.take(5)
.collect::<Vec<_>>()
.join("\n");
self.prompt_template
.replace("{tool_name}", tool_name)
.replace("{tool_input}", &Self::truncate_tool_input(tool_input))
.replace("{allow_rules}", &allow_rules)
.replace("{soft_deny_rules}", &soft_deny_rules)
.replace("{conversation_summary}", &conversation_summary)
.replace(
"{workspace_path}",
context.workspace_path.as_deref().unwrap_or("unknown"),
)
.replace(
"{project_type}",
context.project_type.as_deref().unwrap_or("unknown"),
)
.replace("{recent_files}", &context.recent_files.join(", "))
.replace(
"{risk_context}",
&context
.risk_context
.as_ref()
.map_or("none".to_string(), |r| {
format!(
"sensitive_files={}, destructive={}, depth={}, repetition={}",
r.has_sensitive_files,
r.is_destructive,
r.directory_depth,
r.repetition_count
)
}),
)
}
fn parse_response(response: &str) -> Option<ClassifierResult> {
if let Some(json_str) = Self::extract_json(response)
&& let Ok(json) = serde_json::from_str::<Value>(&json_str)
{
let should_block = json
.get("should_block")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let reason = json
.get("reason")
.and_then(|v| v.as_str())
.unwrap_or("未提供原因")
.to_string();
return Some(if should_block {
ClassifierResult::block(reason)
} else {
ClassifierResult::allow(reason)
});
}
if response.contains("阻止") || response.contains("block") {
Some(ClassifierResult::block("LLM 判断应阻止".to_string()))
} else if response.contains("允许") || response.contains("allow") {
Some(ClassifierResult::allow("LLM 判断应允许".to_string()))
} else {
None
}
}
fn extract_json(text: &str) -> Option<String> {
let trimmed = text.trim();
let mut depth = 0i32;
let mut start: Option<usize> = None;
let chars = trimmed.char_indices();
for (i, c) in chars {
if c == '{' {
if depth == 0 {
start = Some(i);
}
depth += 1;
} else if c == '}' {
depth -= 1;
if depth == 0
&& let Some(s) = start
{
return Some(trimmed[s..=i].to_string());
}
}
}
None
}
}
#[async_trait]
impl<C: LlmClient + ?Sized> Classifier for LlmClassifier<C> {
async fn classify(
&self,
tool_name: &str,
tool_input: &Value,
context: &ClassifierContext,
) -> Result<ClassifierResult> {
{
let tracker = self.denial_tracker.lock().await;
if tracker.should_fallback() {
return Ok(
ClassifierResult::block("连续拒绝过多,回退到用户交互模式".to_string())
.with_confidence(0.0),
); }
}
let prompt = self.build_prompt(tool_name, tool_input, context);
let request = ChatRequest::new(vec![Message::user(prompt)]);
let response = self.client.chat(request).await?;
let content = response.content().unwrap_or_default();
let result = Self::parse_response(&content).unwrap_or_else(|| {
ClassifierResult::allow("无法解析 LLM 响应,默认允许".to_string()).with_confidence(0.5)
});
if result.should_block {
let mut tracker = self.denial_tracker.lock().await;
tracker.record_denial();
} else {
let mut tracker = self.denial_tracker.lock().await;
tracker.reset();
}
Ok(result)
}
}
pub struct CompositeClassifier {
classifiers: Vec<Arc<dyn Classifier>>,
}
impl CompositeClassifier {
pub fn new() -> Self {
Self {
classifiers: Vec::new(),
}
}
pub fn add(&mut self, classifier: Arc<dyn Classifier>) {
self.classifiers.push(classifier);
}
pub fn with_classifiers(mut self, classifiers: Vec<Arc<dyn Classifier>>) -> Self {
self.classifiers = classifiers;
self
}
}
impl Default for CompositeClassifier {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Classifier for CompositeClassifier {
async fn classify(
&self,
tool_name: &str,
tool_input: &Value,
context: &ClassifierContext,
) -> Result<ClassifierResult> {
for classifier in &self.classifiers {
let result = classifier.classify(tool_name, tool_input, context).await?;
if result.confidence >= 0.8 {
return Ok(result);
}
}
Ok(
ClassifierResult::allow("所有分类器低置信度,默认允许".to_string())
.with_confidence(0.5),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classifier_result_allow() {
let result = ClassifierResult::allow("test reason".to_string());
assert!(!result.should_block);
assert_eq!(result.reason, "test reason");
assert_eq!(result.confidence, 1.0);
}
#[test]
fn test_classifier_result_block() {
let result = ClassifierResult::block("dangerous".to_string());
assert!(result.should_block);
assert_eq!(result.reason, "dangerous");
assert_eq!(result.confidence, 1.0);
}
#[test]
fn test_classifier_result_with_confidence() {
let result = ClassifierResult::allow("test".to_string()).with_confidence(0.5);
assert_eq!(result.confidence, 0.5);
}
#[test]
fn test_denial_tracker_new() {
let tracker = DenialTracker::new();
assert_eq!(tracker.consecutive_denials(), 0);
assert_eq!(tracker.total_denials(), 0);
assert!(!tracker.should_fallback());
}
#[test]
fn test_denial_tracker_record() {
let mut tracker = DenialTracker::new();
tracker.record_denial();
assert_eq!(tracker.consecutive_denials(), 1);
assert_eq!(tracker.total_denials(), 1);
assert!(tracker.last_denial_time().is_some());
}
#[test]
fn test_denial_tracker_fallback() {
let mut tracker = DenialTracker::with_max_consecutive(2);
tracker.record_denial();
assert!(!tracker.should_fallback());
tracker.record_denial();
assert!(tracker.should_fallback());
}
#[test]
fn test_denial_tracker_reset() {
let mut tracker = DenialTracker::new();
tracker.record_denial();
tracker.record_denial();
tracker.reset();
assert_eq!(tracker.consecutive_denials(), 0);
assert_eq!(tracker.total_denials(), 2); }
#[test]
fn test_classifier_context_new() {
let ctx = ClassifierContext::new("agent".to_string(), "session".to_string());
assert_eq!(ctx.agent_name, "agent");
assert_eq!(ctx.session_id, "session");
assert!(ctx.messages.is_empty());
}
#[test]
fn test_rule_classifier_matches_pattern() {
assert!(RuleClassifier::matches_pattern("*", "Bash"));
assert!(RuleClassifier::matches_pattern("Bash", "Bash"));
assert!(RuleClassifier::matches_pattern("Bash", "Bash(git:*)"));
assert!(!RuleClassifier::matches_pattern("Bash", "BashExtra"));
}
#[tokio::test]
async fn test_rule_classifier_allow() {
let classifier =
RuleClassifier::new().with_allow_patterns(vec!["Read".to_string(), "Glob".to_string()]);
let ctx = ClassifierContext::new("agent".to_string(), "session".to_string());
let result = classifier
.classify("Read", &serde_json::json!({}), &ctx)
.await
.unwrap();
assert!(!result.should_block);
}
#[tokio::test]
async fn test_rule_classifier_deny() {
let classifier = RuleClassifier::new().with_deny_patterns(vec!["Bash(rm:*)".to_string()]);
let ctx = ClassifierContext::new("agent".to_string(), "session".to_string());
let result = classifier
.classify(
"Bash(rm:rf)",
&serde_json::json!({"command": "rm -rf"}),
&ctx,
)
.await
.unwrap();
assert!(result.should_block);
}
#[tokio::test]
async fn test_rule_classifier_default() {
let classifier = RuleClassifier::new();
let ctx = ClassifierContext::new("agent".to_string(), "session".to_string());
let result = classifier
.classify("UnknownTool", &serde_json::json!({}), &ctx)
.await
.unwrap();
assert!(!result.should_block); }
#[test]
fn test_llm_classifier_build_prompt() {
use echo_core::llm::{ChatChunk, ChatCompletionResponse, ChatRequest, ChatResponse};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
struct MockClient;
impl LlmClient for MockClient {
fn chat(&self, _request: ChatRequest) -> BoxFuture<'_, Result<ChatResponse>> {
Box::pin(async move {
Ok(ChatResponse {
message: Message::assistant(
"{\"should_block\": false, \"reason\": \"test\"}".to_string(),
),
finish_reason: None,
raw: ChatCompletionResponse::default(),
})
})
}
fn chat_stream(
&self,
_request: ChatRequest,
) -> BoxFuture<'_, Result<BoxStream<'_, Result<ChatChunk>>>> {
Box::pin(async move {
Err(echo_core::error::ReactError::Other(
"not implemented".to_string(),
))
})
}
fn model_name(&self) -> &str {
"mock"
}
}
let client = Arc::new(MockClient);
let classifier = LlmClassifier::new(client, "model".to_string());
let ctx = ClassifierContext::new("agent".to_string(), "session".to_string())
.with_allow_rules(vec!["Read:*".to_string()])
.with_soft_deny_rules(vec!["Bash(rm:*)".to_string()]);
let prompt = classifier.build_prompt("Read", &serde_json::json!({"path": "/tmp"}), &ctx);
assert!(prompt.contains("Read"));
assert!(prompt.contains("Read:*"));
assert!(prompt.contains("Bash(rm:*)"));
}
#[test]
fn test_llm_classifier_parse_response_json() {
let result =
parse_hook_output_for_test("{\"should_block\": true, \"reason\": \"危险操作\"}");
assert!(result.is_some());
let r = result.unwrap();
assert!(r.should_block);
assert_eq!(r.reason, "危险操作");
}
#[test]
fn test_llm_classifier_parse_response_text() {
let result = parse_hook_output_for_test("我认为应该阻止此操作");
assert!(result.is_some());
assert!(result.unwrap().should_block);
}
fn parse_hook_output_for_test(response: &str) -> Option<ClassifierResult> {
let json_str = response
.trim()
.lines()
.filter(|line| line.starts_with('{') || line.starts_with('}'))
.collect::<String>();
if let Ok(json) = serde_json::from_str::<Value>(&json_str) {
let should_block = json
.get("should_block")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let reason = json
.get("reason")
.and_then(|v| v.as_str())
.unwrap_or("未提供原因")
.to_string();
return Some(if should_block {
ClassifierResult::block(reason)
} else {
ClassifierResult::allow(reason)
});
}
if response.contains("阻止") || response.contains("block") {
Some(ClassifierResult::block("LLM 判断应阻止".to_string()))
} else if response.contains("允许") || response.contains("allow") {
Some(ClassifierResult::allow("LLM 判断应允许".to_string()))
} else {
None
}
}
}