use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::vec::Vec;
use lazy_static::lazy_static;
use spin::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SmbVersion {
Smb20,
Smb21,
Smb30,
Smb302,
Smb311,
}
impl SmbVersion {
pub fn name(&self) -> &'static str {
match self {
SmbVersion::Smb20 => "SMB 2.0",
SmbVersion::Smb21 => "SMB 2.1",
SmbVersion::Smb30 => "SMB 3.0",
SmbVersion::Smb302 => "SMB 3.02",
SmbVersion::Smb311 => "SMB 3.11",
}
}
pub fn supports_multichannel(&self) -> bool {
*self >= SmbVersion::Smb30
}
pub fn supports_rdma(&self) -> bool {
*self >= SmbVersion::Smb30
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelType {
Tcp,
Rdma,
}
#[derive(Debug, Clone)]
pub struct SmbChannel {
pub id: u64,
pub channel_type: ChannelType,
pub bandwidth: u64,
pub active: bool,
pub bytes_transferred: u64,
pub last_activity: u64,
}
impl SmbChannel {
pub fn new(id: u64, channel_type: ChannelType, bandwidth: u64) -> Self {
Self {
id,
channel_type,
bandwidth,
active: true,
bytes_transferred: 0,
last_activity: 0,
}
}
pub fn record_transfer(&mut self, bytes: u64, timestamp: u64) {
self.bytes_transferred += bytes;
self.last_activity = timestamp;
}
}
#[derive(Debug, Clone)]
pub struct SmbSession {
pub id: u64,
pub client_addr: String,
pub version: SmbVersion,
pub channels: Vec<SmbChannel>,
pub user: Option<String>,
pub start_time: u64,
}
impl SmbSession {
pub fn new(id: u64, client_addr: String, version: SmbVersion, timestamp: u64) -> Self {
Self {
id,
client_addr,
version,
channels: Vec::new(),
user: None,
start_time: timestamp,
}
}
pub fn add_channel(&mut self, channel: SmbChannel) {
self.channels.push(channel);
}
pub fn total_bandwidth(&self) -> u64 {
self.channels
.iter()
.filter(|c| c.active)
.map(|c| c.bandwidth)
.sum()
}
pub fn active_channels(&self) -> usize {
self.channels.iter().filter(|c| c.active).count()
}
pub fn select_channel(&mut self) -> Option<&mut SmbChannel> {
self.channels
.iter_mut()
.filter(|c| c.active)
.min_by_key(|c| c.bytes_transferred)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct SmbFileHandle(pub u64);
#[derive(Debug, Clone)]
pub struct SmbShare {
pub name: String,
pub dataset_id: u64,
pub read_only: bool,
pub guest_ok: bool,
pub browseable: bool,
}
impl SmbShare {
pub fn new(name: String, dataset_id: u64) -> Self {
Self {
name,
dataset_id,
read_only: false,
guest_ok: false,
browseable: true,
}
}
pub fn read_only(mut self) -> Self {
self.read_only = true;
self
}
pub fn allow_guest(mut self) -> Self {
self.guest_ok = true;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct SmbStats {
pub total_sessions: u64,
pub active_sessions: u64,
pub total_channels: u64,
pub reads: u64,
pub writes: u64,
pub bytes_read: u64,
pub bytes_written: u64,
pub multichannel_sessions: u64,
pub rdma_channels: u64,
}
lazy_static! {
static ref SMB_SERVER: Mutex<SmbServer> = Mutex::new(SmbServer::new());
}
pub struct SmbServer {
shares: Vec<SmbShare>,
sessions: BTreeMap<u64, SmbSession>,
file_handles: BTreeMap<SmbFileHandle, (u64, u64)>, next_session_id: u64,
next_fh: u64,
next_channel_id: u64,
stats: SmbStats,
}
impl Default for SmbServer {
fn default() -> Self {
Self::new()
}
}
impl SmbServer {
pub fn new() -> Self {
Self {
shares: Vec::new(),
sessions: BTreeMap::new(),
file_handles: BTreeMap::new(),
next_session_id: 1,
next_fh: 1,
next_channel_id: 1,
stats: SmbStats::default(),
}
}
pub fn add_share(&mut self, share: SmbShare) {
crate::lcpfs_println!(
"[ SMB ] Shared \\\\server\\{} (dataset: {}, ro: {})",
share.name,
share.dataset_id,
share.read_only
);
self.shares.push(share);
}
pub fn create_session(
&mut self,
client_addr: String,
version: SmbVersion,
timestamp: u64,
) -> u64 {
let session_id = self.next_session_id;
self.next_session_id += 1;
let session = SmbSession::new(session_id, client_addr.clone(), version, timestamp);
crate::lcpfs_println!(
"[ SMB ] Session {} created from {} ({})",
session_id,
client_addr,
version.name()
);
self.sessions.insert(session_id, session);
self.stats.total_sessions += 1;
self.stats.active_sessions += 1;
session_id
}
pub fn add_channel(
&mut self,
session_id: u64,
channel_type: ChannelType,
bandwidth: u64,
) -> Result<u64, &'static str> {
let session = self
.sessions
.get_mut(&session_id)
.ok_or("Session not found")?;
if !session.version.supports_multichannel() {
return Err("Version does not support multichannel");
}
let channel_id = self.next_channel_id;
self.next_channel_id += 1;
let channel = SmbChannel::new(channel_id, channel_type, bandwidth);
session.add_channel(channel);
self.stats.total_channels += 1;
if session.channels.len() > 1 {
self.stats.multichannel_sessions += 1;
}
if channel_type == ChannelType::Rdma {
self.stats.rdma_channels += 1;
}
crate::lcpfs_println!(
"[ SMB ] Channel {} added to session {} ({:?}, {} GB/s)",
channel_id,
session_id,
channel_type,
bandwidth / 1_000_000_000
);
Ok(channel_id)
}
pub fn read(
&mut self,
session_id: u64,
fh: SmbFileHandle,
offset: u64,
size: u64,
timestamp: u64,
) -> Result<u64, &'static str> {
let session = self
.sessions
.get_mut(&session_id)
.ok_or("Session not found")?;
self.file_handles.get(&fh).ok_or("Invalid file handle")?;
if let Some(channel) = session.select_channel() {
channel.record_transfer(size, timestamp);
}
self.stats.reads += 1;
self.stats.bytes_read += size;
Ok(size)
}
pub fn write(
&mut self,
session_id: u64,
fh: SmbFileHandle,
offset: u64,
size: u64,
timestamp: u64,
) -> Result<u64, &'static str> {
let session = self
.sessions
.get_mut(&session_id)
.ok_or("Session not found")?;
let (dataset_id, _) = self.file_handles.get(&fh).ok_or("Invalid file handle")?;
let share = self.shares.iter().find(|s| s.dataset_id == *dataset_id);
if let Some(s) = share {
if s.read_only {
return Err("Share is read-only");
}
}
if let Some(channel) = session.select_channel() {
channel.record_transfer(size, timestamp);
}
self.stats.writes += 1;
self.stats.bytes_written += size;
Ok(size)
}
pub fn allocate_fh(&mut self, dataset_id: u64, offset: u64) -> SmbFileHandle {
let fh = SmbFileHandle(self.next_fh);
self.next_fh += 1;
self.file_handles.insert(fh, (dataset_id, offset));
fh
}
pub fn close_session(&mut self, session_id: u64) {
if self.sessions.remove(&session_id).is_some() {
self.stats.active_sessions = self.stats.active_sessions.saturating_sub(1);
crate::lcpfs_println!("[ SMB ] Session {} closed", session_id);
}
}
pub fn stats(&self) -> SmbStats {
self.stats.clone()
}
pub fn share_count(&self) -> usize {
self.shares.len()
}
}
pub struct Smb;
impl Smb {
pub fn add_share(share: SmbShare) {
let mut server = SMB_SERVER.lock();
server.add_share(share);
}
pub fn create_session(client_addr: String, version: SmbVersion, timestamp: u64) -> u64 {
let mut server = SMB_SERVER.lock();
server.create_session(client_addr, version, timestamp)
}
pub fn add_channel(
session_id: u64,
channel_type: ChannelType,
bandwidth: u64,
) -> Result<u64, &'static str> {
let mut server = SMB_SERVER.lock();
server.add_channel(session_id, channel_type, bandwidth)
}
pub fn read(
session_id: u64,
fh: SmbFileHandle,
offset: u64,
size: u64,
timestamp: u64,
) -> Result<u64, &'static str> {
let mut server = SMB_SERVER.lock();
server.read(session_id, fh, offset, size, timestamp)
}
pub fn write(
session_id: u64,
fh: SmbFileHandle,
offset: u64,
size: u64,
timestamp: u64,
) -> Result<u64, &'static str> {
let mut server = SMB_SERVER.lock();
server.write(session_id, fh, offset, size, timestamp)
}
pub fn stats() -> SmbStats {
let server = SMB_SERVER.lock();
server.stats()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_smb_version() {
assert_eq!(SmbVersion::Smb311.name(), "SMB 3.11");
assert!(SmbVersion::Smb30.supports_multichannel());
assert!(!SmbVersion::Smb21.supports_multichannel());
}
#[test]
fn test_version_ordering() {
assert!(SmbVersion::Smb311 > SmbVersion::Smb30);
assert!(SmbVersion::Smb30 >= SmbVersion::Smb30);
}
#[test]
fn test_rdma_support() {
assert!(SmbVersion::Smb311.supports_rdma());
assert!(!SmbVersion::Smb21.supports_rdma());
}
#[test]
fn test_channel_creation() {
let channel = SmbChannel::new(1, ChannelType::Tcp, 10_000_000_000);
assert_eq!(channel.id, 1);
assert_eq!(channel.channel_type, ChannelType::Tcp);
assert_eq!(channel.bandwidth, 10_000_000_000);
assert!(channel.active);
}
#[test]
fn test_channel_transfer() {
let mut channel = SmbChannel::new(1, ChannelType::Tcp, 10_000_000_000);
channel.record_transfer(4096, 1000);
assert_eq!(channel.bytes_transferred, 4096);
assert_eq!(channel.last_activity, 1000);
}
#[test]
fn test_session_creation() {
let session = SmbSession::new(1, "192.168.1.100".into(), SmbVersion::Smb311, 1000);
assert_eq!(session.id, 1);
assert_eq!(session.version, SmbVersion::Smb311);
assert_eq!(session.channels.len(), 0);
}
#[test]
fn test_session_multichannel() {
let mut session = SmbSession::new(1, "192.168.1.100".into(), SmbVersion::Smb311, 1000);
let ch1 = SmbChannel::new(1, ChannelType::Tcp, 1_000_000_000);
let ch2 = SmbChannel::new(2, ChannelType::Tcp, 1_000_000_000);
session.add_channel(ch1);
session.add_channel(ch2);
assert_eq!(session.active_channels(), 2);
assert_eq!(session.total_bandwidth(), 2_000_000_000);
}
#[test]
fn test_channel_selection() {
let mut session = SmbSession::new(1, "192.168.1.100".into(), SmbVersion::Smb311, 1000);
let mut ch1 = SmbChannel::new(1, ChannelType::Tcp, 1_000_000_000);
let ch2 = SmbChannel::new(2, ChannelType::Tcp, 1_000_000_000);
ch1.bytes_transferred = 1_000_000;
session.add_channel(ch1);
session.add_channel(ch2);
let selected = session
.select_channel()
.expect("test: operation should succeed");
assert_eq!(selected.id, 2);
}
#[test]
fn test_share_creation() {
let share = SmbShare::new("data".into(), 100);
assert_eq!(share.name, "data");
assert_eq!(share.dataset_id, 100);
assert!(!share.read_only);
assert!(share.browseable);
}
#[test]
fn test_share_builder() {
let share = SmbShare::new("data".into(), 100).read_only().allow_guest();
assert!(share.read_only);
assert!(share.guest_ok);
}
#[test]
fn test_server_add_share() {
let mut server = SmbServer::new();
let share = SmbShare::new("data".into(), 100);
server.add_share(share);
assert_eq!(server.share_count(), 1);
}
#[test]
fn test_server_create_session() {
let mut server = SmbServer::new();
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
assert_eq!(session_id, 1);
assert_eq!(server.stats.total_sessions, 1);
assert_eq!(server.stats.active_sessions, 1);
}
#[test]
fn test_server_add_channel() {
let mut server = SmbServer::new();
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
let channel_id = server
.add_channel(session_id, ChannelType::Tcp, 10_000_000_000)
.expect("test: operation should succeed");
assert_eq!(channel_id, 1);
assert_eq!(server.stats.total_channels, 1);
}
#[test]
fn test_multichannel_stats() {
let mut server = SmbServer::new();
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
server
.add_channel(session_id, ChannelType::Tcp, 10_000_000_000)
.expect("test: operation should succeed");
server
.add_channel(session_id, ChannelType::Tcp, 10_000_000_000)
.expect("test: operation should succeed");
assert_eq!(server.stats.multichannel_sessions, 1);
}
#[test]
fn test_rdma_channel() {
let mut server = SmbServer::new();
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
server
.add_channel(session_id, ChannelType::Rdma, 25_000_000_000)
.expect("test: operation should succeed");
assert_eq!(server.stats.rdma_channels, 1);
}
#[test]
fn test_read_operation() {
let mut server = SmbServer::new();
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
server
.add_channel(session_id, ChannelType::Tcp, 10_000_000_000)
.expect("test: operation should succeed");
let fh = server.allocate_fh(100, 0);
let size = server
.read(session_id, fh, 0, 4096, 1100)
.expect("test: operation should succeed");
assert_eq!(size, 4096);
assert_eq!(server.stats.reads, 1);
assert_eq!(server.stats.bytes_read, 4096);
}
#[test]
fn test_write_operation() {
let mut server = SmbServer::new();
let share = SmbShare::new("data".into(), 100);
server.add_share(share);
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
server
.add_channel(session_id, ChannelType::Tcp, 10_000_000_000)
.expect("test: operation should succeed");
let fh = server.allocate_fh(100, 0);
let size = server
.write(session_id, fh, 0, 8192, 1100)
.expect("test: operation should succeed");
assert_eq!(size, 8192);
assert_eq!(server.stats.writes, 1);
assert_eq!(server.stats.bytes_written, 8192);
}
#[test]
fn test_read_only_share() {
let mut server = SmbServer::new();
let share = SmbShare::new("data".into(), 100).read_only();
server.add_share(share);
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
server
.add_channel(session_id, ChannelType::Tcp, 10_000_000_000)
.expect("test: operation should succeed");
let fh = server.allocate_fh(100, 0);
let result = server.write(session_id, fh, 0, 4096, 1100);
assert!(result.is_err());
}
#[test]
fn test_close_session() {
let mut server = SmbServer::new();
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb311, 1000);
assert_eq!(server.stats.active_sessions, 1);
server.close_session(session_id);
assert_eq!(server.stats.active_sessions, 0);
}
#[test]
fn test_old_version_no_multichannel() {
let mut server = SmbServer::new();
let session_id = server.create_session("192.168.1.100".into(), SmbVersion::Smb21, 1000);
let result = server.add_channel(session_id, ChannelType::Tcp, 1_000_000_000);
assert!(result.is_err());
}
#[test]
fn test_file_handle() {
let fh1 = SmbFileHandle(1);
let fh2 = SmbFileHandle(2);
assert_ne!(fh1, fh2);
}
}