#include "thread_pool.h"
#include <assert.h>
#include <pthread.h>
#include <stdbool.h>
#include <stdlib.h>
struct task {
work_function_t function;
void *aux;
};
struct thread_pool {
pthread_t *threads;
size_t num_threads;
queue_t *work_queue;
};
void *get_work_from_queue(void *pool) {
thread_pool_t *thread_pool = (thread_pool_t *) pool;
task_t *task = NULL;
for (task = (task_t *) queue_dequeue(thread_pool->work_queue); task != NULL;
task = (task_t *) queue_dequeue(thread_pool->work_queue)) {
assert(task != NULL);
assert(task->function != NULL);
task->function(task->aux);
free(task);
}
return NULL;
}
thread_pool_t *thread_pool_init(size_t num_worker_threads) {
thread_pool_t *thread_pool = (thread_pool_t *) calloc(1, sizeof(thread_pool_t));
if (thread_pool == NULL) {
return NULL;
}
thread_pool->num_threads = num_worker_threads;
thread_pool->work_queue = queue_init();
thread_pool->threads =
(pthread_t *) calloc(thread_pool->num_threads, sizeof(pthread_t));
if (thread_pool->threads == NULL) {
return NULL;
}
for (size_t i = 0; i < thread_pool->num_threads; i++) {
int pthread_error = pthread_create(&thread_pool->threads[i], NULL,
&get_work_from_queue, thread_pool);
assert(pthread_error == 0);
}
return thread_pool;
}
void thread_pool_add_work(thread_pool_t *pool, work_function_t function, void *aux) {
task_t *task = calloc(1, sizeof(task_t));
assert(task != NULL);
task->function = function;
task->aux = aux;
queue_enqueue(pool->work_queue, task);
}
void thread_pool_finish(thread_pool_t *pool) {
for (size_t i = 0; i < pool->num_threads; i++) {
queue_enqueue(pool->work_queue, NULL);
}
for (size_t i = 0; i < pool->num_threads; i++) {
pthread_join(pool->threads[i], NULL);
}
free(pool->threads);
queue_free(pool->work_queue);
free(pool);
}