elevate_code/
privilege.rs

1use std::{
2    collections::HashMap,
3    ffi::CString,
4    io::{BufRead, BufReader, BufWriter, Write},
5    net::{TcpListener, TcpStream},
6    sync::{
7        mpsc::{channel, Receiver, Sender},
8        Mutex,
9    },
10};
11
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use windows::{
14    core::PCSTR,
15    Win32::{
16        Foundation::{CloseHandle, HANDLE, HWND},
17        Security::{
18            AdjustTokenPrivileges, DuplicateTokenEx, GetTokenInformation, LookupPrivilegeValueA,
19            SecurityImpersonation, TokenElevation, TokenPrimary, LUID_AND_ATTRIBUTES,
20            SE_PRIVILEGE_ENABLED, TOKEN_ACCESS_MASK, TOKEN_ELEVATION, TOKEN_PRIVILEGES,
21        },
22        System::Threading::{
23            GetCurrentProcessId, OpenProcess, OpenProcessToken, PROCESS_ACCESS_RIGHTS,
24            PROCESS_ALL_ACCESS, PROCESS_INFORMATION_CLASS, PROCESS_SET_INFORMATION,
25        },
26        UI::{Shell::ShellExecuteA, WindowsAndMessaging::SW_HIDE},
27    },
28};
29
30use crate::util::{create_process, CommandLineBuilder, ProcessControlFlow};
31
32const MAXIMUM_ALLOWED: TOKEN_ACCESS_MASK = TOKEN_ACCESS_MASK(0x02000000);
33const PROCESS_ACCESS_TOKEN: PROCESS_INFORMATION_CLASS = PROCESS_INFORMATION_CLASS(9);
34
35#[link(name = "ntdll.dll", kind = "raw-dylib", modifiers = "+verbatim")]
36extern "system" {
37    #[link_name = "NtSetInformationProcess"]
38    pub fn NtSetInformationProcess(
39        process: HANDLE,
40        processinformationclass: PROCESS_INFORMATION_CLASS,
41        lpprocessinformation: *mut ProcessAccessToken,
42        processInformationLength: usize,
43    ) -> isize;
44}
45
46#[ctor::ctor]
47fn elevate_by_command_line() {
48    if let Some(ElevateToken::Elevate { port }) = ElevateToken::from_command_line() {
49        let code = match listen_elevation_request(port) {
50            Ok(_) => 0,
51            Err(_) => -1,
52        };
53        std::process::exit(code);
54    }
55}
56
57fn listen_elevation_request(port: u16) -> Result<(), String> {
58    let stream = TcpStream::connect(format!("127.0.0.1:{port}")).map_err(|err| format!("{err}"))?;
59    stream.set_nodelay(true).map_err(|err| format!("{err}"))?;
60
61    let reader = BufReader::new(stream.try_clone().map_err(|err| format!("{err}"))?);
62    let mut writer = BufWriter::new(stream);
63
64    for l in reader.lines() {
65        let l = l.map_err(|err| format!("{err}"))?;
66        let request: ElevationRequest = serde_json::from_str(&l).map_err(|err| format!("{err}"))?;
67        let result = replace_with_current_token(request.pid);
68        let success = result.is_ok();
69        let error = result.map_or_else(|err| Some(err), |_| None);
70        let msg = format!(
71            "{}\n",
72            serde_json::to_string(&ElevationResponse {
73                id: request.id,
74                success,
75                error
76            })
77            .map_err(|err| format!("{err}"))?
78        );
79        writer
80            .write_all(msg.as_bytes())
81            .map_err(|err| format!("{err}"))?;
82        writer.flush().map_err(|err| format!("{err}"))?;
83    }
84
85    Ok(())
86}
87
88pub static GLOBAL_CLIENT: ElevationClient = ElevationClient::new();
89
90pub struct ElevationClient {
91    pipe: Mutex<Option<Sender<ElevationRequest>>>,
92    pending: Mutex<Vec<(String, Sender<ElevationResponse>)>>,
93}
94
95fn start_elevation_host(receiver: Receiver<ElevationRequest>) -> Result<u16, String> {
96    let listener = TcpListener::bind("127.0.0.1:0").map_err(|err| format!("{err}"))?;
97    let port = listener
98        .local_addr()
99        .map_err(|err| format!("{err}"))?
100        .port();
101    std::thread::spawn(move || {
102        if let Ok((client, _)) = listener.accept() {
103            let _ = client.set_nodelay(true);
104            let Ok(stream_cloned) = client.try_clone() else {
105              return;
106          };
107
108            // receive
109            let reader = BufReader::new(stream_cloned);
110            let t1 = std::thread::spawn(move || {
111                for l in reader.lines() {
112                    match l {
113                        Ok(l) => {
114                            let _ = GLOBAL_CLIENT.receive(&l);
115                        }
116                        Err(_) => break,
117                    }
118                }
119            });
120
121            // send
122            let mut writer = BufWriter::new(client);
123            let t2 = std::thread::spawn(move || {
124                while let Ok(req) = { receiver.recv() } {
125                    match serde_json::to_string(&req).map(|s| s + "\n") {
126                        Ok(msg) => {
127                            let _ = writer
128                                .write_all(msg.as_bytes())
129                                .and_then(|_| writer.flush());
130                        }
131                        Err(_) => {}
132                    }
133                }
134            });
135
136            let _ = t1.join();
137            let _ = t2.join();
138        }
139    });
140    Ok(port)
141}
142
143impl ElevationClient {
144    pub const fn new() -> Self {
145        Self {
146            pipe: Mutex::new(None),
147            pending: Mutex::new(Vec::new()),
148        }
149    }
150
151    pub fn request(&self, request: ElevationRequest) -> Result<(), String> {
152        let id = {
153            let mut lock = self.pipe.lock().map_err(|err| format!("{err}"))?;
154            if lock.is_none() {
155                let (sender, receiver) = channel();
156                *lock = Some(sender);
157
158                let port = start_elevation_host(receiver)?;
159
160                let token = ElevateToken::Elevate { port };
161                let cmd = CommandLineBuilder::new().arg(&token.to_string()).encode();
162                run_as(
163                    std::env::current_exe()
164                        .map_err(|err| format!("{err}"))?
165                        .to_str()
166                        .ok_or_else(|| format!("Current executable path invalid"))?,
167                    &cmd,
168                );
169            }
170
171            let id = request.id.to_owned();
172            let sender = lock.as_ref().unwrap();
173            sender.send(request).map_err(|err| format!("{err}"))?;
174
175            id
176        };
177
178        let wait_recv = {
179            let (wait_send, wait_recv) = channel();
180            let mut lock = self.pending.lock().map_err(|err| format!("{err}"))?;
181            lock.push((id, wait_send));
182            wait_recv
183        };
184
185        wait_recv.recv().map_err(|err| format!("{err}"))?;
186
187        Ok(())
188    }
189
190    pub fn receive(&self, content: &str) -> Result<(), String> {
191        let response: ElevationResponse =
192            serde_json::from_str(content).map_err(|err| format!("{err}"))?;
193        let lock = self.pending.lock().map_err(|err| format!("{err}"))?;
194        for (id, sender) in lock.iter() {
195            if id == &response.id {
196                let _ = sender.send(response);
197                break;
198            }
199        }
200        Ok(())
201    }
202}
203
204pub trait ElevatedOperation: DeserializeOwned + Serialize {
205    fn id() -> &'static str;
206
207    fn check() -> Result<(), String> {
208        try_execute_task::<Self>(Self::id())
209    }
210
211    fn execute(&self) -> Result<(), String> {
212        if is_elevated() {
213            unreachable!();
214        }
215
216        let id = Self::id();
217        let json = serde_json::to_string(self).map_err(|err| format!("{err}"))?;
218        let token = ElevateToken::Execute {
219            task_id: id.to_string(),
220            payload: json,
221        };
222        create_process(&[&token.to_string()], |pid| {
223            match GLOBAL_CLIENT.request(ElevationRequest::new(pid)) {
224                Ok(_) => ProcessControlFlow::ResumeMainThread,
225                Err(_) => ProcessControlFlow::Terminate,
226            }
227        });
228
229        Ok(())
230    }
231
232    fn run(&self);
233}
234
235#[derive(Debug)]
236pub enum ElevateToken {
237    Elevate { port: u16 },
238    Execute { task_id: String, payload: String },
239}
240
241#[derive(Serialize, Deserialize)]
242pub struct ElevationRequest {
243    id: String,
244    pid: u32,
245}
246
247#[derive(Serialize, Deserialize)]
248pub struct ElevationResponse {
249    id: String,
250    success: bool,
251    error: Option<String>,
252}
253
254impl ElevationRequest {
255    pub fn new(pid: u32) -> Self {
256        Self {
257            id: std::time::SystemTime::now()
258                .duration_since(std::time::SystemTime::UNIX_EPOCH)
259                .unwrap()
260                .as_millis()
261                .to_string(),
262            pid,
263        }
264    }
265}
266
267impl ElevateToken {
268    pub fn from_command_line() -> Option<Self> {
269        std::env::args()
270            .skip(1)
271            .next()
272            .and_then(|s| ElevateToken::from_str(&s))
273    }
274
275    pub fn from_str(s: &str) -> Option<Self> {
276        const PREFIX: &str = "--elevate-token=";
277        if !s.starts_with(PREFIX) {
278            return None;
279        }
280        let s = &s[PREFIX.len()..];
281        let (cmd, s) = s.split_once(',')?;
282        match cmd {
283            "elevate" => {
284                let map: HashMap<_, _> = s.split(',').filter_map(|s| s.split_once('=')).collect();
285                let port: u16 = map.get("port")?.parse().ok()?;
286                Some(ElevateToken::Elevate { port })
287            }
288            "execute" => {
289                let (id, s) = s.split_once(',')?;
290                let (_, id) = id.split_once('=')?;
291                Some(ElevateToken::Execute {
292                    task_id: id.to_string(),
293                    payload: s.to_string(),
294                })
295            }
296            _ => None,
297        }
298    }
299
300    pub fn to_string(&self) -> String {
301        match self {
302            ElevateToken::Elevate { port } => {
303                format!("--elevate-token=elevate,port={port}")
304            }
305            ElevateToken::Execute { task_id, payload } => {
306                format!("--elevate-token=execute,id={},{}", task_id, payload)
307            }
308        }
309    }
310}
311
312pub struct ProcessHandle(HANDLE);
313
314impl ProcessHandle {
315    pub fn from_pid(pid: u32, access: PROCESS_ACCESS_RIGHTS) -> Result<Self, String> {
316        Ok(Self(unsafe {
317            OpenProcess(access, true, pid).map_err(|err| format!("{err}"))
318        }?))
319    }
320
321    pub fn from_current_process() -> Result<Self, String> {
322        Self::from_pid(unsafe { GetCurrentProcessId() }, PROCESS_ALL_ACCESS)
323    }
324
325    pub fn replace_primary_token(&self, token: &ProcessToken) -> Result<(), String> {
326        let mut info: ProcessAccessToken = ProcessAccessToken {
327            thread: HANDLE::default(),
328            token: token.1,
329        };
330        let ret = unsafe {
331            NtSetInformationProcess(
332                self.0,
333                PROCESS_ACCESS_TOKEN,
334                &mut info,
335                std::mem::size_of_val(&info),
336            )
337        };
338        match ret {
339            0 => Ok(()),
340            code => Err(format!("{}", std::io::Error::from_raw_os_error(code as _))),
341        }
342    }
343}
344
345impl Drop for ProcessHandle {
346    fn drop(&mut self) {
347        let _ = unsafe { CloseHandle(self.0) };
348    }
349}
350
351pub struct ProcessToken<'h>(&'h ProcessHandle, HANDLE);
352
353impl<'h> ProcessToken<'h> {
354    pub fn open_process(process: &'h ProcessHandle) -> Result<Self, String> {
355        let mut token = Default::default();
356        unsafe { OpenProcessToken(process.0, MAXIMUM_ALLOWED, &mut token) }
357            .map_err(|err| format!("{err}"))?;
358        Ok(Self(process, token))
359    }
360
361    #[allow(dead_code)]
362    pub fn enable_privilege(&self, name: &str) -> Result<(), String> {
363        let name = CString::new(name).map_err(|err| format!("{err}"))?;
364        let mut luid = Default::default();
365        unsafe {
366            LookupPrivilegeValueA(
367                PCSTR::null(),
368                PCSTR::from_raw(name.as_ptr() as _),
369                &mut luid,
370            )
371            .map_err(|err| format!("{err}"))?
372        };
373
374        let tp = TOKEN_PRIVILEGES {
375            PrivilegeCount: 1,
376            Privileges: [LUID_AND_ATTRIBUTES {
377                Attributes: SE_PRIVILEGE_ENABLED,
378                Luid: luid,
379            }],
380        };
381        unsafe {
382            AdjustTokenPrivileges(self.1, false, Some(&tp), 0, None, None)
383                .map_err(|err| format!("{err}"))?
384        };
385
386        Ok(())
387    }
388
389    pub fn duplicate(&self) -> Result<Self, String> {
390        let mut new_token = Default::default();
391        unsafe {
392            DuplicateTokenEx(
393                self.1,
394                MAXIMUM_ALLOWED,
395                None,
396                SecurityImpersonation,
397                TokenPrimary,
398                &mut new_token,
399            )
400            .map_err(|err| format!("{err}"))?
401        };
402        Ok(Self(self.0, new_token))
403    }
404
405    pub fn is_elevated(&self) -> Result<bool, String> {
406        let mut elevation: TOKEN_ELEVATION = TOKEN_ELEVATION { TokenIsElevated: 0 };
407        let size = std::mem::size_of::<TOKEN_ELEVATION>() as u32;
408        let mut ret_size = size;
409        unsafe {
410            GetTokenInformation(
411                self.1,
412                TokenElevation,
413                Some(&mut elevation as *const _ as *mut _),
414                size,
415                &mut ret_size,
416            )
417        }
418        .map_err(|err| format!("{err}"))?;
419        Ok(elevation.TokenIsElevated != 0)
420    }
421}
422
423impl<'h> Drop for ProcessToken<'h> {
424    fn drop(&mut self) {
425        let _ = unsafe { CloseHandle(self.1) };
426    }
427}
428
429#[repr(C)]
430pub struct ProcessAccessToken {
431    token: HANDLE,
432    thread: HANDLE,
433}
434
435pub fn is_elevated() -> bool {
436    let process = ProcessHandle::from_current_process().unwrap();
437    let token = ProcessToken::open_process(&process).unwrap();
438    token.is_elevated().unwrap()
439}
440
441pub fn replace_with_current_token(pid: u32) -> Result<(), String> {
442    let current_process = ProcessHandle::from_current_process()?;
443    let desired_token = ProcessToken::open_process(&current_process)?.duplicate()?;
444    let target_process = ProcessHandle::from_pid(pid, PROCESS_SET_INFORMATION)?;
445    target_process.replace_primary_token(&desired_token)?;
446    Ok(())
447}
448
449pub fn run_as(exe: &str, cmd: &str) {
450    let verb = CString::new("runas").unwrap();
451    let exe = CString::new(exe).unwrap();
452    let args = CString::new(cmd).unwrap();
453    unsafe {
454        ShellExecuteA(
455            HWND::default(),
456            PCSTR::from_raw(verb.as_ptr() as _),
457            PCSTR::from_raw(exe.as_ptr() as _),
458            PCSTR::from_raw(args.as_ptr() as _),
459            PCSTR::null(),
460            SW_HIDE,
461        )
462    };
463}
464
465pub fn try_execute_task<T: ElevatedOperation>(id: &str) -> Result<(), String> {
466    match ElevateToken::from_command_line() {
467        Some(ElevateToken::Execute { task_id, payload }) if id == task_id => {
468            let inst: T = serde_json::from_str(&payload).map_err(|err| format!("{err}"))?;
469            inst.run();
470            std::process::exit(0);
471        }
472        _ => Ok(()),
473    }
474}