echo_core 0.1.0

Core traits and types for the echo-agent framework
Documentation
//! 护栏系统核心 trait 和类型

#[cfg(feature = "guard")]
pub mod content;
#[cfg(feature = "guard")]
pub mod llm;
#[cfg(feature = "guard")]
pub mod rule;

use crate::error::Result;
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;

/// 护栏检查方向
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GuardDirection {
    /// 用户输入方向检查
    Input,
    /// 模型输出方向检查
    Output,
    /// 工具输入参数检查
    ToolInput,
    /// 工具输出结果检查
    ToolOutput,
}

impl std::fmt::Display for GuardDirection {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            GuardDirection::Input => write!(f, "input"),
            GuardDirection::Output => write!(f, "output"),
            GuardDirection::ToolInput => write!(f, "tool_input"),
            GuardDirection::ToolOutput => write!(f, "tool_output"),
        }
    }
}

/// 护栏检查结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GuardResult {
    Pass,
    Block {
        /// 阻断原因
        reason: String,
    },
    /// Multiple warnings collected from all guards.
    Warn {
        /// 警告原因列表
        reasons: Vec<String>,
    },
}

impl GuardResult {
    pub fn is_blocked(&self) -> bool {
        matches!(self, GuardResult::Block { .. })
    }
}

/// 护栏 trait
pub trait Guard: Send + Sync {
    /// 获取护栏名称
    fn name(&self) -> &str;

    /// 检查内容
    ///
    /// # 参数
    /// * `content` - 待检查的内容
    /// * `direction` - 检查方向
    fn check<'a>(
        &'a self,
        content: &'a str,
        direction: GuardDirection,
    ) -> BoxFuture<'a, Result<GuardResult>>;
}

/// 护栏管理器
pub struct GuardManager {
    guards: Vec<Arc<dyn Guard>>,
}

impl Default for GuardManager {
    fn default() -> Self {
        Self::new()
    }
}

impl GuardManager {
    /// 创建空的护栏管理器
    pub fn new() -> Self {
        Self { guards: Vec::new() }
    }

    /// 添加护栏
    pub fn add(&mut self, guard: Arc<dyn Guard>) {
        self.guards.push(guard);
    }

    /// 从护栏列表创建管理器
    pub fn from_guards(guards: Vec<Arc<dyn Guard>>) -> Self {
        Self { guards }
    }

    /// 检查是否为空(未添加任何护栏)
    pub fn is_empty(&self) -> bool {
        self.guards.is_empty()
    }

    /// 并行执行所有护栏检查。
    ///
    /// - 所有护栏同时启动(受并发度上限约束),而非串行执行。
    /// - 一旦发现 `Block` 结果,取消其他仍在运行的检查(通过 `CancellationToken`)。
    /// - 收集所有 `Warn` 理由到 `Vec<String>`。
    ///
    /// 并发度上限为 16,防止护栏数量过多时创建海量任务。
    pub async fn check_all(&self, content: &str, direction: GuardDirection) -> Result<GuardResult> {
        if self.guards.is_empty() {
            return Ok(GuardResult::Pass);
        }

        // Concurrency limit to avoid spawning unbounded tasks
        let semaphore = Arc::new(tokio::sync::Semaphore::new(16));
        let cancel = CancellationToken::new();
        let mut handles = Vec::with_capacity(self.guards.len());

        for guard in &self.guards {
            let guard = guard.clone();
            let content = content.to_string();
            let cancel_child = cancel.clone();
            let permit = semaphore.clone().acquire_owned().await;
            handles.push(tokio::spawn(async move {
                let _permit = permit; // hold until task completes
                let result = tokio::select! {
                    _ = cancel_child.cancelled() => {
                        return (guard.name().to_string(), Ok(GuardResult::Pass));
                    }
                    r = guard.check(&content, direction) => r,
                };
                (guard.name().to_string(), result)
            }));
        }

        let mut warnings = Vec::new();

        for (i, handle) in handles.into_iter().enumerate() {
            let (guard_name, result) = handle.await.map_err(|e| {
                crate::error::ReactError::Other(format!("Guard task {} panicked: {}", i, e))
            })?;

            match result {
                Ok(GuardResult::Block { reason }) => {
                    cancel.cancel(); // 取消其他仍在运行的检查
                    tracing::warn!(
                        guard = guard_name,
                        direction = %direction,
                        reason = %reason,
                        "护栏阻断"
                    );
                    return Ok(GuardResult::Block { reason });
                }
                Ok(GuardResult::Warn { reasons }) => {
                    warnings.extend(reasons);
                }
                Ok(GuardResult::Pass) => {}
                Err(e) => {
                    tracing::error!(guard = guard_name, error = %e, "护栏检查出错");
                    warnings.push(format!("{} error: {}", guard_name, e));
                }
            }
        }

        if !warnings.is_empty() {
            Ok(GuardResult::Warn { reasons: warnings })
        } else {
            Ok(GuardResult::Pass)
        }
    }
}