silverwind/
lib.rs

1#![no_std]
2
3mod sys;
4use core::marker::*;
5use core::mem::*;
6use core::ptr::*;
7use core::sync::atomic::*;
8
9use sys::*;
10
11static G_FILTER: AtomicPtr<core::ffi::c_void> = AtomicPtr::new(null_mut());
12
13#[cfg(not(test))]
14mod panic_handler {
15    use super::*;
16    #[panic_handler]
17    fn panic(_info: &core::panic::PanicInfo) -> ! {
18        unsafe { DbgPrint("PANIC!\n".as_ptr()) };
19        loop {}
20    }
21}
22
23struct ContextRc {
24    context: POpaque,
25}
26
27impl ContextRc {
28    pub fn wrap(context: POpaque) -> ContextRc {
29        ContextRc { context }
30    }
31    pub fn raw(&self) -> POpaque {
32        self.context
33    }
34}
35
36impl Clone for ContextRc {
37    fn clone(&self) -> Self {
38        unsafe { FltReferenceContext(self.context) }
39        ContextRc {
40            context: self.context,
41        }
42    }
43}
44
45impl Drop for ContextRc {
46    fn drop(&mut self) {
47        unsafe { FltReleaseContext(self.context) }
48    }
49}
50
51trait ContextTypeTag {
52    fn sys_context_type() -> u16;
53    fn sys_tag() -> u32;
54}
55
56#[derive(Clone)]
57struct InstanceContextTag;
58impl ContextTypeTag for InstanceContextTag {
59    fn sys_context_type() -> u16 {
60        2
61    }
62    fn sys_tag() -> u32 {
63        0x40
64    }
65}
66
67#[derive(Clone)]
68struct Context<T, ContextType> {
69    rc: ContextRc,
70    phantom_t: core::marker::PhantomData<T>,
71    phantom_context: core::marker::PhantomData<ContextType>,
72}
73
74extern "C" fn context_cleanup<T>(context: POpaque, _context_type: u16) {
75    unsafe { drop_in_place(context as *mut T) }
76}
77
78impl<T: 'static + Send + Sync, ContextType: ContextTypeTag> Context<T, ContextType> {
79    fn context_size() -> u64 {
80        // TODO: align
81        let size = size_of::<T>();
82        size as u64
83    }
84
85    fn registration() -> FLT_CONTEXT_REGISTRATION {
86        let cleanup = if needs_drop::<T>() {
87            context_cleanup::<T> as *const _
88        } else {
89            null()
90        };
91
92        FLT_CONTEXT_REGISTRATION {
93            ContextType: ContextType::sys_context_type(),
94            Flags: 0,
95            ContextCleanupCallback: cleanup,
96            Size: Self::context_size(),
97            PoolTag: ContextType::sys_tag(),
98            ContextAllocateCallback: null(),
99            ContextFreeCallback: null(),
100            Reserved1: null_mut(),
101        }
102    }
103
104    pub fn wrap(context: POpaque) -> Self {
105        Context {
106            rc: ContextRc { context },
107            phantom_t: PhantomData,
108            phantom_context: PhantomData,
109        }
110    }
111
112    pub fn new(filter: POpaque, value: T, pool_type: i32) -> Result<Self, (T, NTSTATUS)> {
113        unsafe {
114            // TODO: align
115            let mut context = null_mut();
116            let status = FltAllocateContext(
117                filter,
118                ContextType::sys_context_type(),
119                Self::context_size(),
120                pool_type,
121                addr_of_mut!(context),
122            );
123            if status < 0 {
124                return Err((value, status));
125            }
126            write(context as *mut T, value);
127            Ok(Context {
128                rc: ContextRc { context },
129                phantom_t: PhantomData,
130                phantom_context: PhantomData,
131            })
132        }
133    }
134
135    pub fn get(&self) -> &T {
136        unsafe { &*(self.rc.context as *const T) }
137    }
138
139    pub fn raw(&self) -> POpaque {
140        self.rc.raw()
141    }
142}
143
144struct InstanceContextInner {
145    total_creates: AtomicU32,
146}
147
148type InstanceContext = Context<InstanceContextInner, InstanceContextTag>;
149
150struct CallbackPackage {
151    data: POpaque,
152    flt_objects: *const FLT_RELATED_OBJECTS,
153}
154
155enum PreOpResult<T> {
156    SuccessWithCallback(T),
157    SuccessNoCallback,
158}
159
160enum PostOpResult {
161    FinishedProcessing,
162}
163
164trait CompletionContext {
165    fn into_ptr(self) -> POpaque;
166    fn from_ptr(ptr: POpaque) -> Self;
167}
168
169struct PoolBox<T> {
170    raw: *mut T,
171    phantom: PhantomData<T>,
172}
173
174impl<T> CompletionContext for PoolBox<T> {
175    fn into_ptr(self) -> POpaque {
176        let raw = self.raw as POpaque;
177        forget(self);
178        raw
179    }
180    fn from_ptr(ptr: POpaque) -> Self {
181        Self {
182            raw: ptr as *mut T,
183            phantom: PhantomData,
184        }
185    }
186}
187
188impl<T> Drop for PoolBox<T> {
189    fn drop(&mut self) {
190        unsafe { ExFreePoolWithTag(self.raw as POpaque, 0) }
191    }
192}
193
194impl<T> PoolBox<T> {
195    fn from_raw(raw: *mut T) -> Self {
196        Self {
197            raw,
198            phantom: PhantomData,
199        }
200    }
201
202    fn get_raw(&self) -> *mut T {
203        self.raw
204    }
205}
206
207trait FilterInterface {
208    type CreateCompletionContext: CompletionContext;
209    fn pre_create(callback_package: &CallbackPackage)
210        -> PreOpResult<Self::CreateCompletionContext>;
211    fn post_create(
212        callback_package: &CallbackPackage,
213        completion_context: Self::CreateCompletionContext,
214    ) -> PostOpResult;
215}
216
217struct FilterImpl;
218
219impl FilterInterface for FilterImpl {
220    type CreateCompletionContext = PoolBox<UNICODE_STRING>;
221    fn pre_create(
222        callback_package: &CallbackPackage,
223    ) -> PreOpResult<Self::CreateCompletionContext> {
224        unsafe {
225            if (*(*callback_package.flt_objects).FileObject).Flags & 0x00400000 != 0 {
226                // volume open
227                return PreOpResult::SuccessNoCallback;
228            }
229
230            let mut raw_context: POpaque = null_mut();
231            let status = FltGetInstanceContext(
232                (*callback_package.flt_objects).Instance,
233                addr_of_mut!(raw_context),
234            );
235            if status < 0 {
236                DbgPrint(
237                    "pre_create: FltGetInstanceContext failed (0x%x)\n\0".as_ptr(),
238                    status,
239                );
240                return PreOpResult::SuccessNoCallback;
241            }
242            let instance_context = InstanceContext::wrap(raw_context);
243            let total_creates = instance_context
244                .get()
245                .total_creates
246                .fetch_add(1, Ordering::SeqCst)
247                + 1;
248            DbgPrint("pre_create: %u created so far\n\0".as_ptr(), total_creates);
249
250            let process = FltGetRequestorProcess(callback_package.data);
251            let mut image_file_name: *mut UNICODE_STRING = null_mut();
252            let status = SeLocateProcessImageName(process, addr_of_mut!(image_file_name));
253            if status < 0 {
254                DbgPrint(
255                    "pre_create: SeLocateProcessImageName  failed (0x%x)\n\0".as_ptr(),
256                    status,
257                );
258                return PreOpResult::SuccessNoCallback;
259            }
260            let image_file_name = PoolBox::from_raw(image_file_name);
261
262            DbgPrint(
263                "pre_create: process %wZ\n\0".as_ptr(),
264                image_file_name.get_raw(),
265            );
266
267            PreOpResult::SuccessWithCallback(image_file_name)
268        }
269    }
270
271    fn post_create(
272        callback_package: &CallbackPackage,
273        image_file_name: Self::CreateCompletionContext,
274    ) -> PostOpResult {
275        unsafe {
276            DbgPrint(
277                "post_create: process %wZ\n\0".as_ptr(),
278                image_file_name.get_raw(),
279            );
280
281            PostOpResult::FinishedProcessing
282        }
283    }
284}
285
286struct Filter<Interface> {
287    filter: POpaque,
288    marker: core::marker::PhantomData<Interface>,
289}
290
291impl<Interface> Filter<Interface> {
292    unsafe fn start(self) -> Result<(), NTSTATUS> {
293        let status = FltStartFiltering(self.filter);
294        if status < 0 {
295            FltUnregisterFilter(self.filter);
296            return Err(status);
297        }
298
299        Ok(())
300    }
301}
302
303unsafe extern "C" fn flt_unload(_flags: u32) -> NTSTATUS {
304    DbgPrint("flt_unload\n\0".as_ptr());
305    FltUnregisterFilter(G_FILTER.load(core::sync::atomic::Ordering::Acquire));
306    0
307}
308
309unsafe extern "C" fn flt_instance_setup(
310    flt_objects: *const FLT_RELATED_OBJECTS,
311    _flags: u32,
312    _volume_device_type: u32,
313    volume_filesystem_type: i32,
314) -> NTSTATUS {
315    if volume_filesystem_type != 3 {
316        DbgPrint(
317            "flt_instance_setup: File System is not FAT (0x%x), ignoring\n\0".as_ptr(),
318            volume_filesystem_type,
319        );
320        return 0xC01C000Fu32 as NTSTATUS;
321    }
322
323    DbgPrint("flt_instance_setup: Attaching to FAT volume\n\0".as_ptr());
324
325    let context = match InstanceContext::new(
326        (*flt_objects).Filter,
327        InstanceContextInner {
328            total_creates: AtomicU32::new(0),
329        },
330        1,
331    ) {
332        Ok(context) => context,
333        Err((_, status)) => {
334            DbgPrint(
335                "flt_instance_setup: context allocation failed (0x%x)\n\0".as_ptr(),
336                status,
337            );
338            return status;
339        }
340    };
341
342    let status = FltSetInstanceContext((*flt_objects).Instance, 0, context.raw(), null_mut());
343
344    if status < 0 {
345        DbgPrint(
346            "flt_instance_setup: FltSetInstanceContext failed (0x%x)\n\0".as_ptr(),
347            status,
348        );
349        return status;
350    }
351
352    0
353}
354
355unsafe extern "C" fn flt_query_teardown(flt_objects: POpaque, _flags: u32) -> NTSTATUS {
356    DbgPrint("flt_query_teardown\n\0".as_ptr());
357
358    0
359}
360
361unsafe extern "C" fn pre_create<Interface: FilterInterface>(
362    data: POpaque,
363    flt_objects: *const FLT_RELATED_OBJECTS,
364    completion_context: *mut POpaque,
365) -> i32 {
366    let callback_package = CallbackPackage { data, flt_objects };
367
368    match Interface::pre_create(&callback_package) {
369        PreOpResult::SuccessWithCallback(completion_context_value) => {
370            *completion_context = completion_context_value.into_ptr();
371            0
372        }
373        PreOpResult::SuccessNoCallback => 1,
374    }
375}
376
377unsafe extern "C" fn post_create<Interface: FilterInterface>(
378    data: POpaque,
379    flt_objects: *const FLT_RELATED_OBJECTS,
380    completion_context: POpaque,
381    flags: u32,
382) -> i32 {
383    let callback_package = CallbackPackage { data, flt_objects };
384    let completion_context = Interface::CreateCompletionContext::from_ptr(completion_context);
385
386    match Interface::post_create(&callback_package, completion_context) {
387        PostOpResult::FinishedProcessing => 0,
388    }
389}
390
391unsafe fn register_filter<Interface: FilterInterface>(
392    driver_object: POpaque,
393) -> Result<Filter<Interface>, NTSTATUS> {
394    let context_registration = [
395        <InstanceContext>::registration(),
396        FLT_CONTEXT_REGISTRATION {
397            ContextType: 0xFFFF,
398            Flags: 0,
399            ContextCleanupCallback: core::ptr::null(),
400            Size: 0,
401            PoolTag: 0,
402            ContextAllocateCallback: core::ptr::null(),
403            ContextFreeCallback: core::ptr::null(),
404            Reserved1: core::ptr::null_mut(),
405        },
406    ];
407    let operation_registration = [
408        FLT_OPERATION_REGISTRATION {
409            MajorFunction: 0, //create
410            Flags: 0,
411            PreOperation: pre_create::<Interface> as *const _,
412            PostOperation: post_create::<Interface> as *const _,
413            Reserved1: core::ptr::null_mut(),
414        },
415        FLT_OPERATION_REGISTRATION {
416            MajorFunction: 0x80,
417            Flags: 0,
418            PreOperation: core::ptr::null(),
419            PostOperation: core::ptr::null(),
420            Reserved1: core::ptr::null_mut(),
421        },
422    ];
423
424    let registration = FLT_REGISTRATION {
425        Size: core::mem::size_of::<FLT_REGISTRATION>() as u16,
426        Version: 0x0202, // ehh
427        Flags: 0,
428        ContextRegistration: context_registration.as_ptr(),
429        OperationRegistration: operation_registration.as_ptr(),
430        FilterUnloadCallback: flt_unload as *const _,
431        InstanceSetupCallback: flt_instance_setup as *const _,
432        InstanceQueryTeardownCallback: flt_query_teardown as *const _,
433        InstanceTeardownStartCallback: core::ptr::null(),
434        InstanceTeardownCompleteCallback: core::ptr::null(),
435        GenerateFileNameCallback: core::ptr::null(),
436        NormalizeNameComponentCallback: core::ptr::null(),
437        NormalizeContextCleanupCallback: core::ptr::null(),
438        TransactionNotificationCallback: core::ptr::null(),
439        NormalizeNameComponentExCallback: core::ptr::null(),
440        // SectionNotificationCallback: core::ptr::null(),
441    };
442
443    let mut filter: POpaque = core::ptr::null_mut();
444    let status = FltRegisterFilter(driver_object, addr_of!(registration), addr_of_mut!(filter));
445    if status < 0 {
446        return Err(status);
447    }
448    G_FILTER.store(filter, core::sync::atomic::Ordering::Release);
449    Ok(Filter {
450        filter,
451        marker: core::marker::PhantomData,
452    })
453}
454
455#[export_name = "entry"]
456pub unsafe extern "C" fn entry(driver_object: POpaque) -> NTSTATUS {
457    DbgPrint("Silverwind - Loaded\n\0".as_ptr());
458
459    let filter = match register_filter::<FilterImpl>(driver_object) {
460        Ok(filter) => filter,
461        Err(status) => {
462            DbgPrint("register failed! 0x%x\n\0".as_ptr(), status);
463            return status;
464        }
465    };
466
467    if let Err(status) = filter.start() {
468        DbgPrint("start failed! 0x%x\n\0".as_ptr(), status);
469        return status;
470    }
471
472    0
473}