use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use lazy_static::lazy_static;
use spin::Mutex;
use super::error::{NfsError, NfsResult, NfsStatus};
use super::types::{ClientId, FileHandle, SessionId};
#[derive(Debug, Clone)]
pub struct ExportOptions {
pub read_only: bool,
pub no_root_squash: bool,
pub all_squash: bool,
pub anon_uid: u32,
pub anon_gid: u32,
pub sync: bool,
pub subtree_check: bool,
pub security: Vec<SecurityFlavor>,
pub max_clients: usize,
}
impl Default for ExportOptions {
fn default() -> Self {
Self {
read_only: false,
no_root_squash: false,
all_squash: false,
anon_uid: 65534, anon_gid: 65534, sync: true,
subtree_check: false,
security: vec![SecurityFlavor::Sys],
max_clients: 0,
}
}
}
impl ExportOptions {
pub fn read_only() -> Self {
Self {
read_only: true,
..Self::default()
}
}
pub fn read_write() -> Self {
Self::default()
}
pub fn with_no_root_squash(mut self) -> Self {
self.no_root_squash = true;
self
}
pub fn with_security(mut self, flavors: Vec<SecurityFlavor>) -> Self {
self.security = flavors;
self
}
pub fn with_async(mut self) -> Self {
self.sync = false;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum SecurityFlavor {
None = 0,
Sys = 1,
Krb5 = 390003,
Krb5i = 390004,
Krb5p = 390005,
}
impl SecurityFlavor {
pub fn from_u32(v: u32) -> Option<Self> {
match v {
0 => Some(Self::None),
1 => Some(Self::Sys),
390003 => Some(Self::Krb5),
390004 => Some(Self::Krb5i),
390005 => Some(Self::Krb5p),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct AccessRule {
pub client: ClientSpec,
pub options: ExportOptions,
}
#[derive(Debug, Clone)]
pub enum ClientSpec {
Ip([u8; 4]),
Network {
ip: [u8; 4],
mask: [u8; 4],
},
Host(String),
All,
}
impl ClientSpec {
pub fn parse(s: &str) -> Option<Self> {
let s = s.trim();
if s == "*" {
return Some(Self::All);
}
if let Some((ip_str, mask_str)) = s.split_once('/') {
let ip = Self::parse_ip(ip_str)?;
let prefix_len: u8 = mask_str.parse().ok()?;
if prefix_len > 32 {
return None;
}
let mask = Self::prefix_to_mask(prefix_len);
return Some(Self::Network { ip, mask });
}
if let Some(ip) = Self::parse_ip(s) {
return Some(Self::Ip(ip));
}
Some(Self::Host(s.into()))
}
fn parse_ip(s: &str) -> Option<[u8; 4]> {
let parts: Vec<_> = s.split('.').collect();
if parts.len() != 4 {
return None;
}
let mut ip = [0u8; 4];
for (i, part) in parts.iter().enumerate() {
ip[i] = part.parse().ok()?;
}
Some(ip)
}
fn prefix_to_mask(prefix: u8) -> [u8; 4] {
if prefix == 0 {
return [0, 0, 0, 0];
}
let mask_u32 = !((1u32 << (32 - prefix)) - 1);
mask_u32.to_be_bytes()
}
pub fn matches(&self, client_ip: [u8; 4]) -> bool {
match self {
Self::All => true,
Self::Ip(ip) => client_ip == *ip,
Self::Network { ip, mask } => {
for i in 0..4 {
if (client_ip[i] & mask[i]) != (ip[i] & mask[i]) {
return false;
}
}
true
}
Self::Host(_) => {
false
}
}
}
}
#[derive(Debug, Clone)]
pub struct Export {
pub id: u64,
pub dataset: String,
pub path: String,
pub access_rules: Vec<AccessRule>,
pub default_options: Option<ExportOptions>,
pub active: bool,
}
impl Export {
pub fn new(dataset: impl Into<String>, path: impl Into<String>) -> Self {
static NEXT_ID: core::sync::atomic::AtomicU64 = core::sync::atomic::AtomicU64::new(1);
Self {
id: NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::SeqCst),
dataset: dataset.into(),
path: path.into(),
access_rules: Vec::new(),
default_options: Some(ExportOptions::read_only()),
active: true,
}
}
pub fn with_rule(mut self, client: ClientSpec, options: ExportOptions) -> Self {
self.access_rules.push(AccessRule { client, options });
self
}
pub fn with_default(mut self, options: ExportOptions) -> Self {
self.default_options = Some(options);
self
}
pub fn allow_all(mut self, options: ExportOptions) -> Self {
self.access_rules.push(AccessRule {
client: ClientSpec::All,
options,
});
self
}
pub fn root_handle(&self) -> FileHandle {
FileHandle::new(self.id, 0, 0)
}
pub fn check_access(&self, client_ip: [u8; 4]) -> Option<&ExportOptions> {
for rule in &self.access_rules {
if rule.client.matches(client_ip) {
return Some(&rule.options);
}
}
self.default_options.as_ref()
}
pub fn export_path(&self) -> String {
if self.path.is_empty() {
alloc::format!("/{}", self.dataset)
} else {
alloc::format!("/{}{}", self.dataset, self.path)
}
}
}
lazy_static! {
static ref EXPORTS: Mutex<ExportRegistry> = Mutex::new(ExportRegistry::new());
}
#[derive(Debug)]
struct ExportRegistry {
exports: BTreeMap<u64, Export>,
by_dataset: BTreeMap<String, Vec<u64>>,
}
impl ExportRegistry {
fn new() -> Self {
Self {
exports: BTreeMap::new(),
by_dataset: BTreeMap::new(),
}
}
}
pub fn add_export(export: Export) -> u64 {
let mut reg = EXPORTS.lock();
let id = export.id;
let dataset = export.dataset.clone();
reg.exports.insert(id, export);
reg.by_dataset.entry(dataset).or_default().push(id);
id
}
pub fn remove_export(id: u64) -> Option<Export> {
let mut reg = EXPORTS.lock();
if let Some(export) = reg.exports.remove(&id) {
if let Some(ids) = reg.by_dataset.get_mut(&export.dataset) {
ids.retain(|&x| x != id);
}
Some(export)
} else {
None
}
}
pub fn get_export(id: u64) -> Option<Export> {
EXPORTS.lock().exports.get(&id).cloned()
}
pub fn get_exports_for_dataset(dataset: &str) -> Vec<Export> {
let reg = EXPORTS.lock();
reg.by_dataset
.get(dataset)
.map(|ids| {
ids.iter()
.filter_map(|id| reg.exports.get(id).cloned())
.collect()
})
.unwrap_or_default()
}
pub fn list_exports() -> Vec<Export> {
EXPORTS.lock().exports.values().cloned().collect()
}
pub fn export_count() -> usize {
EXPORTS.lock().exports.len()
}
pub fn find_export_by_handle(fh: &FileHandle) -> Option<Export> {
get_export(fh.dataset_id())
}
pub fn clear_exports() {
let mut reg = EXPORTS.lock();
reg.exports.clear();
reg.by_dataset.clear();
}
lazy_static! {
static ref CLIENTS: Mutex<ClientRegistry> = Mutex::new(ClientRegistry::new());
}
#[derive(Debug, Clone)]
pub struct ClientState {
pub id: ClientId,
pub verifier: [u8; 8],
pub owner_id: Vec<u8>,
pub address: [u8; 4],
pub lease_expires: u64,
pub sessions: Vec<SessionId>,
pub confirmed: bool,
pub created_at: u64,
pub last_activity: u64,
}
impl ClientState {
pub fn new(owner_id: Vec<u8>, verifier: [u8; 8], address: [u8; 4]) -> Self {
Self {
id: ClientId::generate(),
verifier,
owner_id,
address,
lease_expires: 0,
sessions: Vec::new(),
confirmed: false,
created_at: 0,
last_activity: 0,
}
}
pub fn is_lease_expired(&self, now: u64) -> bool {
now > self.lease_expires
}
pub fn renew_lease(&mut self, duration: u64, now: u64) {
self.lease_expires = now + duration;
self.last_activity = now;
}
}
#[derive(Debug)]
struct ClientRegistry {
clients: BTreeMap<u64, ClientState>,
by_owner: BTreeMap<Vec<u8>, u64>,
lease_duration: u64,
}
impl ClientRegistry {
fn new() -> Self {
Self {
clients: BTreeMap::new(),
by_owner: BTreeMap::new(),
lease_duration: 90, }
}
}
pub fn register_client(state: ClientState) -> ClientId {
let mut reg = CLIENTS.lock();
let id = state.id;
let owner_id = state.owner_id.clone();
reg.clients.insert(id.id, state);
reg.by_owner.insert(owner_id, id.id);
id
}
pub fn get_client(id: ClientId) -> Option<ClientState> {
CLIENTS.lock().clients.get(&id.id).cloned()
}
pub fn find_client_by_owner(owner_id: &[u8]) -> Option<ClientState> {
let reg = CLIENTS.lock();
reg.by_owner
.get(owner_id)
.and_then(|id| reg.clients.get(id).cloned())
}
pub fn confirm_client(id: ClientId) -> NfsResult<()> {
let mut reg = CLIENTS.lock();
if let Some(client) = reg.clients.get_mut(&id.id) {
client.confirmed = true;
Ok(())
} else {
Err(NfsError::new(NfsStatus::Stale))
}
}
pub fn renew_client_lease(id: ClientId, now: u64) -> NfsResult<()> {
let mut reg = CLIENTS.lock();
let duration = reg.lease_duration;
if let Some(client) = reg.clients.get_mut(&id.id) {
client.renew_lease(duration, now);
Ok(())
} else {
Err(NfsError::new(NfsStatus::Stale))
}
}
pub fn cleanup_expired_clients(now: u64) -> Vec<ClientId> {
let mut reg = CLIENTS.lock();
let expired: Vec<_> = reg
.clients
.iter()
.filter(|(_, c)| c.is_lease_expired(now))
.map(|(id, c)| (*id, c.owner_id.clone()))
.collect();
for (id, owner_id) in &expired {
reg.clients.remove(id);
reg.by_owner.remove(owner_id);
}
expired
.into_iter()
.map(|(id, _)| ClientId::new(id))
.collect()
}
pub fn client_count() -> usize {
CLIENTS.lock().clients.len()
}
pub fn set_lease_duration(seconds: u64) {
CLIENTS.lock().lease_duration = seconds;
}
pub fn get_lease_duration() -> u64 {
CLIENTS.lock().lease_duration
}
pub fn clear_clients() {
let mut reg = CLIENTS.lock();
reg.clients.clear();
reg.by_owner.clear();
}
lazy_static! {
static ref SESSIONS: Mutex<SessionRegistry> = Mutex::new(SessionRegistry::new());
}
#[derive(Debug, Clone)]
pub struct SessionState {
pub id: SessionId,
pub client_id: ClientId,
pub sequence_id: u32,
pub fore_slots: u32,
pub back_slots: u32,
pub created_at: u64,
pub active: bool,
}
impl SessionState {
pub fn new(client_id: ClientId) -> Self {
Self {
id: SessionId::generate(),
client_id,
sequence_id: 1,
fore_slots: 16,
back_slots: 4,
created_at: 0,
active: true,
}
}
}
#[derive(Debug)]
struct SessionRegistry {
sessions: BTreeMap<[u8; 16], SessionState>,
by_client: BTreeMap<u64, Vec<[u8; 16]>>,
}
impl SessionRegistry {
fn new() -> Self {
Self {
sessions: BTreeMap::new(),
by_client: BTreeMap::new(),
}
}
}
pub fn create_session(client_id: ClientId) -> SessionState {
let mut reg = SESSIONS.lock();
let session = SessionState::new(client_id);
let session_id = session.id.id;
reg.sessions.insert(session_id, session.clone());
reg.by_client
.entry(client_id.id)
.or_default()
.push(session_id);
session
}
pub fn get_session(id: SessionId) -> Option<SessionState> {
SESSIONS.lock().sessions.get(&id.id).cloned()
}
pub fn get_client_sessions(client_id: ClientId) -> Vec<SessionState> {
let reg = SESSIONS.lock();
reg.by_client
.get(&client_id.id)
.map(|ids| {
ids.iter()
.filter_map(|id| reg.sessions.get(id).cloned())
.collect()
})
.unwrap_or_default()
}
pub fn destroy_session(id: SessionId) -> Option<SessionState> {
let mut reg = SESSIONS.lock();
if let Some(session) = reg.sessions.remove(&id.id) {
if let Some(ids) = reg.by_client.get_mut(&session.client_id.id) {
ids.retain(|sid| sid != &id.id);
}
Some(session)
} else {
None
}
}
pub fn session_count() -> usize {
SESSIONS.lock().sessions.len()
}
pub fn clear_sessions() {
let mut reg = SESSIONS.lock();
reg.sessions.clear();
reg.by_client.clear();
}
#[cfg(test)]
mod tests {
use super::*;
fn setup() {
clear_exports();
clear_clients();
clear_sessions();
}
#[test]
fn test_export_options_default() {
let opts = ExportOptions::default();
assert!(!opts.read_only);
assert!(!opts.no_root_squash);
assert!(opts.sync);
}
#[test]
fn test_export_options_builder() {
let opts = ExportOptions::read_only()
.with_no_root_squash()
.with_async();
assert!(opts.read_only);
assert!(opts.no_root_squash);
assert!(!opts.sync);
}
#[test]
fn test_client_spec_parse() {
let all = ClientSpec::parse("*").unwrap();
assert!(matches!(all, ClientSpec::All));
let ip = ClientSpec::parse("192.168.1.1").unwrap();
assert!(matches!(ip, ClientSpec::Ip([192, 168, 1, 1])));
let net = ClientSpec::parse("10.0.0.0/8").unwrap();
assert!(matches!(
net,
ClientSpec::Network {
ip: [10, 0, 0, 0],
..
}
));
let host = ClientSpec::parse("server.example.com").unwrap();
assert!(matches!(host, ClientSpec::Host(_)));
}
#[test]
fn test_client_spec_matches() {
let all = ClientSpec::All;
assert!(all.matches([1, 2, 3, 4]));
let ip = ClientSpec::Ip([192, 168, 1, 1]);
assert!(ip.matches([192, 168, 1, 1]));
assert!(!ip.matches([192, 168, 1, 2]));
let net = ClientSpec::Network {
ip: [192, 168, 1, 0],
mask: [255, 255, 255, 0],
};
assert!(net.matches([192, 168, 1, 100]));
assert!(!net.matches([192, 168, 2, 1]));
}
#[test]
fn test_export_creation() {
let export = Export::new("tank/data", "/shared").allow_all(ExportOptions::read_write());
assert_eq!(export.dataset, "tank/data");
assert_eq!(export.path, "/shared");
assert_eq!(export.access_rules.len(), 1);
}
#[test]
fn test_export_check_access() {
let export = Export::new("tank", "")
.with_rule(
ClientSpec::Ip([192, 168, 1, 100]),
ExportOptions::read_write(),
)
.with_default(ExportOptions::read_only());
let opts = export.check_access([192, 168, 1, 100]).unwrap();
assert!(!opts.read_only);
let opts = export.check_access([192, 168, 1, 200]).unwrap();
assert!(opts.read_only);
}
#[test]
fn test_export_registry() {
setup();
let export = Export::new("pool/test", "");
let id = add_export(export);
assert!(get_export(id).is_some());
assert_eq!(export_count(), 1);
let exports = get_exports_for_dataset("pool/test");
assert_eq!(exports.len(), 1);
remove_export(id);
assert!(get_export(id).is_none());
}
#[test]
fn test_client_registration() {
setup();
let state = ClientState::new(b"test-client".to_vec(), [1; 8], [127, 0, 0, 1]);
let id = register_client(state);
let client = get_client(id).unwrap();
assert_eq!(client.owner_id, b"test-client");
let found = find_client_by_owner(b"test-client").unwrap();
assert_eq!(found.id, id);
}
#[test]
fn test_client_lease() {
setup();
let mut state = ClientState::new(b"lease-test".to_vec(), [2; 8], [127, 0, 0, 1]);
state.lease_expires = 100;
assert!(!state.is_lease_expired(50));
assert!(state.is_lease_expired(150));
state.renew_lease(60, 100);
assert_eq!(state.lease_expires, 160);
}
#[test]
fn test_session_creation() {
setup();
let client_id = ClientId::generate();
let session = create_session(client_id);
assert_eq!(session.client_id, client_id);
assert!(session.active);
let retrieved = get_session(session.id).unwrap();
assert_eq!(retrieved.id, session.id);
let sessions = get_client_sessions(client_id);
assert_eq!(sessions.len(), 1);
destroy_session(session.id);
assert!(get_session(session.id).is_none());
}
#[test]
fn test_security_flavor() {
assert_eq!(SecurityFlavor::from_u32(0), Some(SecurityFlavor::None));
assert_eq!(SecurityFlavor::from_u32(1), Some(SecurityFlavor::Sys));
assert_eq!(SecurityFlavor::from_u32(390003), Some(SecurityFlavor::Krb5));
}
#[test]
fn test_export_path() {
let e1 = Export::new("tank", "");
assert_eq!(e1.export_path(), "/tank");
let e2 = Export::new("tank/data", "/shared");
assert_eq!(e2.export_path(), "/tank/data/shared");
}
}