use anyhow::{Result, anyhow};
use serde_json::Value;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Notify, RwLock, Semaphore, mpsc};
use tokio::time::timeout;
use tracing::{debug, error};
use crate::core::memory_pool::global_pool;
use crate::tools::request_response::ToolCallRequest;
#[derive(Debug, Clone)]
pub struct ToolRequest {
pub call: ToolCallRequest,
pub priority: ExecutionPriority,
pub timeout: Duration,
pub context: ExecutionContext,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ExecutionPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone)]
pub struct ExecutionContext {
pub session_id: String,
pub user_id: Option<String>,
pub workspace_path: String,
pub parent_request_id: Option<String>,
}
#[derive(Debug)]
pub struct ToolResult {
pub request_id: String,
pub result: Result<Value>,
pub execution_time: Duration,
pub memory_used: Option<usize>,
pub cache_hit: bool,
}
#[derive(Debug)]
pub struct ToolBatch {
pub requests: Vec<ToolRequest>,
pub batch_id: String,
pub created_at: Instant,
}
pub struct AsyncToolPipeline {
request_queue: Arc<RwLock<VecDeque<ToolRequest>>>,
work_notify: Arc<Notify>,
batch_size: usize,
batch_timeout: Duration,
execution_semaphore: Arc<Semaphore>,
result_cache: Arc<RwLock<lru::LruCache<String, ToolResult>>>,
metrics: Arc<RwLock<PipelineMetrics>>,
shutdown_tx: Option<mpsc::Sender<()>>,
processing_task: Option<tokio::task::JoinHandle<()>>,
}
#[derive(Debug, Default, Clone)]
pub struct PipelineMetrics {
pub total_requests: u64,
pub successful_executions: u64,
pub failed_executions: u64,
pub cache_hits: u64,
pub avg_execution_time_ms: f64,
pub batch_efficiency: f64,
}
impl AsyncToolPipeline {
pub fn new(
max_concurrent_tools: usize,
cache_size: usize,
batch_size: usize,
batch_timeout: Duration,
) -> Self {
Self {
request_queue: Arc::new(RwLock::new(VecDeque::with_capacity(256))),
work_notify: Arc::new(Notify::new()),
batch_size,
batch_timeout,
execution_semaphore: Arc::new(Semaphore::new(max_concurrent_tools)),
result_cache: Arc::new(RwLock::new(lru::LruCache::new(
std::num::NonZeroUsize::new(cache_size).unwrap_or(std::num::NonZeroUsize::MIN),
))),
metrics: Arc::new(RwLock::new(PipelineMetrics::default())),
shutdown_tx: None,
processing_task: None,
}
}
pub async fn start(&mut self) -> Result<()> {
if self.shutdown_tx.is_some() {
return Err(anyhow!("AsyncToolPipeline already started"));
}
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
self.shutdown_tx = Some(shutdown_tx);
let request_queue = Arc::clone(&self.request_queue);
let work_notify = Arc::clone(&self.work_notify);
let batch_size = self.batch_size;
let batch_timeout = self.batch_timeout;
let execution_semaphore = Arc::clone(&self.execution_semaphore);
let result_cache = Arc::clone(&self.result_cache);
let metrics = Arc::clone(&self.metrics);
let processing_task = tokio::spawn(async move {
'outer: loop {
tokio::select! {
_ = shutdown_rx.recv() => {
debug!("Pipeline shutdown requested");
break;
}
_ = work_notify.notified() => {}
}
let mut flush_deadline = tokio::time::Instant::now() + batch_timeout;
loop {
let queue_len = request_queue.read().await.len();
if queue_len == 0 {
break;
}
if queue_len >= batch_size {
Self::process_batch(
&request_queue,
batch_size,
&execution_semaphore,
&result_cache,
&metrics,
)
.await;
flush_deadline = tokio::time::Instant::now() + batch_timeout;
continue;
}
let sleep_until_flush = tokio::time::sleep_until(flush_deadline);
tokio::pin!(sleep_until_flush);
tokio::select! {
_ = shutdown_rx.recv() => {
debug!("Pipeline shutdown requested");
break 'outer;
}
_ = &mut sleep_until_flush => {
Self::process_batch(
&request_queue,
batch_size,
&execution_semaphore,
&result_cache,
&metrics,
)
.await;
flush_deadline = tokio::time::Instant::now() + batch_timeout;
}
_ = work_notify.notified() => {}
}
}
}
});
self.processing_task = Some(processing_task);
Ok(())
}
pub async fn submit_request(&self, request: ToolRequest) -> Result<String> {
let cache_key = self.generate_cache_key(&request);
let cache_hit = {
let cache = self.result_cache.read().await;
cache.peek(&cache_key).is_some()
};
if cache_hit {
self.metrics.write().await.cache_hits += 1;
return Ok(request.call.id.clone());
}
let request_id = request.call.id.clone();
{
let mut queue = self.request_queue.write().await;
let insert_pos = queue
.iter()
.position(|r| r.priority < request.priority)
.unwrap_or(queue.len());
queue.insert(insert_pos, request);
}
self.metrics.write().await.total_requests += 1;
self.work_notify.notify_one();
Ok(request_id)
}
async fn process_batch(
request_queue: &Arc<RwLock<VecDeque<ToolRequest>>>,
batch_size: usize,
execution_semaphore: &Arc<Semaphore>,
result_cache: &Arc<RwLock<lru::LruCache<String, ToolResult>>>,
metrics: &Arc<RwLock<PipelineMetrics>>,
) {
let batch = {
let mut queue = request_queue.write().await;
if queue.is_empty() {
return;
}
let current_batch_size = std::cmp::min(queue.len(), batch_size);
let requests: Vec<_> = queue.drain(..current_batch_size).collect();
ToolBatch {
requests,
batch_id: uuid::Uuid::new_v4().to_string(),
created_at: Instant::now(),
}
};
if batch.requests.is_empty() {
return;
}
debug!(
"Processing batch {} with {} requests",
batch.batch_id,
batch.requests.len()
);
let mut handles = Vec::with_capacity(batch.requests.len());
let batch_size = batch.requests.len();
for request in batch.requests {
let semaphore = Arc::clone(execution_semaphore);
let cache = Arc::clone(result_cache);
let metrics_ref = Arc::clone(metrics);
let handle = tokio::spawn(async move {
let Ok(_permit) = semaphore.acquire().await else {
error!("Pipeline semaphore closed before request execution");
return Ok(());
};
Self::execute_single_request(request, cache, metrics_ref).await
});
handles.push(handle);
}
for handle in handles {
if let Err(e) = handle.await {
error!("Tool execution failed: {}", e);
}
}
let batch_time = batch.created_at.elapsed();
let mut metrics_guard = metrics.write().await;
let batch_time_ms = batch_time.as_millis() as f64;
metrics_guard.batch_efficiency = if batch_time_ms > 0.0 {
batch_size as f64 / batch_time_ms
} else {
batch_size as f64
};
}
async fn execute_single_request(
request: ToolRequest,
result_cache: Arc<RwLock<lru::LruCache<String, ToolResult>>>,
metrics: Arc<RwLock<PipelineMetrics>>,
) -> Result<()> {
let start_time = Instant::now();
let cache_key = format!(
"{}:{}",
request.call.tool_name,
serde_json::to_string(&request.call.args).unwrap_or_default()
);
{
let cache_guard = result_cache.read().await;
if cache_guard.peek(&cache_key).is_some() {
metrics.write().await.cache_hits += 1;
return Ok(());
}
}
let execution_result = timeout(
request.timeout,
Self::execute_tool_impl(&request.call.tool_name, &request.call.args),
)
.await;
let execution_time = start_time.elapsed();
let result = match execution_result {
Ok(Ok(value)) => Ok(value),
Ok(Err(e)) => Err(e),
Err(_) => Err(anyhow!(
"Tool execution timed out after {:?}",
request.timeout
)),
};
let result_for_cache = result.is_ok();
let tool_result = ToolResult {
request_id: request.call.id.clone(),
result: result.map_err(|e| anyhow::anyhow!(e.to_string())),
execution_time,
memory_used: None, cache_hit: false,
};
if result_for_cache {
result_cache.write().await.put(cache_key, tool_result);
}
let mut metrics_guard = metrics.write().await;
if result_for_cache {
metrics_guard.successful_executions += 1;
} else {
metrics_guard.failed_executions += 1;
}
let alpha = 0.1; metrics_guard.avg_execution_time_ms = alpha * execution_time.as_millis() as f64
+ (1.0 - alpha) * metrics_guard.avg_execution_time_ms;
Ok(())
}
fn generate_cache_key(&self, request: &ToolRequest) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
request.call.tool_name.hash(&mut hasher);
request.call.args.to_string().hash(&mut hasher);
format!("{}:{:x}", request.call.tool_name, hasher.finish())
}
async fn execute_tool_impl(_tool_name: &str, _args: &Value) -> Result<Value> {
let pool = global_pool();
let mut work_string = pool.get_string();
work_string.push_str("Executed tool with args");
tokio::time::sleep(Duration::from_millis(10)).await;
pool.return_string(work_string);
Ok(Value::String("Tool execution result".to_string()))
}
pub async fn get_metrics(&self) -> PipelineMetrics {
self.metrics.read().await.clone()
}
pub async fn shutdown(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
if let Some(handle) = self.processing_task.take() {
let _ = handle.await;
}
Ok(())
}
}
impl Drop for AsyncToolPipeline {
fn drop(&mut self) {
if let Some(handle) = self.processing_task.take() {
handle.abort();
}
}
}