pezsc-rpc 29.0.0

Bizinikiwi Client RPC
Documentation
// This file is part of Bizinikiwi.

// Copyright (C) Parity Technologies (UK) Ltd. and Dijital Kurdistan Tech Institute
// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0

// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

use std::{
	sync::{
		atomic::{AtomicBool, Ordering},
		Arc,
	},
	time::Duration,
};

/// An error signifying that a task has been cancelled due to a timeout.
#[derive(Debug)]
pub struct Timeout;

impl std::error::Error for Timeout {}
impl std::fmt::Display for Timeout {
	fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
		fmt.write_str("task has been running too long")
	}
}

/// A handle which can be used to check whether the task has been cancelled due to a timeout.
#[repr(transparent)]
pub struct IsTimedOut(Arc<AtomicBool>);

impl IsTimedOut {
	#[must_use]
	pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> {
		if self.0.load(Ordering::Relaxed) {
			Err(Timeout)
		} else {
			Ok(())
		}
	}
}

/// An error for a task which either panicked, or has been cancelled due to a timeout.
#[derive(Debug)]
pub enum SpawnWithTimeoutError {
	JoinError(tokio::task::JoinError),
	Timeout,
}

impl std::error::Error for SpawnWithTimeoutError {}
impl std::fmt::Display for SpawnWithTimeoutError {
	fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
		match self {
			SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt),
			SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt),
		}
	}
}

struct CancelOnDrop(Arc<AtomicBool>);
impl Drop for CancelOnDrop {
	fn drop(&mut self) {
		self.0.store(true, Ordering::Relaxed)
	}
}

/// Spawns a new blocking task with a given `timeout`.
///
/// The `callback` should continuously call [`IsTimedOut::check_if_timed_out`],
/// which will return an error once the task runs for longer than `timeout`.
///
/// If `timeout` is `None` then this works just as a regular `spawn_blocking`.
pub async fn spawn_blocking_with_timeout<R>(
	timeout: Option<Duration>,
	callback: impl FnOnce(IsTimedOut) -> std::result::Result<R, Timeout> + Send + 'static,
) -> Result<R, SpawnWithTimeoutError>
where
	R: Send + 'static,
{
	let is_timed_out_arc = Arc::new(AtomicBool::new(false));
	let is_timed_out = IsTimedOut(is_timed_out_arc.clone());
	let _cancel_on_drop = CancelOnDrop(is_timed_out_arc);
	let task = tokio::task::spawn_blocking(move || callback(is_timed_out));

	let result = if let Some(timeout) = timeout {
		tokio::select! {
			// Shouldn't really matter, but make sure the task is polled before the timeout,
			// in case the task finishes after the timeout and the timeout is really short.
			biased;

			task_result = task => task_result,
			_ = tokio::time::sleep(timeout) => Ok(Err(Timeout))
		}
	} else {
		task.await
	};

	match result {
		Ok(Ok(result)) => Ok(result),
		Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout),
		Err(error) => Err(SpawnWithTimeoutError::JoinError(error)),
	}
}

#[cfg(test)]
mod tests {
	use super::*;

	#[tokio::test]
	async fn spawn_blocking_with_timeout_works() {
		let task: Result<(), SpawnWithTimeoutError> =
			spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
				std::thread::sleep(Duration::from_millis(200));
				is_timed_out.check_if_timed_out()?;
				unreachable!();
			})
			.await;

		assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout));

		let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
			std::thread::sleep(Duration::from_millis(20));
			is_timed_out.check_if_timed_out()?;
			Ok(())
		})
		.await;

		assert_matches::assert_matches!(task, Ok(()));
	}
}