1use std::{
2 collections::HashMap,
3 ffi::CString,
4 io::{BufRead, BufReader, BufWriter, Write},
5 net::{TcpListener, TcpStream},
6 sync::{
7 mpsc::{channel, Receiver, Sender},
8 Mutex,
9 },
10};
11
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use windows::{
14 core::PCSTR,
15 Win32::{
16 Foundation::{CloseHandle, HANDLE, HWND},
17 Security::{
18 AdjustTokenPrivileges, DuplicateTokenEx, GetTokenInformation, LookupPrivilegeValueA,
19 SecurityImpersonation, TokenElevation, TokenPrimary, LUID_AND_ATTRIBUTES,
20 SE_PRIVILEGE_ENABLED, TOKEN_ACCESS_MASK, TOKEN_ELEVATION, TOKEN_PRIVILEGES,
21 },
22 System::Threading::{
23 GetCurrentProcessId, OpenProcess, OpenProcessToken, PROCESS_ACCESS_RIGHTS,
24 PROCESS_ALL_ACCESS, PROCESS_INFORMATION_CLASS, PROCESS_SET_INFORMATION,
25 },
26 UI::{Shell::ShellExecuteA, WindowsAndMessaging::SW_HIDE},
27 },
28};
29
30use crate::util::{create_process, CommandLineBuilder, ProcessControlFlow};
31
32const MAXIMUM_ALLOWED: TOKEN_ACCESS_MASK = TOKEN_ACCESS_MASK(0x02000000);
33const PROCESS_ACCESS_TOKEN: PROCESS_INFORMATION_CLASS = PROCESS_INFORMATION_CLASS(9);
34
35#[link(name = "ntdll.dll", kind = "raw-dylib", modifiers = "+verbatim")]
36extern "system" {
37 #[link_name = "NtSetInformationProcess"]
38 pub fn NtSetInformationProcess(
39 process: HANDLE,
40 processinformationclass: PROCESS_INFORMATION_CLASS,
41 lpprocessinformation: *mut ProcessAccessToken,
42 processInformationLength: usize,
43 ) -> isize;
44}
45
46#[ctor::ctor]
47fn elevate_by_command_line() {
48 if let Some(ElevateToken::Elevate { port }) = ElevateToken::from_command_line() {
49 let code = match listen_elevation_request(port) {
50 Ok(_) => 0,
51 Err(_) => -1,
52 };
53 std::process::exit(code);
54 }
55}
56
57fn listen_elevation_request(port: u16) -> Result<(), String> {
58 let stream = TcpStream::connect(format!("127.0.0.1:{port}")).map_err(|err| format!("{err}"))?;
59 stream.set_nodelay(true).map_err(|err| format!("{err}"))?;
60
61 let reader = BufReader::new(stream.try_clone().map_err(|err| format!("{err}"))?);
62 let mut writer = BufWriter::new(stream);
63
64 for l in reader.lines() {
65 let l = l.map_err(|err| format!("{err}"))?;
66 let request: ElevationRequest = serde_json::from_str(&l).map_err(|err| format!("{err}"))?;
67 let result = replace_with_current_token(request.pid);
68 let success = result.is_ok();
69 let error = result.map_or_else(|err| Some(err), |_| None);
70 let msg = format!(
71 "{}\n",
72 serde_json::to_string(&ElevationResponse {
73 id: request.id,
74 success,
75 error
76 })
77 .map_err(|err| format!("{err}"))?
78 );
79 writer
80 .write_all(msg.as_bytes())
81 .map_err(|err| format!("{err}"))?;
82 writer.flush().map_err(|err| format!("{err}"))?;
83 }
84
85 Ok(())
86}
87
88pub static GLOBAL_CLIENT: ElevationClient = ElevationClient::new();
89
90pub struct ElevationClient {
91 pipe: Mutex<Option<Sender<ElevationRequest>>>,
92 pending: Mutex<Vec<(String, Sender<ElevationResponse>)>>,
93}
94
95fn start_elevation_host(receiver: Receiver<ElevationRequest>) -> Result<u16, String> {
96 let listener = TcpListener::bind("127.0.0.1:0").map_err(|err| format!("{err}"))?;
97 let port = listener
98 .local_addr()
99 .map_err(|err| format!("{err}"))?
100 .port();
101 std::thread::spawn(move || {
102 if let Ok((client, _)) = listener.accept() {
103 let _ = client.set_nodelay(true);
104 let Ok(stream_cloned) = client.try_clone() else {
105 return;
106 };
107
108 let reader = BufReader::new(stream_cloned);
110 let t1 = std::thread::spawn(move || {
111 for l in reader.lines() {
112 match l {
113 Ok(l) => {
114 let _ = GLOBAL_CLIENT.receive(&l);
115 }
116 Err(_) => break,
117 }
118 }
119 });
120
121 let mut writer = BufWriter::new(client);
123 let t2 = std::thread::spawn(move || {
124 while let Ok(req) = { receiver.recv() } {
125 match serde_json::to_string(&req).map(|s| s + "\n") {
126 Ok(msg) => {
127 let _ = writer
128 .write_all(msg.as_bytes())
129 .and_then(|_| writer.flush());
130 }
131 Err(_) => {}
132 }
133 }
134 });
135
136 let _ = t1.join();
137 let _ = t2.join();
138 }
139 });
140 Ok(port)
141}
142
143impl ElevationClient {
144 pub const fn new() -> Self {
145 Self {
146 pipe: Mutex::new(None),
147 pending: Mutex::new(Vec::new()),
148 }
149 }
150
151 pub fn request(&self, request: ElevationRequest) -> Result<(), String> {
152 let id = {
153 let mut lock = self.pipe.lock().map_err(|err| format!("{err}"))?;
154 if lock.is_none() {
155 let (sender, receiver) = channel();
156 *lock = Some(sender);
157
158 let port = start_elevation_host(receiver)?;
159
160 let token = ElevateToken::Elevate { port };
161 let cmd = CommandLineBuilder::new().arg(&token.to_string()).encode();
162 run_as(
163 std::env::current_exe()
164 .map_err(|err| format!("{err}"))?
165 .to_str()
166 .ok_or_else(|| format!("Current executable path invalid"))?,
167 &cmd,
168 );
169 }
170
171 let id = request.id.to_owned();
172 let sender = lock.as_ref().unwrap();
173 sender.send(request).map_err(|err| format!("{err}"))?;
174
175 id
176 };
177
178 let wait_recv = {
179 let (wait_send, wait_recv) = channel();
180 let mut lock = self.pending.lock().map_err(|err| format!("{err}"))?;
181 lock.push((id, wait_send));
182 wait_recv
183 };
184
185 wait_recv.recv().map_err(|err| format!("{err}"))?;
186
187 Ok(())
188 }
189
190 pub fn receive(&self, content: &str) -> Result<(), String> {
191 let response: ElevationResponse =
192 serde_json::from_str(content).map_err(|err| format!("{err}"))?;
193 let lock = self.pending.lock().map_err(|err| format!("{err}"))?;
194 for (id, sender) in lock.iter() {
195 if id == &response.id {
196 let _ = sender.send(response);
197 break;
198 }
199 }
200 Ok(())
201 }
202}
203
204pub trait ElevatedOperation: DeserializeOwned + Serialize {
205 fn id() -> &'static str;
206
207 fn check() -> Result<(), String> {
208 try_execute_task::<Self>(Self::id())
209 }
210
211 fn execute(&self) -> Result<(), String> {
212 if is_elevated() {
213 unreachable!();
214 }
215
216 let id = Self::id();
217 let json = serde_json::to_string(self).map_err(|err| format!("{err}"))?;
218 let token = ElevateToken::Execute {
219 task_id: id.to_string(),
220 payload: json,
221 };
222 create_process(&[&token.to_string()], |pid| {
223 match GLOBAL_CLIENT.request(ElevationRequest::new(pid)) {
224 Ok(_) => ProcessControlFlow::ResumeMainThread,
225 Err(_) => ProcessControlFlow::Terminate,
226 }
227 });
228
229 Ok(())
230 }
231
232 fn run(&self);
233}
234
235#[derive(Debug)]
236pub enum ElevateToken {
237 Elevate { port: u16 },
238 Execute { task_id: String, payload: String },
239}
240
241#[derive(Serialize, Deserialize)]
242pub struct ElevationRequest {
243 id: String,
244 pid: u32,
245}
246
247#[derive(Serialize, Deserialize)]
248pub struct ElevationResponse {
249 id: String,
250 success: bool,
251 error: Option<String>,
252}
253
254impl ElevationRequest {
255 pub fn new(pid: u32) -> Self {
256 Self {
257 id: std::time::SystemTime::now()
258 .duration_since(std::time::SystemTime::UNIX_EPOCH)
259 .unwrap()
260 .as_millis()
261 .to_string(),
262 pid,
263 }
264 }
265}
266
267impl ElevateToken {
268 pub fn from_command_line() -> Option<Self> {
269 std::env::args()
270 .skip(1)
271 .next()
272 .and_then(|s| ElevateToken::from_str(&s))
273 }
274
275 pub fn from_str(s: &str) -> Option<Self> {
276 const PREFIX: &str = "--elevate-token=";
277 if !s.starts_with(PREFIX) {
278 return None;
279 }
280 let s = &s[PREFIX.len()..];
281 let (cmd, s) = s.split_once(',')?;
282 match cmd {
283 "elevate" => {
284 let map: HashMap<_, _> = s.split(',').filter_map(|s| s.split_once('=')).collect();
285 let port: u16 = map.get("port")?.parse().ok()?;
286 Some(ElevateToken::Elevate { port })
287 }
288 "execute" => {
289 let (id, s) = s.split_once(',')?;
290 let (_, id) = id.split_once('=')?;
291 Some(ElevateToken::Execute {
292 task_id: id.to_string(),
293 payload: s.to_string(),
294 })
295 }
296 _ => None,
297 }
298 }
299
300 pub fn to_string(&self) -> String {
301 match self {
302 ElevateToken::Elevate { port } => {
303 format!("--elevate-token=elevate,port={port}")
304 }
305 ElevateToken::Execute { task_id, payload } => {
306 format!("--elevate-token=execute,id={},{}", task_id, payload)
307 }
308 }
309 }
310}
311
312pub struct ProcessHandle(HANDLE);
313
314impl ProcessHandle {
315 pub fn from_pid(pid: u32, access: PROCESS_ACCESS_RIGHTS) -> Result<Self, String> {
316 Ok(Self(unsafe {
317 OpenProcess(access, true, pid).map_err(|err| format!("{err}"))
318 }?))
319 }
320
321 pub fn from_current_process() -> Result<Self, String> {
322 Self::from_pid(unsafe { GetCurrentProcessId() }, PROCESS_ALL_ACCESS)
323 }
324
325 pub fn replace_primary_token(&self, token: &ProcessToken) -> Result<(), String> {
326 let mut info: ProcessAccessToken = ProcessAccessToken {
327 thread: HANDLE::default(),
328 token: token.1,
329 };
330 let ret = unsafe {
331 NtSetInformationProcess(
332 self.0,
333 PROCESS_ACCESS_TOKEN,
334 &mut info,
335 std::mem::size_of_val(&info),
336 )
337 };
338 match ret {
339 0 => Ok(()),
340 code => Err(format!("{}", std::io::Error::from_raw_os_error(code as _))),
341 }
342 }
343}
344
345impl Drop for ProcessHandle {
346 fn drop(&mut self) {
347 let _ = unsafe { CloseHandle(self.0) };
348 }
349}
350
351pub struct ProcessToken<'h>(&'h ProcessHandle, HANDLE);
352
353impl<'h> ProcessToken<'h> {
354 pub fn open_process(process: &'h ProcessHandle) -> Result<Self, String> {
355 let mut token = Default::default();
356 unsafe { OpenProcessToken(process.0, MAXIMUM_ALLOWED, &mut token) }
357 .map_err(|err| format!("{err}"))?;
358 Ok(Self(process, token))
359 }
360
361 #[allow(dead_code)]
362 pub fn enable_privilege(&self, name: &str) -> Result<(), String> {
363 let name = CString::new(name).map_err(|err| format!("{err}"))?;
364 let mut luid = Default::default();
365 unsafe {
366 LookupPrivilegeValueA(
367 PCSTR::null(),
368 PCSTR::from_raw(name.as_ptr() as _),
369 &mut luid,
370 )
371 .map_err(|err| format!("{err}"))?
372 };
373
374 let tp = TOKEN_PRIVILEGES {
375 PrivilegeCount: 1,
376 Privileges: [LUID_AND_ATTRIBUTES {
377 Attributes: SE_PRIVILEGE_ENABLED,
378 Luid: luid,
379 }],
380 };
381 unsafe {
382 AdjustTokenPrivileges(self.1, false, Some(&tp), 0, None, None)
383 .map_err(|err| format!("{err}"))?
384 };
385
386 Ok(())
387 }
388
389 pub fn duplicate(&self) -> Result<Self, String> {
390 let mut new_token = Default::default();
391 unsafe {
392 DuplicateTokenEx(
393 self.1,
394 MAXIMUM_ALLOWED,
395 None,
396 SecurityImpersonation,
397 TokenPrimary,
398 &mut new_token,
399 )
400 .map_err(|err| format!("{err}"))?
401 };
402 Ok(Self(self.0, new_token))
403 }
404
405 pub fn is_elevated(&self) -> Result<bool, String> {
406 let mut elevation: TOKEN_ELEVATION = TOKEN_ELEVATION { TokenIsElevated: 0 };
407 let size = std::mem::size_of::<TOKEN_ELEVATION>() as u32;
408 let mut ret_size = size;
409 unsafe {
410 GetTokenInformation(
411 self.1,
412 TokenElevation,
413 Some(&mut elevation as *const _ as *mut _),
414 size,
415 &mut ret_size,
416 )
417 }
418 .map_err(|err| format!("{err}"))?;
419 Ok(elevation.TokenIsElevated != 0)
420 }
421}
422
423impl<'h> Drop for ProcessToken<'h> {
424 fn drop(&mut self) {
425 let _ = unsafe { CloseHandle(self.1) };
426 }
427}
428
429#[repr(C)]
430pub struct ProcessAccessToken {
431 token: HANDLE,
432 thread: HANDLE,
433}
434
435pub fn is_elevated() -> bool {
436 let process = ProcessHandle::from_current_process().unwrap();
437 let token = ProcessToken::open_process(&process).unwrap();
438 token.is_elevated().unwrap()
439}
440
441pub fn replace_with_current_token(pid: u32) -> Result<(), String> {
442 let current_process = ProcessHandle::from_current_process()?;
443 let desired_token = ProcessToken::open_process(¤t_process)?.duplicate()?;
444 let target_process = ProcessHandle::from_pid(pid, PROCESS_SET_INFORMATION)?;
445 target_process.replace_primary_token(&desired_token)?;
446 Ok(())
447}
448
449pub fn run_as(exe: &str, cmd: &str) {
450 let verb = CString::new("runas").unwrap();
451 let exe = CString::new(exe).unwrap();
452 let args = CString::new(cmd).unwrap();
453 unsafe {
454 ShellExecuteA(
455 HWND::default(),
456 PCSTR::from_raw(verb.as_ptr() as _),
457 PCSTR::from_raw(exe.as_ptr() as _),
458 PCSTR::from_raw(args.as_ptr() as _),
459 PCSTR::null(),
460 SW_HIDE,
461 )
462 };
463}
464
465pub fn try_execute_task<T: ElevatedOperation>(id: &str) -> Result<(), String> {
466 match ElevateToken::from_command_line() {
467 Some(ElevateToken::Execute { task_id, payload }) if id == task_id => {
468 let inst: T = serde_json::from_str(&payload).map_err(|err| format!("{err}"))?;
469 inst.run();
470 std::process::exit(0);
471 }
472 _ => Ok(()),
473 }
474}