use log::debug;
use crate::client::connection::Connection;
use crate::client::tree::Tree;
use crate::error::Result;
use crate::msg::change_notify::{
ChangeNotifyRequest, ChangeNotifyResponse, FILE_NOTIFY_CHANGE_ATTRIBUTES,
FILE_NOTIFY_CHANGE_CREATION, FILE_NOTIFY_CHANGE_DIR_NAME, FILE_NOTIFY_CHANGE_FILE_NAME,
FILE_NOTIFY_CHANGE_LAST_WRITE, FILE_NOTIFY_CHANGE_SIZE, SMB2_WATCH_TREE,
};
use crate::pack::{ReadCursor, Unpack};
use crate::types::status::NtStatus;
use crate::types::{Command, FileId};
use crate::Error;
const DEFAULT_COMPLETION_FILTER: u32 = FILE_NOTIFY_CHANGE_FILE_NAME
| FILE_NOTIFY_CHANGE_DIR_NAME
| FILE_NOTIFY_CHANGE_ATTRIBUTES
| FILE_NOTIFY_CHANGE_SIZE
| FILE_NOTIFY_CHANGE_LAST_WRITE
| FILE_NOTIFY_CHANGE_CREATION;
const OUTPUT_BUFFER_LENGTH: u32 = 65536;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileNotifyAction {
Added,
Removed,
Modified,
RenamedOldName,
RenamedNewName,
}
impl FileNotifyAction {
fn from_u32(value: u32) -> Result<Self> {
match value {
0x0000_0001 => Ok(FileNotifyAction::Added),
0x0000_0002 => Ok(FileNotifyAction::Removed),
0x0000_0003 => Ok(FileNotifyAction::Modified),
0x0000_0004 => Ok(FileNotifyAction::RenamedOldName),
0x0000_0005 => Ok(FileNotifyAction::RenamedNewName),
other => Err(Error::invalid_data(format!(
"unknown FILE_NOTIFY_INFORMATION action: {other:#010X}"
))),
}
}
}
impl std::fmt::Display for FileNotifyAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FileNotifyAction::Added => write!(f, "added"),
FileNotifyAction::Removed => write!(f, "removed"),
FileNotifyAction::Modified => write!(f, "modified"),
FileNotifyAction::RenamedOldName => write!(f, "renamed (old name)"),
FileNotifyAction::RenamedNewName => write!(f, "renamed (new name)"),
}
}
}
#[derive(Debug, Clone)]
pub struct FileNotifyEvent {
pub action: FileNotifyAction,
pub filename: String,
}
pub struct Watcher<'a> {
tree: &'a Tree,
conn: &'a mut Connection,
file_id: FileId,
recursive: bool,
}
impl<'a> Watcher<'a> {
pub(crate) fn new(
tree: &'a Tree,
conn: &'a mut Connection,
file_id: FileId,
recursive: bool,
) -> Self {
Watcher {
tree,
conn,
file_id,
recursive,
}
}
pub async fn next_events(&mut self) -> Result<Vec<FileNotifyEvent>> {
let flags = if self.recursive { SMB2_WATCH_TREE } else { 0 };
let req = ChangeNotifyRequest {
flags,
output_buffer_length: OUTPUT_BUFFER_LENGTH,
file_id: self.file_id,
completion_filter: DEFAULT_COMPLETION_FILTER,
};
let frame = self
.conn
.execute(Command::ChangeNotify, &req, Some(self.tree.tree_id))
.await?;
if frame.header.status == NtStatus::NOTIFY_ENUM_DIR {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::ChangeNotify,
});
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::ChangeNotify,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = ChangeNotifyResponse::unpack(&mut cursor)?;
let events = parse_notify_information(&resp.output_data)?;
debug!("watcher: received {} change event(s)", events.len());
Ok(events)
}
pub async fn close(self) -> Result<()> {
self.tree.close_handle(self.conn, self.file_id).await
}
}
fn parse_notify_information(data: &[u8]) -> Result<Vec<FileNotifyEvent>> {
let mut events = Vec::new();
let mut offset = 0usize;
if data.is_empty() {
return Ok(events);
}
loop {
if offset + 12 > data.len() {
return Err(Error::invalid_data(
"FILE_NOTIFY_INFORMATION truncated: not enough bytes for fixed fields",
));
}
let next_entry_offset =
u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
let action_raw = u32::from_le_bytes(data[offset + 4..offset + 8].try_into().unwrap());
let filename_length =
u32::from_le_bytes(data[offset + 8..offset + 12].try_into().unwrap()) as usize;
let filename_start = offset + 12;
let filename_end = filename_start + filename_length;
if filename_end > data.len() {
return Err(Error::invalid_data(format!(
"FILE_NOTIFY_INFORMATION filename extends beyond buffer: \
need {} bytes at offset {}, buffer is {} bytes",
filename_length,
filename_start,
data.len()
)));
}
let filename_bytes = &data[filename_start..filename_end];
let filename = decode_utf16le(filename_bytes)?;
let action = FileNotifyAction::from_u32(action_raw)?;
events.push(FileNotifyEvent { action, filename });
if next_entry_offset == 0 {
break;
}
offset += next_entry_offset;
}
Ok(events)
}
fn decode_utf16le(bytes: &[u8]) -> Result<String> {
if bytes.len() % 2 != 0 {
return Err(Error::invalid_data("UTF-16LE filename has odd byte count"));
}
let u16s: Vec<u16> = bytes
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
String::from_utf16(&u16s)
.map_err(|e| Error::invalid_data(format!("invalid UTF-16LE filename: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_single_notify_entry() {
let filename = "test.txt";
let utf16: Vec<u16> = filename.encode_utf16().collect();
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let filename_len = filename_bytes.len() as u32;
let mut data = Vec::new();
data.extend_from_slice(&0u32.to_le_bytes());
data.extend_from_slice(&1u32.to_le_bytes());
data.extend_from_slice(&filename_len.to_le_bytes());
data.extend_from_slice(&filename_bytes);
let events = parse_notify_information(&data).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].action, FileNotifyAction::Added);
assert_eq!(events[0].filename, "test.txt");
}
#[test]
fn parse_multiple_notify_entries() {
let build_entry = |name: &str, action: u32, is_last: bool| -> Vec<u8> {
let utf16: Vec<u16> = name.encode_utf16().collect();
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let filename_len = filename_bytes.len() as u32;
let mut entry = Vec::new();
let entry_size = 12 + filename_bytes.len();
let aligned_size = (entry_size + 3) & !3;
let next_offset = if is_last { 0u32 } else { aligned_size as u32 };
entry.extend_from_slice(&next_offset.to_le_bytes());
entry.extend_from_slice(&action.to_le_bytes());
entry.extend_from_slice(&filename_len.to_le_bytes());
entry.extend_from_slice(&filename_bytes);
while entry.len() < aligned_size {
entry.push(0);
}
entry
};
let mut data = Vec::new();
data.extend_from_slice(&build_entry("added.txt", 1, false));
data.extend_from_slice(&build_entry("removed.txt", 2, true));
let events = parse_notify_information(&data).unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].action, FileNotifyAction::Added);
assert_eq!(events[0].filename, "added.txt");
assert_eq!(events[1].action, FileNotifyAction::Removed);
assert_eq!(events[1].filename, "removed.txt");
}
#[test]
fn parse_empty_buffer_returns_no_events() {
let events = parse_notify_information(&[]).unwrap();
assert!(events.is_empty());
}
#[test]
fn parse_truncated_buffer_returns_error() {
let data = vec![0u8; 8];
let result = parse_notify_information(&data);
assert!(result.is_err());
}
#[test]
fn decode_utf16le_basic() {
let input = "hello";
let utf16: Vec<u16> = input.encode_utf16().collect();
let bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let result = decode_utf16le(&bytes).unwrap();
assert_eq!(result, "hello");
}
#[test]
fn decode_utf16le_non_ascii() {
let input = "photos/\u{00E9}t\u{00E9}";
let utf16: Vec<u16> = input.encode_utf16().collect();
let bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let result = decode_utf16le(&bytes).unwrap();
assert_eq!(result, input);
}
#[test]
fn decode_utf16le_odd_bytes_is_error() {
let result = decode_utf16le(&[0x41, 0x00, 0x42]);
assert!(result.is_err());
}
#[test]
fn file_notify_action_display() {
assert_eq!(format!("{}", FileNotifyAction::Added), "added");
assert_eq!(format!("{}", FileNotifyAction::Removed), "removed");
assert_eq!(format!("{}", FileNotifyAction::Modified), "modified");
assert_eq!(
format!("{}", FileNotifyAction::RenamedOldName),
"renamed (old name)"
);
assert_eq!(
format!("{}", FileNotifyAction::RenamedNewName),
"renamed (new name)"
);
}
#[test]
fn file_notify_action_from_u32_unknown_is_error() {
let result = FileNotifyAction::from_u32(0x9999);
assert!(result.is_err());
}
}