use std::path::{Path, PathBuf};
use crate::error::OlError;
use super::{Supervisor, SupervisorKind, SupervisorStatus, ERR_SUPERVISION_INSTALL_FAILED};
const TASK_NAME: &str = "OpenLatch\\Client";
pub struct TaskSchedulerSupervisor;
impl Default for TaskSchedulerSupervisor {
fn default() -> Self {
Self::new()
}
}
impl TaskSchedulerSupervisor {
pub fn new() -> Self {
Self
}
fn generate_xml(&self, binary_path: &Path, user: &str) -> String {
let bin = binary_path.display();
format!(
r#"<?xml version="1.0" encoding="UTF-16"?>
<Task version="1.2" xmlns="http://schemas.microsoft.com/windows/2004/02/mit/task">
<Triggers>
<LogonTrigger>
<Enabled>true</Enabled>
<UserId>{user}</UserId>
</LogonTrigger>
</Triggers>
<Settings>
<MultipleInstancesPolicy>IgnoreNew</MultipleInstancesPolicy>
<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>
<StopIfGoingOnBatteries>false</StopIfGoingOnBatteries>
<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>
<Hidden>false</Hidden>
<RestartOnFailure>
<Interval>PT1M</Interval>
<Count>5</Count>
</RestartOnFailure>
</Settings>
<Principals>
<Principal id="Author">
<UserId>{user}</UserId>
<LogonType>InteractiveToken</LogonType>
<RunLevel>LeastPrivilege</RunLevel>
</Principal>
</Principals>
<Actions Context="Author">
<Exec>
<Command>{bin}</Command>
<Arguments>daemon start --foreground</Arguments>
</Exec>
</Actions>
</Task>"#
)
}
}
fn encode_utf16le_bom(s: &str) -> Vec<u8> {
let units: Vec<u16> = s.encode_utf16().collect();
let mut bytes = Vec::with_capacity(2 + units.len() * 2);
bytes.extend_from_slice(&[0xFF, 0xFE]); for u in units {
bytes.extend_from_slice(&u.to_le_bytes());
}
bytes
}
fn current_user_identifier() -> String {
let username = std::env::var("USERNAME").unwrap_or_default();
let domain = std::env::var("USERDOMAIN").unwrap_or_default();
if !domain.is_empty() && !username.is_empty() {
format!("{domain}\\{username}")
} else if !username.is_empty() {
username
} else {
"S-1-5-32-545".to_string()
}
}
impl Supervisor for TaskSchedulerSupervisor {
fn kind(&self) -> SupervisorKind {
SupervisorKind::TaskScheduler
}
fn install(&self, binary_path: &Path) -> Result<(), OlError> {
let user = current_user_identifier();
let xml = self.generate_xml(binary_path, &user);
let bytes = encode_utf16le_bom(&xml);
let tmp_dir =
std::env::var("TEMP").unwrap_or_else(|_| std::env::var("TMP").unwrap_or_default());
let xml_path = PathBuf::from(&tmp_dir).join("openlatch-task.xml");
std::fs::write(&xml_path, &bytes).map_err(|e| {
OlError::new(
ERR_SUPERVISION_INSTALL_FAILED,
format!("Cannot write task XML: {e}"),
)
})?;
let output = std::process::Command::new("schtasks")
.args([
"/Create",
"/TN",
TASK_NAME,
"/XML",
&xml_path.display().to_string(),
"/RU",
&user,
"/IT",
"/F",
])
.output();
let _ = std::fs::remove_file(&xml_path);
match output {
Ok(o) if o.status.success() => Ok(()),
Ok(o) => {
let stderr = String::from_utf8_lossy(&o.stderr);
Err(OlError::new(
ERR_SUPERVISION_INSTALL_FAILED,
format!("schtasks /Create failed: {stderr}"),
))
}
Err(e) => Err(OlError::new(
ERR_SUPERVISION_INSTALL_FAILED,
format!("Cannot run schtasks: {e}"),
)),
}
}
fn uninstall(&self) -> Result<(), OlError> {
let _ = std::process::Command::new("schtasks")
.args(["/Delete", "/TN", TASK_NAME, "/F"])
.output();
Ok(())
}
fn status(&self) -> Result<SupervisorStatus, OlError> {
let output = std::process::Command::new("schtasks")
.args(["/Query", "/TN", TASK_NAME])
.output();
match output {
Ok(o) if o.status.success() => Ok(SupervisorStatus {
installed: true,
running: true,
description: "Task Scheduler (RestartOnFailure active)".into(),
}),
_ => Ok(SupervisorStatus {
installed: false,
running: false,
description: "not installed".into(),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xml_hidden_is_false() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "USER\\me");
assert!(xml.contains("<Hidden>false</Hidden>"));
}
#[test]
fn xml_has_execution_time_limit_zero() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "USER\\me");
assert!(xml.contains("<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>"));
}
#[test]
fn xml_has_logon_trigger() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "USER\\me");
assert!(xml.contains("<LogonTrigger>"));
}
#[test]
fn xml_has_restart_on_failure() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "USER\\me");
assert!(xml.contains("<RestartOnFailure>"));
}
#[test]
fn xml_has_least_privilege() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "USER\\me");
assert!(xml.contains("<RunLevel>LeastPrivilege</RunLevel>"));
}
#[test]
fn xml_declares_utf16_to_match_on_disk_bytes() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "USER\\me");
assert!(xml.starts_with("<?xml version=\"1.0\" encoding=\"UTF-16\"?>"));
}
#[test]
fn utf16le_bom_encoder_emits_bom_and_little_endian_units() {
let out = encode_utf16le_bom("A");
assert_eq!(out, vec![0xFF, 0xFE, 0x41, 0x00]);
}
#[test]
fn utf16le_bom_encoder_roundtrips_ascii_xml_prolog() {
let prolog = "<?xml version=\"1.0\"?>";
let bytes = encode_utf16le_bom(prolog);
assert_eq!(bytes.len(), 2 + prolog.len() * 2);
assert_eq!(&bytes[0..2], &[0xFF, 0xFE]);
assert_eq!(&bytes[2..4], &[0x3C, 0x00]);
}
#[test]
fn xml_principal_has_user_id() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "MYHOST\\alice");
assert!(xml.contains("<UserId>MYHOST\\alice</UserId>"));
}
#[test]
fn xml_actions_context_is_author() {
let sup = TaskSchedulerSupervisor::new();
let xml = sup.generate_xml(Path::new("C:\\openlatch\\openlatch.exe"), "USER\\me");
assert!(xml.contains("<Actions Context=\"Author\">"));
}
#[test]
fn current_user_identifier_prefers_domain_backslash_user() {
let domain = std::env::var("USERDOMAIN").ok();
let username = std::env::var("USERNAME").ok();
let id = current_user_identifier();
if let (Some(d), Some(u)) = (domain.as_deref(), username.as_deref()) {
if !d.is_empty() && !u.is_empty() {
assert_eq!(id, format!("{d}\\{u}"));
}
}
}
}