gooty_proxy/orchestration/
threading.rs

1//! # Threading Module
2//!
3//! Provides threading utilities for managing concurrent tasks in the Gooty Proxy system.
4//!
5//! ## Overview
6//!
7//! This module includes abstractions and helpers for:
8//! - Spawning and managing threads
9//! - Synchronizing shared data between threads
10//! - Handling thread-safe operations
11//!
12//! ## Examples
13//!
14//! ```
15//! use gooty_proxy::orchestration::threading;
16//!
17//! // Example of spawning a thread
18//! let handle = threading::spawn(|| {
19//!     println!("Thread running");
20//! });
21//! handle.join().unwrap();
22//! ```
23
24/// Provides threading utilities for orchestration.
25///
26/// This module contains helper functions and abstractions for managing
27/// threads in the orchestration layer of the application.
28///
29/// # Examples
30///
31/// ```
32/// use gooty_proxy::orchestration::threading;
33///
34/// threading::spawn_worker(|| {
35///     println!("Worker thread running");
36/// });
37/// ```
38use futures::{StreamExt, stream};
39use std::future::Future;
40use std::pin::Pin;
41use tokio::sync::mpsc;
42use tokio::task::JoinHandle;
43
44/// Manages a collection of task handles for concurrent execution
45#[derive(Default)]
46pub struct TaskManager {
47    tasks: Vec<JoinHandle<()>>,
48}
49
50impl TaskManager {
51    /// Create a new task manager
52    #[must_use]
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Spawn a new task and add it to the managed set
58    pub fn spawn<F>(&mut self, future: F)
59    where
60        F: Future<Output = ()> + Send + 'static,
61    {
62        let handle = tokio::spawn(future);
63        self.tasks.push(handle);
64    }
65
66    /// Wait for all tasks to complete
67    pub async fn join_all(&mut self) {
68        while let Some(task) = self.tasks.pop() {
69            let _ = task.await;
70        }
71    }
72
73    /// Cancel all running tasks
74    pub fn cancel_all(&mut self) {
75        for task in self.tasks.drain(..) {
76            task.abort();
77        }
78    }
79}
80
81/// Creates a set of worker tasks with a bounded channel for work distribution
82pub fn create_worker_pool<T, F, Fut>(
83    concurrency: usize,
84    worker_fn: F,
85) -> (mpsc::Sender<T>, TaskManager)
86where
87    T: Send + 'static,
88    F: FnMut(T) -> Fut + Send + Clone + 'static,
89    Fut: Future<Output = ()> + Send + 'static,
90{
91    let (tx, rx) = mpsc::channel::<T>(concurrency);
92    let rx = std::sync::Arc::new(tokio::sync::Mutex::new(rx));
93
94    let mut task_manager = TaskManager::new();
95
96    for _ in 0..concurrency {
97        let mut worker_fn = worker_fn.clone();
98        let rx = rx.clone();
99
100        task_manager.spawn(async move {
101            loop {
102                let message = {
103                    let mut rx_lock = rx.lock().await;
104                    rx_lock.recv().await
105                };
106
107                match message {
108                    Some(item) => {
109                        worker_fn(item).await;
110                    }
111                    None => break,
112                }
113            }
114        });
115    }
116
117    (tx, task_manager)
118}
119
120/// Execute multiple futures concurrently with a limit on parallelism
121///
122/// # Panics
123///
124/// This function will panic if the semaphore is closed, which can happen
125/// if the semaphore is dropped while permits are still active.
126pub async fn execute_with_concurrency_limit<T, F, Fut>(
127    items: Vec<T>,
128    concurrency: usize,
129    mut job_fn: F,
130) -> Vec<Pin<Box<dyn Future<Output = ()> + Send>>>
131where
132    T: Send + 'static,
133    F: FnMut(T) -> Fut + Send,
134    Fut: Future<Output = ()> + Send + 'static,
135{
136    let mut futures = Vec::new();
137    let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(concurrency));
138
139    for item in items {
140        let permit = semaphore.clone().acquire_owned().await.unwrap();
141        let future = job_fn(item);
142
143        futures.push(Box::pin(async move {
144            future.await;
145            drop(permit);
146        }) as Pin<Box<dyn Future<Output = ()> + Send>>);
147    }
148
149    futures
150}
151
152/// Run a batch of operations concurrently with limited parallelism.
153///
154/// This function takes a collection of items, a concurrency limit, and a job function.
155/// It processes the items concurrently but limited to the specified level of parallelism,
156/// returning the results when all operations are complete.
157///
158/// # Type Parameters
159///
160/// * `T` - The input item type
161/// * `R` - The result type
162/// * `F` - The function type that processes each item
163/// * `Fut` - The future type returned by the function
164///
165/// # Arguments
166///
167/// * `items` - Vector of items to process
168/// * `concurrency` - Maximum number of concurrent operations
169/// * `job_fn` - Function that processes each item and returns a future
170///
171/// # Returns
172///
173/// A vector containing the results of all operations in the same order as the input items.
174///
175/// # Examples
176///
177/// ```
178/// async fn process_item(item: u32) -> u32 {
179///     // Some async processing
180///     item * 2
181/// }
182///
183/// let items = vec![1, 2, 3, 4, 5];
184/// let concurrency = 2;
185/// let results = run_concurrent_batch(items, concurrency, |item| async move {
186///     process_item(item).await
187/// }).await;
188/// ```
189pub async fn run_concurrent_batch<T, R, F>(
190    items: Vec<T>,
191    concurrency: usize,
192    job_fn: &F,
193) -> Vec<(R, bool)>
194where
195    T: Send + 'static,
196    R: Send + 'static,
197    F: Fn(T) -> Pin<Box<dyn Future<Output = (R, bool)> + Send>> + Send + Sync + Clone + 'static,
198{
199    // Create a buffered stream with the specified concurrency
200    stream::iter(items)
201        .map(|item| {
202            let job = job_fn.clone();
203            async move { job(item).await }
204        })
205        .buffer_unordered(concurrency.max(1)) // Ensure at least 1 concurrency
206        .collect::<Vec<_>>()
207        .await
208}
209
210/// Process items concurrently with a shared state
211///
212/// Similar to `run_concurrent_batch`, but allows for a shared state that
213/// can be accessed and modified by each job.
214///
215/// # Type Parameters
216///
217/// * `T` - The type of items to process
218/// * `R` - The result type returned by the job function
219/// * `S` - The shared state type
220/// * `F` - The job function type
221///
222/// # Arguments
223///
224/// * `items` - Vector of items to process
225/// * `state` - Shared state accessible by all jobs (must be thread-safe)
226/// * `concurrency` - Maximum number of concurrent operations
227/// * `job_fn` - Function to process each item with access to the shared state
228///
229/// # Returns
230///
231/// A vector containing the results from processing each item
232pub async fn run_concurrent_batch_with_state<T, R, S, F>(
233    items: Vec<T>,
234    state: S,
235    concurrency: usize,
236    job_fn: F,
237) -> Vec<(R, bool)>
238where
239    T: Send + 'static,
240    R: Send + 'static,
241    S: Clone + Send + Sync + 'static,
242    F: Fn(T, S) -> Pin<Box<dyn Future<Output = (R, bool)> + Send>> + Send + Sync + Clone + 'static,
243{
244    // Create a buffered stream with the specified concurrency
245    stream::iter(items)
246        .map(move |item| {
247            let job = job_fn.clone();
248            let state = state.clone();
249            async move { job(item, state).await }
250        })
251        .buffer_unordered(concurrency.max(1)) // Ensure at least 1 concurrency
252        .collect::<Vec<_>>()
253        .await
254}
255
256/// Runs a batch of operations with progress reporting.
257///
258/// Similar to `run_concurrent_batch`, but also reports progress through a callback function.
259/// This is useful for long-running operations where you want to update a progress bar
260/// or log periodic status updates.
261///
262/// # Type Parameters
263///
264/// * `T` - The input item type
265/// * `R` - The result type
266/// * `F` - The function type that processes each item
267/// * `Fut` - The future type returned by the function
268/// * `P` - The progress callback function type
269///
270/// # Arguments
271///
272/// * `items` - Vector of items to process
273/// * `concurrency` - Maximum number of concurrent operations
274/// * `job_fn` - Function that processes each item and returns a future
275/// * `progress_fn` - Callback function called after each item is processed
276///
277/// # Returns
278///
279/// A vector containing the results of all operations in the same order as the input items.
280pub async fn run_concurrent_batch_with_progress<T, R, F, Fut, P>(
281    items: Vec<T>,
282    concurrency: usize,
283    job_fn: impl Fn(T) -> Fut + Send + Sync + Clone + 'static,
284    progress_fn: impl Fn(usize, &R) + Send + Sync + Clone + 'static,
285) -> Vec<R>
286where
287    T: Send + 'static,
288    R: Send + 'static,
289    F: FnOnce(T) -> Fut + Send + Sync + Clone + 'static,
290    Fut: Future<Output = R> + Send,
291    P: Fn(usize, &R) + Send + Sync + Clone + 'static,
292{
293    let mut results = Vec::with_capacity(items.len());
294
295    // Process in batches to allow for progress reporting
296    let mut iter = items.into_iter().enumerate();
297
298    loop {
299        let batch: Vec<(usize, T)> = iter.by_ref().take(concurrency).collect();
300        if batch.is_empty() {
301            break;
302        }
303
304        // Process this batch concurrently
305        let batch_results = stream::iter(batch)
306            .map(|(idx, item)| {
307                let job = job_fn.clone();
308                async move { (idx, job(item).await) }
309            })
310            .buffer_unordered(concurrency)
311            .collect::<Vec<(usize, R)>>()
312            .await;
313
314        // Update progress for each result
315        for (idx, result) in &batch_results {
316            let progress = progress_fn.clone();
317            progress(*idx, result);
318        }
319
320        // Store results
321        results.extend(batch_results.into_iter().map(|(_, r)| r));
322    }
323
324    results
325}