use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::thread;
use std::time::{Duration, Instant};
use super::{YufmathError, ComputeProgress};
pub type AsyncResult<T> = Pin<Box<dyn Future<Output = Result<T, YufmathError>> + Send>>;
#[derive(Debug, Clone, PartialEq)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Cancelled,
Error,
}
pub struct AsyncTask<T> {
pub id: u64,
pub status: TaskStatus,
pub result: Option<Result<T, YufmathError>>,
pub progress: Option<ComputeProgress>,
pub start_time: Option<Instant>,
pub end_time: Option<Instant>,
waker: Option<Waker>,
}
impl<T> AsyncTask<T> {
pub fn new(id: u64) -> Self {
Self {
id,
status: TaskStatus::Pending,
result: None,
progress: None,
start_time: None,
end_time: None,
waker: None,
}
}
pub fn start(&mut self) {
self.status = TaskStatus::Running;
self.start_time = Some(Instant::now());
}
pub fn complete(&mut self, result: Result<T, YufmathError>) {
self.status = TaskStatus::Completed;
self.end_time = Some(Instant::now());
self.result = Some(result);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
pub fn cancel(&mut self) {
self.status = TaskStatus::Cancelled;
self.end_time = Some(Instant::now());
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
pub fn update_progress(&mut self, progress: ComputeProgress) {
self.progress = Some(progress);
if let Some(waker) = &self.waker {
waker.wake_by_ref();
}
}
pub fn execution_time(&self) -> Option<Duration> {
match (self.start_time, self.end_time) {
(Some(start), Some(end)) => Some(end.duration_since(start)),
(Some(start), None) => Some(Instant::now().duration_since(start)),
_ => None,
}
}
}
pub struct AsyncComputation<T> {
task: Arc<Mutex<AsyncTask<T>>>,
}
impl<T> AsyncComputation<T> {
pub fn new(task: Arc<Mutex<AsyncTask<T>>>) -> Self {
Self { task }
}
pub fn status(&self) -> TaskStatus {
if let Ok(task) = self.task.lock() {
task.status.clone()
} else {
TaskStatus::Error
}
}
pub fn progress(&self) -> Option<ComputeProgress> {
if let Ok(task) = self.task.lock() {
task.progress.clone()
} else {
None
}
}
pub fn cancel(&self) {
if let Ok(mut task) = self.task.lock() {
task.cancel();
}
}
pub fn execution_time(&self) -> Option<Duration> {
if let Ok(task) = self.task.lock() {
task.execution_time()
} else {
None
}
}
}
impl<T: Clone> Future for AsyncComputation<T> {
type Output = Result<T, YufmathError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Ok(mut task) = self.task.lock() {
match task.status {
TaskStatus::Completed => {
if let Some(result) = task.result.take() {
Poll::Ready(result)
} else {
Poll::Ready(Err(YufmathError::internal("任务已完成但没有结果")))
}
}
TaskStatus::Cancelled => {
Poll::Ready(Err(YufmathError::internal("任务已被取消")))
}
TaskStatus::Error => {
Poll::Ready(Err(YufmathError::internal("任务执行出错")))
}
_ => {
task.waker = Some(cx.waker().clone());
Poll::Pending
}
}
} else {
Poll::Ready(Err(YufmathError::internal("无法获取任务锁")))
}
}
}
pub struct BatchAsyncComputer {
task_counter: Arc<Mutex<u64>>,
active_tasks: Arc<Mutex<Vec<Arc<Mutex<AsyncTask<String>>>>>>,
max_concurrent: usize,
}
impl BatchAsyncComputer {
pub fn new(max_concurrent: usize) -> Self {
Self {
task_counter: Arc::new(Mutex::new(0)),
active_tasks: Arc::new(Mutex::new(Vec::new())),
max_concurrent,
}
}
pub fn submit_batch(&self, expressions: Vec<String>) -> Vec<AsyncComputation<String>> {
let mut computations = Vec::new();
for expr in expressions {
let task_id = {
let mut counter = self.task_counter.lock().unwrap();
*counter += 1;
*counter
};
let task = Arc::new(Mutex::new(AsyncTask::new(task_id)));
let computation = AsyncComputation::new(Arc::clone(&task));
if let Ok(mut active_tasks) = self.active_tasks.lock() {
active_tasks.push(Arc::clone(&task));
}
self.spawn_computation_task(task, expr);
computations.push(computation);
}
computations
}
fn spawn_computation_task(&self, task: Arc<Mutex<AsyncTask<String>>>, expression: String) {
thread::spawn(move || {
if let Ok(mut t) = task.lock() {
t.start();
}
let result = Self::simulate_computation(&expression, &task);
if let Ok(mut t) = task.lock() {
t.complete(result);
}
});
}
fn simulate_computation(
expression: &str,
task: &Arc<Mutex<AsyncTask<String>>>
) -> Result<String, YufmathError> {
let steps = vec![
("解析表达式", 0.2),
("简化表达式", 0.5),
("计算结果", 0.8),
("格式化输出", 1.0),
];
for (step_name, progress) in steps {
if let Ok(task_guard) = task.lock() {
if task_guard.status == TaskStatus::Cancelled {
return Err(YufmathError::internal("计算被取消"));
}
}
let progress_info = ComputeProgress::new(step_name)
.with_progress(progress);
if let Ok(mut task_guard) = task.lock() {
task_guard.update_progress(progress_info);
}
thread::sleep(Duration::from_millis(100));
}
Ok(format!("计算结果: {}", expression))
}
pub fn active_task_count(&self) -> usize {
if let Ok(active_tasks) = self.active_tasks.lock() {
active_tasks.len()
} else {
0
}
}
pub fn cancel_all(&self) {
if let Ok(active_tasks) = self.active_tasks.lock() {
for task in active_tasks.iter() {
if let Ok(mut t) = task.lock() {
t.cancel();
}
}
}
}
pub fn cleanup_completed(&self) {
if let Ok(mut active_tasks) = self.active_tasks.lock() {
active_tasks.retain(|task| {
if let Ok(t) = task.lock() {
!matches!(t.status, TaskStatus::Completed | TaskStatus::Cancelled | TaskStatus::Error)
} else {
false
}
});
}
}
}
#[derive(Debug, Clone)]
pub struct AsyncConfig {
pub max_concurrent_tasks: usize,
pub task_timeout: Duration,
pub enable_progress: bool,
pub progress_interval: Duration,
}
impl Default for AsyncConfig {
fn default() -> Self {
Self {
max_concurrent_tasks: 4,
task_timeout: Duration::from_secs(300), enable_progress: true,
progress_interval: Duration::from_millis(100),
}
}
}
impl AsyncConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_concurrent_tasks(mut self, max_tasks: usize) -> Self {
self.max_concurrent_tasks = max_tasks;
self
}
pub fn with_task_timeout(mut self, timeout: Duration) -> Self {
self.task_timeout = timeout;
self
}
pub fn with_progress(mut self, enable: bool) -> Self {
self.enable_progress = enable;
self
}
pub fn with_progress_interval(mut self, interval: Duration) -> Self {
self.progress_interval = interval;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_async_task_creation() {
let task = AsyncTask::<String>::new(1);
assert_eq!(task.id, 1);
assert_eq!(task.status, TaskStatus::Pending);
assert!(task.result.is_none());
}
#[test]
fn test_async_task_lifecycle() {
let mut task = AsyncTask::<String>::new(1);
task.start();
assert_eq!(task.status, TaskStatus::Running);
assert!(task.start_time.is_some());
task.complete(Ok("test result".to_string()));
assert_eq!(task.status, TaskStatus::Completed);
assert!(task.end_time.is_some());
assert!(task.result.is_some());
}
#[test]
fn test_batch_async_computer() {
let computer = BatchAsyncComputer::new(2);
let expressions = vec!["2+3".to_string(), "x^2".to_string()];
let computations = computer.submit_batch(expressions);
assert_eq!(computations.len(), 2);
assert_eq!(computer.active_task_count(), 2);
}
#[test]
fn test_async_config() {
let config = AsyncConfig::new()
.with_max_concurrent_tasks(8)
.with_task_timeout(Duration::from_secs(600))
.with_progress(false);
assert_eq!(config.max_concurrent_tasks, 8);
assert_eq!(config.task_timeout, Duration::from_secs(600));
assert!(!config.enable_progress);
}
}