pub mod lifecycle;
pub mod scheduler;
pub mod worker_pool;
use crate::context::AppContext;
use crate::context::docker_client::DockerClient;
use crate::docker::proxy_manager::ProxyManager;
use crate::task::TaskStatus;
use crate::tui::events::{ServerEvent, ServerEventSender};
use lifecycle::ServerLifecycle;
use scheduler::TaskScheduler;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct TskServer {
app_context: Arc<AppContext>,
docker_client: Arc<dyn DockerClient>,
quit_signal: Arc<tokio::sync::Notify>,
scheduler: Arc<Mutex<TaskScheduler>>,
scheduler_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
submitted_tasks: Arc<Mutex<HashSet<String>>>,
running_flag: Arc<Mutex<bool>>,
lifecycle: ServerLifecycle,
workers: u32,
event_sender: Option<ServerEventSender>,
}
impl TskServer {
pub fn with_workers(
app_context: Arc<AppContext>,
docker_client: Arc<dyn DockerClient>,
workers: u32,
quit_when_done: bool,
event_sender: Option<ServerEventSender>,
) -> Self {
let tsk_env = app_context.tsk_env();
let storage = app_context.task_storage();
let quit_signal = Arc::new(tokio::sync::Notify::new());
let scheduler = TaskScheduler::new(
app_context.clone(),
docker_client.clone(),
storage.clone(),
quit_when_done,
quit_signal.clone(),
event_sender.clone(),
);
let submitted_tasks = scheduler.submitted_task_ids();
let running_flag = scheduler.running_flag();
let scheduler = Arc::new(Mutex::new(scheduler));
let lifecycle = ServerLifecycle::new(tsk_env);
Self {
app_context,
docker_client,
quit_signal,
scheduler,
scheduler_handle: Mutex::new(None),
submitted_tasks,
running_flag,
lifecycle,
workers,
event_sender,
}
}
fn emit(&self, event: ServerEvent) {
crate::tui::events::emit_or_print(&self.event_sender, event);
}
pub async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if self.lifecycle.is_server_running() {
return Err("Server is already running".into());
}
self.lifecycle.write_pid()?;
self.emit(ServerEvent::StatusMessage(format!(
"tsk server started (PID {})",
std::process::id()
)));
let scheduler = self.scheduler.clone();
let workers = self.workers;
let event_sender_clone = self.event_sender.clone();
let scheduler_handle = tokio::spawn(async move {
if let Err(e) = scheduler.lock().await.start(workers).await {
crate::tui::events::emit_or_print(
&event_sender_clone,
ServerEvent::WarningMessage(format!("Scheduler error: {e}")),
);
}
});
*self.scheduler_handle.lock().await = Some(scheduler_handle);
self.quit_signal.notified().await;
self.emit(ServerEvent::StatusMessage(
"Received quit signal from scheduler...".to_string(),
));
if let Some(handle) = self.scheduler_handle.lock().await.take() {
let _ = handle.await;
}
self.lifecycle.cleanup()?;
self.app_context.terminal_operations().restore_title();
Ok(())
}
pub async fn graceful_shutdown(&self) {
*self.running_flag.lock().await = false;
let task_ids: Vec<String> = self.submitted_tasks.lock().await.iter().cloned().collect();
let proxy_manager = ProxyManager::new(
self.docker_client.clone(),
self.app_context.tsk_env(),
self.app_context.tsk_config().container_engine.clone(),
self.event_sender.clone(),
);
if !task_ids.is_empty() {
let docker_client = &self.docker_client;
for id in &task_ids {
let container_name = format!("tsk-{id}");
if let Err(e) = docker_client.kill_container(&container_name).await {
self.emit(ServerEvent::WarningMessage(format!(
"Note: Could not kill container {container_name}: {e}"
)));
}
}
}
if let Some(handle) = self.scheduler_handle.lock().await.take()
&& tokio::time::timeout(std::time::Duration::from_secs(5), handle)
.await
.is_err()
{
self.emit(ServerEvent::WarningMessage(
"Warning: Scheduler did not stop within 5 seconds".to_string(),
));
}
let storage = self.app_context.task_storage();
let mut stopped_fingerprints = std::collections::HashSet::new();
for task_id in &task_ids {
if let Ok(Some(task)) = storage.get_task(task_id).await {
if task.status == TaskStatus::Running {
let _ = storage.mark_cancelled(task_id).await;
}
let resolved = crate::docker::resolve_config_from_task(
&task,
&self.app_context,
&self.event_sender,
);
let proxy_config = resolved.proxy_config();
let fp = proxy_config.fingerprint();
if stopped_fingerprints.insert(fp)
&& let Err(e) = proxy_manager.force_stop_proxy(&proxy_config).await
{
self.emit(ServerEvent::WarningMessage(format!(
"Warning: Failed to stop proxy during shutdown: {e}"
)));
}
}
}
if let Err(e) = self.lifecycle.cleanup() {
self.emit(ServerEvent::WarningMessage(format!(
"Warning: Failed to clean up PID file: {e}"
)));
}
self.app_context.terminal_operations().restore_title();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::Task;
use crate::test_utils::TrackedDockerClient;
#[tokio::test]
async fn test_graceful_shutdown_kills_containers_and_marks_tasks_cancelled() {
let mock_client = Arc::new(TrackedDockerClient::default());
let ctx = Arc::new(AppContext::builder().build());
let server = TskServer::with_workers(ctx.clone(), mock_client.clone(), 1, false, None);
let storage = ctx.task_storage();
let task1 = Task {
id: "task-1".to_string(),
name: "test-task-1".to_string(),
branch_name: "tsk/test/task-1".to_string(),
..Task::test_default()
};
let task2 = Task {
id: "task-2".to_string(),
name: "test-task-2".to_string(),
branch_name: "tsk/test/task-2".to_string(),
..Task::test_default()
};
storage.add_task(task1).await.unwrap();
storage.add_task(task2).await.unwrap();
storage.mark_running("task-1").await.unwrap();
storage.mark_running("task-2").await.unwrap();
{
let mut submitted = server.submitted_tasks.lock().await;
submitted.insert("task-1".to_string());
submitted.insert("task-2".to_string());
}
server.graceful_shutdown().await;
{
let kill_calls = mock_client.kill_container_calls.lock().unwrap();
assert_eq!(kill_calls.len(), 2);
assert!(kill_calls.contains(&"tsk-task-1".to_string()));
assert!(kill_calls.contains(&"tsk-task-2".to_string()));
}
{
let disconnect_calls = mock_client.disconnect_network_calls.lock().unwrap();
assert_eq!(disconnect_calls.len(), 0);
}
{
let remove_calls = mock_client.remove_network_calls.lock().unwrap();
assert_eq!(remove_calls.len(), 0);
}
let t1 = storage.get_task("task-1").await.unwrap().unwrap();
assert_eq!(t1.status, TaskStatus::Cancelled);
let t2 = storage.get_task("task-2").await.unwrap().unwrap();
assert_eq!(t2.status, TaskStatus::Cancelled);
}
#[tokio::test]
async fn test_graceful_shutdown_skips_completed_tasks() {
let mock_client = Arc::new(TrackedDockerClient::default());
let ctx = Arc::new(AppContext::builder().build());
let server = TskServer::with_workers(ctx.clone(), mock_client.clone(), 1, false, None);
let storage = ctx.task_storage();
let task = Task {
id: "task-done".to_string(),
name: "done-task".to_string(),
branch_name: "tsk/test/task-done".to_string(),
..Task::test_default()
};
storage.add_task(task).await.unwrap();
storage.mark_running("task-done").await.unwrap();
storage
.mark_complete("task-done", "tsk/test/task-done")
.await
.unwrap();
server
.submitted_tasks
.lock()
.await
.insert("task-done".to_string());
server.graceful_shutdown().await;
{
let kill_calls = mock_client.kill_container_calls.lock().unwrap();
assert_eq!(kill_calls.len(), 1);
}
{
let disconnect_calls = mock_client.disconnect_network_calls.lock().unwrap();
assert_eq!(disconnect_calls.len(), 0);
}
{
let remove_calls = mock_client.remove_network_calls.lock().unwrap();
assert_eq!(remove_calls.len(), 0);
}
let t = storage.get_task("task-done").await.unwrap().unwrap();
assert_eq!(t.status, TaskStatus::Complete);
}
}