libcontainer 0.5.1

Library for container control
use std::any::Any;
use std::cell::{Ref, RefCell, RefMut};
use std::collections::HashMap;
use std::ffi::{OsStr, OsString};
use std::path::{Path, PathBuf};
use std::sync::Arc;

use caps::{CapSet, CapsHashSet};
use nix::mount::{MntFlags, MsFlags};
use nix::sched::CloneFlags;
use nix::sys::stat::{Mode, SFlag};
use nix::unistd::{Gid, Uid};
use oci_spec::runtime::PosixRlimit;

use super::{linux, Result, Syscall};

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct MountArgs {
    pub source: Option<PathBuf>,
    pub target: PathBuf,
    pub fstype: Option<String>,
    pub flags: MsFlags,
    pub data: Option<String>,
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct MknodArgs {
    pub path: PathBuf,
    pub kind: SFlag,
    pub perm: Mode,
    pub dev: u64,
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct ChownArgs {
    pub path: PathBuf,
    pub owner: Option<Uid>,
    pub group: Option<Gid>,
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct IoPriorityArgs {
    pub class: i64,
    pub priority: i64,
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct UMount2Args {
    pub target: PathBuf,
    pub flags: MntFlags,
}

#[derive(Default)]
struct Mock {
    values: Vec<Box<dyn Any>>,
    ret_err: Option<fn() -> Result<()>>,
    ret_err_times: usize,
}

#[derive(PartialEq, Eq, Hash, Copy, Clone)]
pub enum ArgName {
    Namespace,
    Unshare,
    Mount,
    Symlink,
    Mknod,
    Chown,
    Hostname,
    Domainname,
    Groups,
    Capability,
    IoPriority,
    UMount2,
}

impl ArgName {
    fn iterator() -> impl Iterator<Item = ArgName> {
        [
            ArgName::Namespace,
            ArgName::Unshare,
            ArgName::Mount,
            ArgName::Symlink,
            ArgName::Mknod,
            ArgName::Chown,
            ArgName::Hostname,
            ArgName::Domainname,
            ArgName::Groups,
            ArgName::Capability,
            ArgName::IoPriority,
        ]
        .iter()
        .copied()
    }
}

struct MockCalls {
    args: HashMap<ArgName, RefCell<Mock>>,
}

impl Default for MockCalls {
    fn default() -> Self {
        let mut m = MockCalls {
            args: HashMap::new(),
        };

        for name in ArgName::iterator() {
            m.args.insert(name, RefCell::new(Mock::default()));
        }

        m
    }
}

impl MockCalls {
    fn act(&self, name: ArgName, value: Box<dyn Any>) -> Result<()> {
        if self.args.get(&name).unwrap().borrow().ret_err_times > 0 {
            self.args.get(&name).unwrap().borrow_mut().ret_err_times -= 1;
            if let Some(e) = &self.args.get(&name).unwrap().borrow().ret_err {
                return e();
            }
        }

        self.args
            .get(&name)
            .unwrap()
            .borrow_mut()
            .values
            .push(value);
        Ok(())
    }

    fn fetch(&self, name: ArgName) -> Ref<Mock> {
        self.args.get(&name).unwrap().borrow()
    }

    fn fetch_mut(&self, name: ArgName) -> RefMut<Mock> {
        self.args.get(&name).unwrap().borrow_mut()
    }
}

#[derive(Default)]
pub struct TestHelperSyscall {
    mocks: MockCalls,
}

impl Syscall for TestHelperSyscall {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn pivot_rootfs(&self, _path: &Path) -> Result<()> {
        unimplemented!()
    }

    fn set_ns(&self, rawfd: i32, nstype: CloneFlags) -> Result<()> {
        self.mocks
            .act(ArgName::Namespace, Box::new((rawfd, nstype)))
    }

    fn set_id(&self, _uid: Uid, _gid: Gid) -> Result<()> {
        unimplemented!()
    }

    fn unshare(&self, flags: CloneFlags) -> Result<()> {
        self.mocks.act(ArgName::Unshare, Box::new(flags))
    }

    fn set_capability(&self, cset: CapSet, value: &CapsHashSet) -> Result<()> {
        self.mocks
            .act(ArgName::Capability, Box::new((cset, value.clone())))
    }

    fn set_hostname(&self, hostname: &str) -> Result<()> {
        self.mocks
            .act(ArgName::Hostname, Box::new(hostname.to_owned()))
    }

    fn set_domainname(&self, domainname: &str) -> Result<()> {
        self.mocks
            .act(ArgName::Domainname, Box::new(domainname.to_owned()))
    }

    fn set_rlimit(&self, _rlimit: &PosixRlimit) -> Result<()> {
        todo!()
    }

    fn get_pwuid(&self, _: u32) -> Option<Arc<OsStr>> {
        Some(OsString::from("youki").into())
    }

    fn chroot(&self, _: &Path) -> Result<()> {
        todo!()
    }

    fn mount(
        &self,
        source: Option<&Path>,
        target: &Path,
        fstype: Option<&str>,
        flags: MsFlags,
        data: Option<&str>,
    ) -> Result<()> {
        self.mocks.act(
            ArgName::Mount,
            Box::new(MountArgs {
                source: source.map(|x| x.to_owned()),
                target: target.to_owned(),
                fstype: fstype.map(|x| x.to_owned()),
                flags,
                data: data.map(|x| x.to_owned()),
            }),
        )
    }

    fn symlink(&self, original: &Path, link: &Path) -> Result<()> {
        self.mocks.act(
            ArgName::Symlink,
            Box::new((original.to_path_buf(), link.to_path_buf())),
        )
    }

    fn mknod(&self, path: &Path, kind: SFlag, perm: Mode, dev: u64) -> Result<()> {
        self.mocks.act(
            ArgName::Mknod,
            Box::new(MknodArgs {
                path: path.to_path_buf(),
                kind,
                perm,
                dev,
            }),
        )
    }
    fn chown(&self, path: &Path, owner: Option<Uid>, group: Option<Gid>) -> Result<()> {
        self.mocks.act(
            ArgName::Chown,
            Box::new(ChownArgs {
                path: path.to_path_buf(),
                owner,
                group,
            }),
        )
    }

    fn set_groups(&self, groups: &[Gid]) -> Result<()> {
        self.mocks.act(ArgName::Groups, Box::new(groups.to_vec()))
    }

    fn close_range(&self, _: i32) -> Result<()> {
        todo!()
    }

    fn mount_setattr(
        &self,
        _: i32,
        _: &Path,
        _: u32,
        _: &linux::MountAttr,
        _: libc::size_t,
    ) -> Result<()> {
        todo!()
    }

    fn set_io_priority(&self, class: i64, priority: i64) -> Result<()> {
        self.mocks.act(
            ArgName::IoPriority,
            Box::new(IoPriorityArgs { class, priority }),
        )
    }

    fn umount2(&self, target: &Path, flags: MntFlags) -> Result<()> {
        self.mocks.act(
            ArgName::UMount2,
            Box::new(UMount2Args {
                target: target.to_owned(),
                flags,
            }),
        )
    }
}

impl TestHelperSyscall {
    pub fn set_ret_err(&self, name: ArgName, err: fn() -> Result<()>) {
        self.mocks.fetch_mut(name).ret_err = Some(err);
        self.set_ret_err_times(name, 1);
    }

    pub fn set_ret_err_times(&self, name: ArgName, times: usize) {
        self.mocks.fetch_mut(name).ret_err_times = times;
    }

    pub fn get_setns_args(&self) -> Vec<(i32, CloneFlags)> {
        self.mocks
            .fetch(ArgName::Namespace)
            .values
            .iter()
            .map(|x| *x.downcast_ref::<(i32, CloneFlags)>().unwrap())
            .collect::<Vec<(i32, CloneFlags)>>()
    }

    pub fn get_unshare_args(&self) -> Vec<CloneFlags> {
        self.mocks
            .fetch(ArgName::Unshare)
            .values
            .iter()
            .map(|x| *x.downcast_ref::<CloneFlags>().unwrap())
            .collect::<Vec<CloneFlags>>()
    }

    pub fn get_set_capability_args(&self) -> Vec<(CapSet, CapsHashSet)> {
        self.mocks
            .fetch(ArgName::Capability)
            .values
            .iter()
            .map(|x| x.downcast_ref::<(CapSet, CapsHashSet)>().unwrap().clone())
            .collect::<Vec<(CapSet, CapsHashSet)>>()
    }

    pub fn get_mount_args(&self) -> Vec<MountArgs> {
        self.mocks
            .fetch(ArgName::Mount)
            .values
            .iter()
            .map(|x| x.downcast_ref::<MountArgs>().unwrap().clone())
            .collect::<Vec<MountArgs>>()
    }

    pub fn get_symlink_args(&self) -> Vec<(PathBuf, PathBuf)> {
        self.mocks
            .fetch(ArgName::Symlink)
            .values
            .iter()
            .map(|x| x.downcast_ref::<(PathBuf, PathBuf)>().unwrap().clone())
            .collect::<Vec<(PathBuf, PathBuf)>>()
    }

    pub fn get_mknod_args(&self) -> Vec<MknodArgs> {
        self.mocks
            .fetch(ArgName::Mknod)
            .values
            .iter()
            .map(|x| x.downcast_ref::<MknodArgs>().unwrap().clone())
            .collect::<Vec<MknodArgs>>()
    }

    pub fn get_chown_args(&self) -> Vec<ChownArgs> {
        self.mocks
            .fetch(ArgName::Chown)
            .values
            .iter()
            .map(|x| x.downcast_ref::<ChownArgs>().unwrap().clone())
            .collect::<Vec<ChownArgs>>()
    }

    pub fn get_hostname_args(&self) -> Vec<String> {
        self.mocks
            .fetch(ArgName::Hostname)
            .values
            .iter()
            .map(|x| x.downcast_ref::<String>().unwrap().clone())
            .collect::<Vec<String>>()
    }

    pub fn get_domainname_args(&self) -> Vec<String> {
        self.mocks
            .fetch(ArgName::Domainname)
            .values
            .iter()
            .map(|x| x.downcast_ref::<String>().unwrap().clone())
            .collect::<Vec<String>>()
    }

    pub fn get_groups_args(&self) -> Vec<Gid> {
        self.mocks
            .fetch(ArgName::Groups)
            .values
            .iter()
            .flat_map(|x| x.downcast_ref::<Vec<Gid>>().unwrap().clone())
            .collect::<Vec<Gid>>()
    }

    pub fn get_io_priority_args(&self) -> Vec<IoPriorityArgs> {
        self.mocks
            .fetch(ArgName::IoPriority)
            .values
            .iter()
            .map(|x| x.downcast_ref::<IoPriorityArgs>().unwrap().clone())
            .collect::<Vec<IoPriorityArgs>>()
    }

    pub fn get_umount_args(&self) -> Vec<UMount2Args> {
        self.mocks
            .fetch(ArgName::UMount2)
            .values
            .iter()
            .map(|x| x.downcast_ref::<UMount2Args>().unwrap().clone())
            .collect::<Vec<UMount2Args>>()
    }
}