use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use serde_json::Value;
use tokio::sync::mpsc;
use tokio::time::timeout;
use crate::error::{Error, Result};
use crate::protocol::SessionId;
use crate::tool::PermissionLevel;
#[derive(Debug, Clone, Default)]
pub struct ResourceLimits {
pub max_memory: Option<u64>,
pub max_cpu_time: Option<u64>,
pub max_api_calls: Option<u64>,
pub max_concurrent_requests: Option<u32>,
}
#[derive(Debug)]
pub struct ResourceUsage {
memory_used: AtomicU64,
cpu_time: AtomicU64,
api_calls: AtomicU64,
current_concurrent_requests: AtomicU64,
start_time: Instant,
}
impl Default for ResourceUsage {
fn default() -> Self {
Self {
memory_used: AtomicU64::new(0),
cpu_time: AtomicU64::new(0),
api_calls: AtomicU64::new(0),
current_concurrent_requests: AtomicU64::new(0),
start_time: Instant::now(),
}
}
}
impl ResourceUsage {
pub fn memory_used(&self) -> u64 {
self.memory_used.load(Ordering::Relaxed)
}
pub fn cpu_time(&self) -> u64 {
self.cpu_time.load(Ordering::Relaxed)
}
pub fn api_calls(&self) -> u64 {
self.api_calls.load(Ordering::Relaxed)
}
pub fn current_concurrent_requests(&self) -> u64 {
self.current_concurrent_requests.load(Ordering::Relaxed)
}
pub fn elapsed_time(&self) -> Duration {
self.start_time.elapsed()
}
pub fn add_memory(&self, bytes: u64) {
self.memory_used.fetch_add(bytes, Ordering::Relaxed);
}
pub fn add_cpu_time(&self, ms: u64) {
self.cpu_time.fetch_add(ms, Ordering::Relaxed);
}
pub fn increment_api_calls(&self) {
self.api_calls.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_concurrent_requests(&self) -> u64 {
self.current_concurrent_requests
.fetch_add(1, Ordering::Relaxed)
}
pub fn decrement_concurrent_requests(&self) -> u64 {
self.current_concurrent_requests
.fetch_sub(1, Ordering::Relaxed)
}
}
#[derive(Debug)]
pub struct ToolContext {
pub session_id: SessionId,
cancelled: AtomicBool,
partial_results_tx: Option<mpsc::Sender<Value>>,
timeout: Option<Duration>,
permission_level: PermissionLevel,
user_id: Option<String>,
resource_limits: ResourceLimits,
resource_usage: ResourceUsage,
}
impl ToolContext {
pub fn new(session_id: SessionId) -> Self {
Self {
session_id,
cancelled: AtomicBool::new(false),
partial_results_tx: None,
timeout: None,
permission_level: PermissionLevel::Public,
user_id: None,
resource_limits: ResourceLimits::default(),
resource_usage: ResourceUsage::default(),
}
}
pub fn with_timeout(session_id: SessionId, timeout_duration: Duration) -> Self {
Self {
session_id,
cancelled: AtomicBool::new(false),
partial_results_tx: None,
timeout: Some(timeout_duration),
permission_level: PermissionLevel::Public,
user_id: None,
resource_limits: ResourceLimits::default(),
resource_usage: ResourceUsage::default(),
}
}
pub fn with_permission_level(mut self, level: PermissionLevel) -> Self {
self.permission_level = level;
self
}
pub fn with_user_id(mut self, user_id: String) -> Self {
self.user_id = Some(user_id);
self.permission_level = if self.permission_level == PermissionLevel::Public {
PermissionLevel::Authenticated
} else {
self.permission_level
};
self
}
pub fn permission_level(&self) -> PermissionLevel {
self.permission_level
}
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
pub fn elevate_to_admin(&mut self) {
self.permission_level = PermissionLevel::Admin;
}
pub fn with_partial_results(mut self, tx: mpsc::Sender<Value>) -> Self {
self.partial_results_tx = Some(tx);
self
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
pub fn check_cancelled(&self) -> Result<()> {
if self.is_cancelled() {
return Err(Error::Cancelled("Tool execution cancelled".to_string()));
}
Ok(())
}
pub async fn with_timeout_future<F, T, E>(&self, future: F) -> Result<T>
where
F: Future<Output = std::result::Result<T, E>>,
E: Into<Error>,
{
self.check_cancelled()?;
match self.timeout {
Some(duration) => {
match timeout(duration, future).await {
Ok(result) => result.map_err(|e| e.into()),
Err(_) => Err(Error::Timeout("Tool execution timed out".to_string())),
}
}
None => future.await.map_err(|e| e.into()),
}
}
pub async fn send_partial_result(&self, result: Value) -> Result<()> {
if let Some(tx) = &self.partial_results_tx {
tx.send(result)
.await
.map_err(|_| Error::Transport("Failed to send partial result".to_string()))?;
}
Ok(())
}
pub fn with_resource_limits(mut self, limits: ResourceLimits) -> Self {
self.resource_limits = limits;
self
}
pub fn check_resource_limits(&self) -> Result<()> {
if let Some(max_memory) = self.resource_limits.max_memory {
let memory_used = self.resource_usage.memory_used();
if memory_used > max_memory {
return Err(Error::ResourceAccess(format!(
"内存使用超出限制:已使用 {memory_used} 字节,最大限制 {max_memory} 字节"
)));
}
}
if let Some(max_cpu_time) = self.resource_limits.max_cpu_time {
let cpu_time = self.resource_usage.cpu_time();
if cpu_time > max_cpu_time {
return Err(Error::ResourceAccess(format!(
"CPU时间超出限制:已使用 {cpu_time} 毫秒,最大限制 {max_cpu_time} 毫秒"
)));
}
}
if let Some(max_api_calls) = self.resource_limits.max_api_calls {
let api_calls = self.resource_usage.api_calls();
if api_calls > max_api_calls {
return Err(Error::ResourceAccess(format!(
"API调用次数超出限制:已调用 {api_calls} 次,最大限制 {max_api_calls} 次"
)));
}
}
if let Some(max_concurrent_requests) = self.resource_limits.max_concurrent_requests {
let concurrent_requests = self.resource_usage.current_concurrent_requests();
if concurrent_requests > max_concurrent_requests as u64 {
return Err(Error::ResourceAccess(format!(
"并发请求数超出限制:当前 {concurrent_requests} 个,最大限制 {max_concurrent_requests} 个"
)));
}
}
Ok(())
}
pub fn record_api_call(&self) -> Result<()> {
self.resource_usage.increment_api_calls();
self.check_resource_limits()
}
pub fn record_memory_usage(&self, bytes: u64) -> Result<()> {
self.resource_usage.add_memory(bytes);
self.check_resource_limits()
}
pub fn resource_usage(&self) -> &ResourceUsage {
&self.resource_usage
}
pub fn resource_limits(&self) -> &ResourceLimits {
&self.resource_limits
}
}