use std::{
error::Error,
fmt::Display,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use log::{LevelFilter, debug, info, trace};
use singleton_task::*;
use tokio::{task::JoinHandle, time::sleep};
#[derive(Debug, Clone)]
enum Error1 {
_A,
}
impl TError for Error1 {}
impl Error for Error1 {}
impl Display for Error1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
struct Task1 {
tx: Option<Sender<u32>>,
}
#[async_trait]
impl Task<Error1> for Task1 {
async fn on_start(&mut self, ctx: Context<Error1>) -> Result<(), Error1> {
trace!("[{}]on_start", ctx.id());
let tx = self.tx.take().unwrap();
let id = ctx.id();
ctx.spawn(async move {
for i in 0..10 {
let _ = tx.try_send(i);
info!("[{id}]send {i}");
sleep(Duration::from_millis(100)).await;
}
});
Ok(())
}
}
struct Tasl1Builder {}
impl TaskBuilder for Tasl1Builder {
type Output = u32;
type Error = Error1;
type Task = Task1;
fn build(self, tx: Sender<u32>) -> Self::Task {
Task1 { tx: Some(tx) }
}
}
#[derive(Debug, Clone)]
enum Error2 {
Custom(String),
}
impl TError for Error2 {}
impl Error for Error2 {}
impl Display for Error2 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error2::Custom(msg) => write!(f, "Custom: {msg}"),
}
}
}
struct LongRunningTask {
tx: Option<Sender<String>>,
task_name: String,
duration_ms: u64,
}
#[async_trait]
impl Task<Error1> for LongRunningTask {
async fn on_start(&mut self, ctx: Context<Error1>) -> Result<(), Error1> {
trace!("[{}] LongRunningTask {} starting", ctx.id(), self.task_name);
let tx = self.tx.take().unwrap();
let task_name = self.task_name.clone();
let duration_ms = self.duration_ms;
let id = ctx.id();
ctx.spawn(async move {
for i in 0..20 {
if tx.send(format!("{task_name}:{i}")).await.is_err() {
break;
}
info!("[{id}] {task_name} sending: {i}");
sleep(Duration::from_millis(duration_ms)).await;
}
info!("[{id}] {task_name} finished sending");
});
Ok(())
}
async fn on_stop(&mut self, ctx: Context<Error1>) -> Result<(), Error1> {
info!("[{}] LongRunningTask {} stopping", ctx.id(), self.task_name);
Ok(())
}
}
struct LongRunningTaskBuilder {
task_name: String,
duration_ms: u64,
}
impl TaskBuilder for LongRunningTaskBuilder {
type Output = String;
type Error = Error1;
type Task = LongRunningTask;
fn build(self, tx: Sender<String>) -> Self::Task {
LongRunningTask {
tx: Some(tx),
task_name: self.task_name,
duration_ms: self.duration_ms,
}
}
}
struct ErrorTask {
tx: Option<Sender<u32>>,
fail_after: u32,
}
#[async_trait]
impl Task<Error2> for ErrorTask {
async fn on_start(&mut self, ctx: Context<Error2>) -> Result<(), Error2> {
trace!("[{}] ErrorTask starting", ctx.id());
let tx = self.tx.take().unwrap();
let fail_after = self.fail_after;
let id = ctx.id();
ctx.spawn(async move {
for i in 0..10 {
if i >= fail_after {
return;
}
if tx.send(i).await.is_err() {
break;
}
info!("[{id}] ErrorTask sending: {i}");
sleep(Duration::from_millis(50)).await;
}
});
if self.fail_after < 3 {
return Err(Error2::Custom("Task failed during startup".to_string()));
}
Ok(())
}
}
struct ErrorTaskBuilder {
fail_after: u32,
}
impl TaskBuilder for ErrorTaskBuilder {
type Output = u32;
type Error = Error2;
type Task = ErrorTask;
fn build(self, tx: Sender<u32>) -> Self::Task {
ErrorTask {
tx: Some(tx),
fail_after: self.fail_after,
}
}
}
struct CounterTask {
tx: Option<Sender<(u32, String)>>,
counter: Arc<Mutex<u32>>,
task_id: String,
}
#[async_trait]
impl Task<Error1> for CounterTask {
async fn on_start(&mut self, ctx: Context<Error1>) -> Result<(), Error1> {
trace!("[{}] CounterTask {} starting", ctx.id(), self.task_id);
let tx = self.tx.take().unwrap();
let counter = self.counter.clone();
let task_id = self.task_id.clone();
let id = ctx.id();
ctx.spawn(async move {
for _ in 0..5 {
let count = {
let mut c = counter.lock().unwrap();
*c += 1;
*c
};
if tx.send((count, task_id.clone())).await.is_err() {
break;
}
info!("[{id}] CounterTask {task_id} count: {count}");
sleep(Duration::from_millis(100)).await;
}
});
Ok(())
}
}
struct CounterTaskBuilder {
counter: Arc<Mutex<u32>>,
task_id: String,
}
impl TaskBuilder for CounterTaskBuilder {
type Output = (u32, String);
type Error = Error1;
type Task = CounterTask;
fn build(self, tx: Sender<(u32, String)>) -> Self::Task {
CounterTask {
tx: Some(tx),
counter: self.counter,
task_id: self.task_id,
}
}
}
fn init_log() {
let _ = env_logger::builder()
.filter_level(LevelFilter::Trace)
.is_test(true)
.try_init();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_stop() {
init_log();
let b = Tasl1Builder {};
let st = SingletonTask::<Error1>::new();
let mut rx = st.start(b).await.unwrap();
for _ in 0..5 {
let r = rx.recv().await.unwrap();
debug!("rcv {r}");
}
let r = rx.stop().await;
debug!("stop: {r:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn test_stop2() {
init_log();
let b = Tasl1Builder {};
let st = SingletonTask::<Error1>::new();
let mut rx = st.start(b).await.unwrap();
let begin = Instant::now();
let h1: JoinHandle<Option<()>> = tokio::spawn(async move {
for _ in 0..10 {
let begin = Instant::now();
match rx.recv().await {
Some(v) => debug!("rcv {v}"),
None => return None,
}
debug!("rcv cost: {:?}", begin.elapsed());
}
Some(())
});
let b = Tasl1Builder {};
sleep(Duration::from_millis(30)).await;
debug!("start 2, delay {:?}", begin.elapsed());
let mut t2 = st.start(b).await.unwrap();
let r = h1.await.unwrap();
debug!("h1 end");
assert!(r.is_none());
while let Some(v) = t2.recv().await {
debug!("2 rcv {v}");
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_concurrent_task_start() {
init_log();
let st = SingletonTask::<Error1>::new();
let st_arc = Arc::new(st);
let mut handles = vec![];
for i in 0..10 {
let st_clone = st_arc.clone();
let handle = tokio::spawn(async move {
let builder = LongRunningTaskBuilder {
task_name: format!("Task{i}"),
duration_ms: 50,
};
match st_clone.start(builder).await {
Ok(mut rx) => {
debug!("Task {i} started successfully");
for _ in 0..3 {
if let Some(msg) = rx.recv().await {
debug!("Task {i} received: {msg}");
}
}
Ok(i)
}
Err(e) => {
debug!("Task {i} failed to start: {e}");
Err(e)
}
}
});
handles.push(handle);
}
let mut successful_tasks = 0;
for handle in handles {
match handle.await.unwrap() {
Ok(task_id) => {
debug!("Task {task_id} completed successfully");
successful_tasks += 1;
}
Err(e) => {
debug!("Task failed: {e}");
}
}
}
debug!("Total successful task starts: {successful_tasks}");
assert!(successful_tasks >= 1);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_rapid_task_replacement() {
init_log();
let st = SingletonTask::<Error1>::new();
let mut last_rx = None;
for i in 0..5 {
let builder = LongRunningTaskBuilder {
task_name: format!("RapidTask{i}"),
duration_ms: 200,
};
match st.start(builder).await {
Ok(rx) => {
debug!("RapidTask{i} started");
last_rx = Some(rx);
}
Err(e) => {
debug!("RapidTask{i} failed: {e}");
}
}
sleep(Duration::from_millis(50)).await;
}
if let Some(mut rx) = last_rx {
let mut received_count = 0;
let timeout = Duration::from_secs(2);
let start = Instant::now();
while start.elapsed() < timeout {
match rx.rx.try_recv() {
Ok(msg) => {
debug!("Final task received: {msg}");
received_count += 1;
if received_count >= 3 {
break;
}
}
Err(_) => {
sleep(Duration::from_millis(10)).await;
}
}
}
assert!(
received_count > 0,
"Should receive at least one message from the final task"
);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_error_handling_multithreaded() {
init_log();
let st = SingletonTask::<Error2>::new();
let st_arc = Arc::new(st);
let mut handles = vec![];
for i in 0..5 {
let st_clone = st_arc.clone();
let handle = tokio::spawn(async move {
let builder = ErrorTaskBuilder {
fail_after: if i < 2 { 0 } else { 5 }, };
let result = st_clone.start(builder).await;
sleep(Duration::from_millis(10)).await;
(i, result)
});
handles.push(handle);
sleep(Duration::from_millis(5)).await;
}
let mut success_count = 0;
let mut error_count = 0;
for handle in handles {
let (task_id, result) = handle.await.unwrap();
match result {
Ok(_) => {
debug!("Task {task_id} succeeded");
success_count += 1;
}
Err(e) => {
debug!("Task {task_id} failed: {e}");
error_count += 1;
}
}
}
debug!("Success: {success_count}, Errors: {error_count}");
assert!(error_count > 0, "Should have some errors");
assert!(
success_count + error_count == 5,
"Should have processed all tasks"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_multiple_task_startup() {
init_log();
let st = SingletonTask::<Error1>::new();
let mut total_messages = 0;
for i in 0..3 {
let builder = LongRunningTaskBuilder {
task_name: format!("MultiTask{i}"),
duration_ms: 50,
};
match st.start(builder).await {
Ok(mut handle) => {
debug!("MultiTask{i} started");
let timeout = Duration::from_millis(500);
let start = Instant::now();
let mut received_count = 0;
while start.elapsed() < timeout && received_count < 3 {
match handle.rx.try_recv() {
Ok(msg) => {
debug!("MultiTask{i} received: {msg}");
received_count += 1;
total_messages += 1;
}
Err(_) => {
sleep(Duration::from_millis(10)).await;
}
}
}
}
Err(e) => {
debug!("MultiTask{i} failed: {e}");
}
}
sleep(Duration::from_millis(100)).await;
}
debug!("Total messages received: {total_messages}");
assert!(total_messages > 0, "Should receive some messages");
}
#[tokio::test(flavor = "multi_thread")]
async fn test_concurrent_stop() {
init_log();
let st = SingletonTask::<Error1>::new();
let builder = LongRunningTaskBuilder {
task_name: "StopTask".to_string(),
duration_ms: 200, };
let mut handle = st.start(builder).await.unwrap();
let ctx = handle.ctx.clone();
sleep(Duration::from_millis(50)).await;
let mut stop_handles = vec![];
for i in 0..3 {
let ctx_clone = ctx.clone();
let stop_handle = tokio::spawn(async move {
sleep(Duration::from_millis(i * 20)).await;
let result = ctx_clone.stop().await;
debug!("Stop attempt {i} result: {result:?}");
result
});
stop_handles.push(stop_handle);
}
let read_handle = tokio::spawn(async move {
let mut messages = vec![];
let timeout = Duration::from_secs(1);
let start = Instant::now();
while start.elapsed() < timeout {
match handle.recv().await {
Some(msg) => {
debug!("Read message: {msg}");
messages.push(msg);
}
None => {
debug!("Channel closed or error");
break;
}
}
}
messages
});
let mut stop_results = vec![];
for stop_handle in stop_handles {
let result = stop_handle.await.unwrap();
stop_results.push(result);
}
let messages = read_handle.await.unwrap();
debug!("Stop results: {stop_results:?}");
debug!("Total messages read: {}", messages.len());
let successful_stops = stop_results.iter().filter(|r| r.is_ok()).count();
debug!("Successful stops: {successful_stops}");
assert!(
successful_stops > 0 || !messages.is_empty(),
"Should either have successful stops or received messages"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_high_concurrency() {
init_log();
let st = SingletonTask::<Error1>::new();
let st_arc = Arc::new(st);
let counter = Arc::new(Mutex::new(0u32));
let mut handles = vec![];
for i in 0..20 {
let st_clone = st_arc.clone();
let counter_clone = counter.clone();
let handle = tokio::spawn(async move {
let builder = CounterTaskBuilder {
counter: counter_clone,
task_id: format!("HighConcurrency{i}"),
};
match st_clone.start(builder).await {
Ok(mut rx) => {
let mut received = 0;
let timeout = Duration::from_secs(1);
let start = Instant::now();
while start.elapsed() < timeout && received < 3 {
if let Some((count, task_id)) = rx.recv().await {
debug!("Task {i} - Count: {count}, TaskId: {task_id}");
received += 1;
}
}
Ok(received)
}
Err(e) => {
debug!("Task {i} failed: {e}");
Err(e)
}
}
});
handles.push(handle);
if i % 5 == 0 {
sleep(Duration::from_millis(10)).await;
}
}
let mut total_received = 0;
let mut successful_starts = 0;
for handle in handles {
if let Ok(received) = handle.await.unwrap() {
total_received += received;
successful_starts += 1;
}
}
let final_counter = *counter.lock().unwrap();
debug!("Successful starts: {successful_starts}");
debug!("Total messages received: {total_received}");
debug!("Final counter value: {final_counter}");
assert!(successful_starts > 0, "Should have successful task starts");
assert!(final_counter > 0, "Counter should be incremented");
}