use log::{debug, info};
use crate::client::connection::Connection;
use crate::error::Result;
use crate::msg::close::CloseRequest;
use crate::msg::create::{
CreateDisposition, CreateRequest, CreateResponse, ImpersonationLevel, ShareAccess,
};
use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE};
use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse};
use crate::msg::tree_disconnect::TreeDisconnectRequest;
use crate::msg::write::{WriteRequest, WriteResponse};
use crate::pack::{ReadCursor, Unpack};
use crate::rpc;
use crate::rpc::srvsvc::{self, ShareInfo};
use crate::types::flags::FileAccessMask;
use crate::types::status::NtStatus;
use crate::types::{Command, FileId, OplockLevel, TreeId};
use crate::Error;
const PIPE_READ_BUFFER_SIZE: u32 = 65536;
pub async fn list_shares(conn: &mut Connection) -> Result<Vec<ShareInfo>> {
let tree_id = tree_connect_ipc(conn).await?;
let result = pipe_rpc_exchange(conn, tree_id).await;
let _ = tree_disconnect(conn, tree_id).await;
let all_shares = result?;
let filtered = srvsvc::filter_disk_shares(all_shares);
info!("shares: found {} disk shares", filtered.len());
Ok(filtered)
}
async fn tree_connect_ipc(conn: &mut Connection) -> Result<TreeId> {
let server = conn.server_name().to_string();
let unc_path = format!(r"\\{}\IPC$", server);
let req = TreeConnectRequest {
flags: TreeConnectRequestFlags::default(),
path: unc_path,
};
let frame = conn.execute(Command::TreeConnect, &req, None).await?;
if frame.header.command != Command::TreeConnect {
return Err(Error::invalid_data(format!(
"expected TreeConnect response, got {:?}",
frame.header.command
)));
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeConnect,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let _resp = TreeConnectResponse::unpack(&mut cursor)?;
let tree_id = frame
.header
.tree_id
.ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?;
info!("shares: connected to IPC$, tree_id={}", tree_id);
Ok(tree_id)
}
async fn pipe_rpc_exchange(conn: &mut Connection, tree_id: TreeId) -> Result<Vec<ShareInfo>> {
let file_id = open_srvsvc_pipe(conn, tree_id).await?;
let result = rpc_bind_and_request(conn, tree_id, file_id).await;
let _ = close_handle(conn, tree_id, file_id).await;
result
}
async fn rpc_bind_and_request(
conn: &mut Connection,
tree_id: TreeId,
file_id: FileId,
) -> Result<Vec<ShareInfo>> {
let bind_data = rpc::build_srvsvc_bind(1);
write_pipe(conn, tree_id, file_id, &bind_data).await?;
debug!("shares: sent RPC BIND ({} bytes)", bind_data.len());
let bind_ack_data = read_pipe(conn, tree_id, file_id).await?;
rpc::parse_bind_ack(&bind_ack_data)?;
debug!("shares: received BIND_ACK, context accepted");
let server_name = format!(r"\\{}", conn.server_name());
let request_data = srvsvc::build_net_share_enum_all(2, &server_name);
write_pipe(conn, tree_id, file_id, &request_data).await?;
debug!(
"shares: sent NetShareEnumAll request ({} bytes)",
request_data.len()
);
let response_data = read_pipe(conn, tree_id, file_id).await?;
let shares = srvsvc::parse_net_share_enum_all_response(&response_data)?;
debug!("shares: received {} shares in response", shares.len());
Ok(shares)
}
async fn open_srvsvc_pipe(conn: &mut Connection, tree_id: TreeId) -> Result<FileId> {
let req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_DATA | FileAccessMask::FILE_WRITE_DATA,
),
file_attributes: 0,
share_access: ShareAccess(ShareAccess::FILE_SHARE_READ | ShareAccess::FILE_SHARE_WRITE),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: r"srvsvc".to_string(),
create_contexts: vec![],
};
let frame = conn.execute(Command::Create, &req, Some(tree_id)).await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = CreateResponse::unpack(&mut cursor)?;
debug!("shares: opened srvsvc pipe, file_id={:?}", resp.file_id);
Ok(resp.file_id)
}
async fn write_pipe(
conn: &mut Connection,
tree_id: TreeId,
file_id: FileId,
data: &[u8],
) -> Result<()> {
let req = WriteRequest {
data_offset: 0x70,
offset: 0,
file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: data.to_vec(),
};
let frame = conn.execute(Command::Write, &req, Some(tree_id)).await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = WriteResponse::unpack(&mut cursor)?;
debug!("shares: wrote {} bytes to pipe", resp.count);
Ok(())
}
async fn read_pipe(conn: &mut Connection, tree_id: TreeId, file_id: FileId) -> Result<Vec<u8>> {
let req = ReadRequest {
padding: 0x50,
flags: 0,
length: PIPE_READ_BUFFER_SIZE,
offset: 0,
file_id,
minimum_count: 0,
channel: SMB2_CHANNEL_NONE,
remaining_bytes: 0,
read_channel_info: vec![],
};
let frame = conn.execute(Command::Read, &req, Some(tree_id)).await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Read,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = ReadResponse::unpack(&mut cursor)?;
debug!("shares: read {} bytes from pipe", resp.data.len());
Ok(resp.data)
}
async fn close_handle(conn: &mut Connection, tree_id: TreeId, file_id: FileId) -> Result<()> {
let req = CloseRequest { flags: 0, file_id };
let frame = conn.execute(Command::Close, &req, Some(tree_id)).await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Close,
});
}
Ok(())
}
async fn tree_disconnect(conn: &mut Connection, tree_id: TreeId) -> Result<()> {
let body = TreeDisconnectRequest;
let frame = conn
.execute(Command::TreeDisconnect, &body, Some(tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeDisconnect,
});
}
info!("shares: disconnected from IPC$");
Ok(())
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::client::connection::{pack_message, NegotiatedParams};
use crate::client::test_helpers::{
build_close_response, build_create_response, build_tree_connect_response, setup_connection,
};
use crate::msg::header::Header;
use crate::msg::read::ReadResponse as ReadResp;
use crate::msg::tree_connect::ShareType;
use crate::msg::tree_disconnect::TreeDisconnectResponse;
use crate::msg::write::WriteResponse as WriteResp;
use crate::pack::Guid;
use crate::rpc::srvsvc::{STYPE_DISKTREE, STYPE_IPC, STYPE_SPECIAL};
use crate::transport::MockTransport;
use crate::types::flags::Capabilities;
use crate::types::{Dialect, SessionId, TreeId};
use std::sync::Arc;
fn build_write_response(count: u32) -> Vec<u8> {
let mut h = Header::new_request(Command::Write);
h.flags.set_response();
h.credits = 32;
let body = WriteResp {
count,
remaining: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
};
pack_message(&h, &body)
}
fn build_read_response(data: Vec<u8>) -> Vec<u8> {
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 32;
let body = ReadResp {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data,
};
pack_message(&h, &body)
}
fn build_tree_disconnect_response() -> Vec<u8> {
let mut h = Header::new_request(Command::TreeDisconnect);
h.flags.set_response();
h.credits = 32;
pack_message(&h, &TreeDisconnectResponse)
}
fn build_bind_ack() -> Vec<u8> {
use crate::pack::WriteCursor;
let mut w = WriteCursor::with_capacity(64);
w.write_u8(5); w.write_u8(0); w.write_u8(12); w.write_u8(0x03); w.write_bytes(&[0x10, 0x00, 0x00, 0x00]); let frag_len_pos = w.position();
w.write_u16_le(0); w.write_u16_le(0); w.write_u32_le(1);
w.write_u16_le(4280); w.write_u16_le(4280); w.write_u32_le(0x12345);
w.write_u16_le(0);
w.write_bytes(&[0, 0]);
w.write_u8(1); w.write_bytes(&[0, 0, 0]); w.write_u16_le(0); w.write_u16_le(0);
use crate::pack::Pack;
let ndr_uuid = Guid {
data1: 0x8A885D04,
data2: 0x1CEB,
data3: 0x11C9,
data4: [0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60],
};
ndr_uuid.pack(&mut w);
w.write_u32_le(2);
let total_len = w.position();
w.set_u16_le_at(frag_len_pos, total_len as u16);
w.into_inner()
}
fn build_share_enum_response(shares: &[(&str, u32, &str)]) -> Vec<u8> {
use crate::pack::WriteCursor;
let mut w = WriteCursor::with_capacity(512);
let count = shares.len() as u32;
w.write_u32_le(1);
w.write_u32_le(1);
if count == 0 {
w.write_u32_le(0); w.write_u32_le(0); w.write_u32_le(0); w.write_u32_le(0); } else {
w.write_u32_le(0x0002_0000);
w.write_u32_le(count);
w.write_u32_le(0x0002_0004);
w.write_u32_le(count);
for (i, &(_, share_type, _)) in shares.iter().enumerate() {
w.write_u32_le(0x0002_0008 + (i as u32) * 2); w.write_u32_le(share_type);
w.write_u32_le(0x0002_0108 + (i as u32) * 2); }
for &(name, _, comment) in shares {
write_ndr_string(&mut w, name);
write_ndr_string(&mut w, comment);
}
w.write_u32_le(count); w.write_u32_le(0); w.write_u32_le(0); }
let stub = w.into_inner();
let mut w2 = WriteCursor::with_capacity(24 + stub.len());
w2.write_u8(5);
w2.write_u8(0);
w2.write_u8(2); w2.write_u8(0x03);
w2.write_bytes(&[0x10, 0x00, 0x00, 0x00]);
let frag_len_pos = w2.position();
w2.write_u16_le(0);
w2.write_u16_le(0);
w2.write_u32_le(2);
w2.write_u32_le(stub.len() as u32); w2.write_u16_le(0); w2.write_u8(0); w2.write_u8(0);
w2.write_bytes(&stub);
let total_len = w2.position();
w2.set_u16_le_at(frag_len_pos, total_len as u16);
w2.into_inner()
}
fn write_ndr_string(w: &mut crate::pack::WriteCursor, s: &str) {
let utf16: Vec<u16> = s.encode_utf16().chain(std::iter::once(0)).collect();
let char_count = utf16.len() as u32;
w.write_u32_le(char_count);
w.write_u32_le(0);
w.write_u32_le(char_count);
for &code_unit in &utf16 {
w.write_u16_le(code_unit);
}
w.align_to(4);
}
pub(crate) fn queue_share_listing_responses(
mock: &MockTransport,
shares: &[(&str, u32, &str)],
) {
let tree_id = TreeId(42);
let file_id = FileId {
persistent: 0xAAAA,
volatile: 0xBBBB,
};
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(72));
mock.queue_response(build_read_response(build_bind_ack()));
mock.queue_response(build_write_response(100));
mock.queue_response(build_read_response(build_share_enum_response(shares)));
mock.queue_response(build_close_response());
mock.queue_response(build_tree_disconnect_response());
}
#[tokio::test]
async fn list_shares_returns_disk_shares() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_share_listing_responses(
&mock,
&[
("Documents", STYPE_DISKTREE, "Shared docs"),
("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"),
("C$", STYPE_DISKTREE | STYPE_SPECIAL, "Default share"),
("Photos", STYPE_DISKTREE, "Family photos"),
],
);
let shares = list_shares(&mut conn).await.unwrap();
assert_eq!(shares.len(), 2);
assert_eq!(shares[0].name, "Documents");
assert_eq!(shares[0].comment, "Shared docs");
assert_eq!(shares[1].name, "Photos");
assert_eq!(shares[1].comment, "Family photos");
}
#[tokio::test]
async fn list_shares_sends_correct_number_of_messages() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_share_listing_responses(&mock, &[("TestShare", STYPE_DISKTREE, "A test share")]);
let _shares = list_shares(&mut conn).await.unwrap();
assert_eq!(mock.sent_count(), 8);
}
#[tokio::test]
async fn list_shares_empty_server() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_share_listing_responses(&mock, &[]);
let shares = list_shares(&mut conn).await.unwrap();
assert!(shares.is_empty());
}
#[tokio::test]
async fn list_shares_filters_non_disk_shares() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_share_listing_responses(
&mock,
&[
("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"),
("ADMIN$", STYPE_DISKTREE | STYPE_SPECIAL, "Remote Admin"),
],
);
let shares = list_shares(&mut conn).await.unwrap();
assert!(shares.is_empty());
}
#[tokio::test]
async fn list_shares_uses_correct_server_name() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn =
Connection::from_transport(Box::new(mock.clone()), Box::new(mock.clone()), "my-nas");
conn.set_test_params(NegotiatedParams {
dialect: Dialect::Smb2_0_2,
max_read_size: 65536,
max_write_size: 65536,
max_transact_size: 65536,
server_guid: Guid::ZERO,
signing_required: false,
capabilities: Capabilities::default(),
gmac_negotiated: false,
cipher: None,
compression_supported: false,
});
conn.set_session_id(SessionId(0x1234));
queue_share_listing_responses(&mock, &[("share1", STYPE_DISKTREE, "")]);
let shares = list_shares(&mut conn).await.unwrap();
assert_eq!(shares.len(), 1);
let sent = mock.sent_messages();
let tree_connect_bytes = &sent[0];
let unc_utf8 = String::from_utf8_lossy(tree_connect_bytes);
assert!(
tree_connect_bytes.windows(2).any(|w| w == b"m\0"), "TREE_CONNECT should reference the server name"
);
drop(unc_utf8);
}
}