Skip to main content

co_actor/
tokio_task_spawner.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (C) 2026 1io BRANDGUARDIAN GmbH
3
4use crate::TaskOptions;
5use futures::Future;
6use std::{panic::Location, sync::Arc};
7use tokio::task::JoinHandle;
8use tokio_util::task::TaskTracker;
9use tracing::Instrument;
10
11pub type TaskHandle<T> = JoinHandle<T>;
12
13#[derive(Debug, Clone)]
14pub struct TaskSpawner {
15	pub(crate) idenitfier: Arc<String>,
16	pub(crate) inner: TaskTracker,
17}
18impl TaskSpawner {
19	pub fn new(idenitfier: String) -> Self {
20		Self { idenitfier: Arc::new(idenitfier), inner: TaskTracker::new() }
21	}
22
23	/// Spawn task.
24	#[inline]
25	#[track_caller]
26	pub fn spawn<F>(&self, task: F) -> TaskHandle<F::Output>
27	where
28		F: Future + Send + 'static,
29		F::Output: Send + 'static,
30	{
31		let caller_file = Location::caller().file();
32		let caller_line = Location::caller().line();
33		let caller_column = Location::caller().column();
34		let span = tracing::trace_span!(
35			"task",
36			application = self.idenitfier.as_str(),
37			caller_file,
38			caller_line,
39			caller_column,
40		);
41		self.inner.spawn(task.instrument(span))
42	}
43
44	/// Spawn task.
45	#[inline]
46	#[track_caller]
47	#[allow(unexpected_cfgs)]
48	pub fn spawn_named<F>(&self, name: &str, task: F) -> TaskHandle<F::Output>
49	where
50		F: Future + Send + 'static,
51		F::Output: Send + 'static,
52	{
53		let caller_file = Location::caller().file();
54		let caller_line = Location::caller().line();
55		let caller_column = Location::caller().column();
56		let span = tracing::trace_span!(
57			"task",
58			task_name = name,
59			application = self.idenitfier.as_str(),
60			caller_file,
61			caller_line,
62			caller_column,
63		);
64		#[cfg(tokio_unstable)]
65		{
66			tokio::task::Builder::new()
67				.name(name)
68				.spawn(self.inner.track_future(task.instrument(span)))
69				.expect("tokio runtime")
70		}
71		#[cfg(not(tokio_unstable))]
72		{
73			self.inner.spawn(task.instrument(span))
74		}
75	}
76
77	/// Spawn task.
78	#[inline]
79	#[track_caller]
80	#[allow(unexpected_cfgs)]
81	pub fn spawn_options<F>(&self, options: TaskOptions, task: F) -> TaskHandle<F::Output>
82	where
83		F: Future + Send + 'static,
84		F::Output: Send + 'static,
85	{
86		let caller_file = Location::caller().file();
87		let caller_line = Location::caller().line();
88		let caller_column = Location::caller().column();
89		let span = tracing::trace_span!(
90			"task",
91			task_name = options.name,
92			application = self.idenitfier.as_str(),
93			caller_file,
94			caller_line,
95			caller_column,
96		);
97		#[cfg(tokio_unstable)]
98		{
99			let mut builder = tokio::task::Builder::new();
100			if let Some(name) = options.name {
101				builder = builder.name(name);
102			}
103			builder
104				.spawn(if options.untracked {
105					futures::future::Either::Left(task.instrument(span))
106				} else {
107					futures::future::Either::Right(self.inner.track_future(task.instrument(span)))
108				})
109				.expect("tokio runtime")
110		}
111		#[cfg(not(tokio_unstable))]
112		if options.untracked {
113			tokio::spawn(task.instrument(span))
114		} else {
115			self.inner.spawn(task.instrument(span))
116		}
117	}
118
119	/// Spawn blocking task.
120	#[inline]
121	#[track_caller]
122	#[allow(unexpected_cfgs)]
123	pub fn spawn_blocking<F, R>(&self, options: TaskOptions, task: F) -> TaskHandle<R>
124	where
125		F: FnOnce() -> R + Send + 'static,
126		R: Send + 'static,
127	{
128		let caller_file = Location::caller().file();
129		let caller_line = Location::caller().line();
130		let caller_column = Location::caller().column();
131		let span = tracing::trace_span!(
132			"task-blocking",
133			task_name = options.name,
134			application = self.idenitfier.as_str(),
135			caller_file,
136			caller_line,
137			caller_column,
138		);
139		let task = move || {
140			let _span_guard = span.enter();
141			task()
142		};
143		#[cfg(tokio_unstable)]
144		{
145			let mut builder = tokio::task::Builder::new();
146			if let Some(name) = options.name {
147				builder = builder.name(name);
148			}
149			builder
150				.spawn_blocking(if options.untracked {
151					futures::future::Either::Left(task)
152				} else {
153					futures::future::Either::Right(self.inner.track_future(task))
154				})
155				.expect("tokio runtime")
156		}
157		#[cfg(not(tokio_unstable))]
158		if options.untracked {
159			tokio::task::spawn_blocking(task)
160		} else {
161			self.inner.spawn_blocking(task)
162		}
163	}
164
165	pub fn tracker(&self) -> TaskTracker {
166		self.inner.clone()
167	}
168}
169impl Default for TaskSpawner {
170	fn default() -> Self {
171		Self { idenitfier: Arc::new("default".to_string()), inner: Default::default() }
172	}
173}