#[cfg(feature = "async")]
use std::collections::HashMap;
#[cfg(feature = "async")]
use std::future::Future;
#[cfg(feature = "async")]
use std::pin::Pin;
#[cfg(feature = "async")]
use tensorlogic_ir::EinsumGraph;
#[cfg(feature = "async")]
use crate::batch::BatchResult;
#[cfg(feature = "async")]
use crate::streaming::{StreamResult, StreamingConfig};
#[cfg(feature = "async")]
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[cfg(feature = "async")]
pub trait TlAsyncExecutor {
type Tensor: Send;
type Error: Send;
fn execute_async<'a>(
&'a mut self,
graph: &'a EinsumGraph,
inputs: &'a HashMap<String, Self::Tensor>,
) -> BoxFuture<'a, Result<Vec<Self::Tensor>, Self::Error>>;
fn is_ready(&self) -> bool {
true
}
fn wait_ready(&mut self) -> BoxFuture<'_, ()>
where
Self: Send,
{
Box::pin(async move {
while !self.is_ready() {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
})
}
}
#[cfg(feature = "async")]
pub trait TlAsyncBatchExecutor: TlAsyncExecutor {
fn execute_batch_async<'a>(
&'a mut self,
graph: &'a EinsumGraph,
batch_inputs: Vec<HashMap<String, Self::Tensor>>,
) -> BoxFuture<'a, Result<BatchResult<Self::Tensor>, Self::Error>>;
}
#[cfg(feature = "async")]
pub type AsyncStreamResults<T, E> = Vec<Result<StreamResult<T>, E>>;
#[cfg(feature = "async")]
pub trait TlAsyncStreamExecutor: TlAsyncExecutor {
fn execute_stream_async<'a>(
&'a mut self,
graph: &'a EinsumGraph,
input_stream: Vec<Vec<Vec<Self::Tensor>>>,
config: &'a StreamingConfig,
) -> BoxFuture<'a, AsyncStreamResults<Self::Tensor, Self::Error>>;
}
#[derive(Debug, Clone)]
pub enum AsyncExecutionError<E> {
Timeout { elapsed_ms: u64 },
ExecutorBusy { queue_size: usize },
Cancelled,
ExecutorError(E),
Dropped,
}
impl<E: std::fmt::Display> std::fmt::Display for AsyncExecutionError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Timeout { elapsed_ms } => {
write!(f, "Execution timed out after {}ms", elapsed_ms)
}
Self::ExecutorBusy { queue_size } => {
write!(
f,
"Executor is busy (queue size: {}), try again later",
queue_size
)
}
Self::Cancelled => write!(f, "Execution was cancelled"),
Self::ExecutorError(e) => write!(f, "Executor error: {}", e),
Self::Dropped => write!(f, "Future was dropped before completion"),
}
}
}
impl<E: std::error::Error> std::error::Error for AsyncExecutionError<E> {}
#[cfg(feature = "async")]
pub struct AsyncExecutionHandle {
execution_id: String,
started_at: std::time::Instant,
cancel_token: tokio::sync::mpsc::Sender<()>,
}
#[cfg(feature = "async")]
impl AsyncExecutionHandle {
pub fn new(execution_id: String) -> (Self, tokio::sync::mpsc::Receiver<()>) {
let (tx, rx) = tokio::sync::mpsc::channel(1);
(
AsyncExecutionHandle {
execution_id,
started_at: std::time::Instant::now(),
cancel_token: tx,
},
rx,
)
}
pub fn execution_id(&self) -> &str {
&self.execution_id
}
pub fn elapsed(&self) -> std::time::Duration {
self.started_at.elapsed()
}
pub async fn cancel(&self) -> Result<(), AsyncExecutionError<std::io::Error>> {
self.cancel_token
.send(())
.await
.map_err(|_| AsyncExecutionError::Cancelled)
}
}
#[cfg(feature = "async")]
pub struct AsyncExecutorPool<E: TlAsyncExecutor> {
executors: Vec<E>,
next_index: std::sync::atomic::AtomicUsize,
}
#[cfg(feature = "async")]
impl<E: TlAsyncExecutor> AsyncExecutorPool<E> {
pub fn new(executors: Vec<E>) -> Self {
AsyncExecutorPool {
executors,
next_index: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn size(&self) -> usize {
self.executors.len()
}
pub fn get_next_index(&self) -> usize {
self.next_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.executors.len()
}
pub fn get_least_loaded_index(&self) -> usize {
for (idx, executor) in self.executors.iter().enumerate() {
if executor.is_ready() {
return idx;
}
}
0
}
pub async fn execute_any<'a>(
&'a mut self,
graph: &'a EinsumGraph,
inputs: &'a HashMap<String, E::Tensor>,
) -> Result<Vec<E::Tensor>, E::Error> {
let index = self.get_least_loaded_index();
self.executors[index].execute_async(graph, inputs).await
}
}
#[derive(Debug, Clone)]
pub struct AsyncConfig {
pub max_concurrent: usize,
pub timeout_ms: Option<u64>,
pub enable_retry: bool,
pub max_retries: usize,
pub backoff_ms: u64,
}
impl Default for AsyncConfig {
fn default() -> Self {
AsyncConfig {
max_concurrent: 4,
timeout_ms: None,
enable_retry: false,
max_retries: 3,
backoff_ms: 100,
}
}
}
impl AsyncConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
pub fn with_retry(mut self, max_retries: usize, backoff_ms: u64) -> Self {
self.enable_retry = true;
self.max_retries = max_retries;
self.backoff_ms = backoff_ms;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct AsyncStats {
pub total_executions: usize,
pub successful: usize,
pub failed: usize,
pub timeouts: usize,
pub cancelled: usize,
pub avg_execution_time_ms: f64,
pub peak_concurrent: usize,
}
impl AsyncStats {
pub fn new() -> Self {
Self::default()
}
pub fn success_rate(&self) -> f64 {
if self.total_executions == 0 {
0.0
} else {
self.successful as f64 / self.total_executions as f64
}
}
pub fn summary(&self) -> String {
format!(
"Async Execution Stats:\n\
- Total: {}\n\
- Successful: {} ({:.1}%)\n\
- Failed: {}\n\
- Timeouts: {}\n\
- Cancelled: {}\n\
- Avg time: {:.2}ms\n\
- Peak concurrent: {}",
self.total_executions,
self.successful,
self.success_rate() * 100.0,
self.failed,
self.timeouts,
self.cancelled,
self.avg_execution_time_ms,
self.peak_concurrent
)
}
}
#[cfg(all(test, feature = "async"))]
mod tests {
use super::*;
#[test]
fn test_async_config() {
let config = AsyncConfig::new()
.with_max_concurrent(8)
.with_timeout(5000)
.with_retry(3, 200);
assert_eq!(config.max_concurrent, 8);
assert_eq!(config.timeout_ms, Some(5000));
assert!(config.enable_retry);
assert_eq!(config.max_retries, 3);
assert_eq!(config.backoff_ms, 200);
}
#[test]
fn test_async_stats() {
let mut stats = AsyncStats::new();
stats.total_executions = 100;
stats.successful = 95;
stats.failed = 3;
stats.timeouts = 2;
assert_eq!(stats.success_rate(), 0.95);
assert!(stats.summary().contains("95.0%"));
}
#[test]
fn test_async_error_display() {
let err = AsyncExecutionError::<String>::Timeout { elapsed_ms: 5000 };
assert_eq!(err.to_string(), "Execution timed out after 5000ms");
let err2 = AsyncExecutionError::<String>::ExecutorBusy { queue_size: 10 };
assert!(err2.to_string().contains("queue size: 10"));
}
#[tokio::test]
async fn test_execution_handle() {
let (handle, mut rx) = AsyncExecutionHandle::new("test-123".to_string());
assert_eq!(handle.execution_id(), "test-123");
assert!(handle.elapsed().as_millis() < 100);
handle.cancel().await.expect("unwrap");
assert!(rx.recv().await.is_some());
}
}