1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use std::u32;
use std::os::raw::c_void;
use std::ptr::{null, null_mut};

use native::*;
use super::{bitflags, check, get_info, ErrorStatus, Flags};

impl Region {
    pub fn segment(&self) -> Result<RegionSegment, ErrorStatus> {
        get_info(|x| self.get_info(RegionInfo::Segment, x))
    }

    pub fn global_flags(&self) -> Result<Flags<RegionGlobalFlag>, ErrorStatus> {
        get_info(|x| self.get_info(RegionInfo::GlobalFlags, x))
            .map(|flags: RegionGlobalFlag| bitflags(flags as u32))
    }

    pub fn size(&self) -> Result<usize, ErrorStatus> {
        get_info(|x| self.get_info(RegionInfo::Size, x))
    }

    pub fn alloc_max_size(&self) -> Result<usize, ErrorStatus> {
        get_info(|x| self.get_info(RegionInfo::AllocMaxSize, x))
    }

    pub fn alloc_max_private_workgroup_size(&self) -> Result<u32, ErrorStatus> {
        get_info(|x| {
            self.get_info(RegionInfo::AllocMaxPrivateWorkgroupSize, x)
        })
    }

    pub fn runtime_alloc_allowed(&self) -> Result<bool, ErrorStatus> {
        get_info(|x| self.get_info(RegionInfo::RuntimeAllocAllowed, x))
    }

    pub fn runtime_alloc_granule(&self) -> Result<usize, ErrorStatus> {
        get_info(|x| self.get_info(RegionInfo::RuntimeAllocGranule, x))
    }

    pub fn runtime_alloc_alignment(&self) -> Result<usize, ErrorStatus> {
        get_info(|x| self.get_info(RegionInfo::RuntimeAllocAlignment, x))
    }

    fn get_info(&self, attr: RegionInfo, v: *mut c_void) -> HSAStatus {
        unsafe { hsa_region_get_info(*self, attr, v) }
    }
}

pub enum Memory<T> {
    RegionMemory(*mut T),
    Registered(*mut T, usize),
    None,
}

impl Memory<u8> {
    pub fn allocate(region: Region, size: usize) -> Result<Memory<u8>, ErrorStatus> {
        let ptr: *mut c_void = null_mut();
        unsafe {
            check(hsa_memory_allocate(region, size, &ptr), ()).map(
                |_| {
                    Memory::RegionMemory(ptr as *mut u8)
                },
            )
        }
    }

    pub fn register(ptr: *mut u8, size: usize) -> Result<Memory<u8>, ErrorStatus> {
        unsafe {
            check(hsa_memory_register(ptr as *mut c_void, size), ())
                .map(|_| Memory::Registered(ptr, size))
        }
    }
}

impl<T> Memory<T> {
    pub fn new(region: Region) -> Result<Memory<T>, ErrorStatus> {
        use std::mem::size_of;
        let ptr: *mut c_void = null_mut();
        let size = size_of::<T>();
        unsafe {
            check(hsa_memory_allocate(region, size, &ptr), ()).map(
                |_| {
                    Memory::RegionMemory(ptr as *mut T)
                },
            )
        }
    }

    pub fn as_ptr(&self) -> *const T {
        match *self {
            Memory::RegionMemory(x) |
            Memory::Registered(x, _) => x,
            Memory::None => null(),
        }
    }

    pub fn as_mut_ptr(&self) -> *mut T {
        match *self {
            Memory::RegionMemory(x) |
            Memory::Registered(x, _) => x,
            Memory::None => null_mut(),
        }
    }

    pub fn assign_agent(&self, agent: Agent, access: AccessPermission) -> Result<(), ErrorStatus> {
        let ptr = self.as_mut_ptr();
        if ptr.is_null() {
            return Err(ErrorStatus::InvalidArgument);
        }
        check(unsafe {
            hsa_memory_assign_agent(ptr as *mut c_void, agent, access)
        }, ())
    }

    pub fn copy_from(&mut self, src: &T) {
        unsafe {
            use std::ptr::copy_nonoverlapping;
            let ptr = src as *const _ as *const T;
            copy_nonoverlapping(ptr, self.as_mut_ptr(), 1);
        }
    }
}

impl<T> Drop for Memory<T> {
    fn drop(&mut self) {
        match *self {
            Memory::RegionMemory(x) => unsafe {
                hsa_memory_free(x as *mut c_void);
            },
            Memory::Registered(x, sz) => unsafe {
                hsa_memory_deregister(x as *mut c_void, sz);
            },
            _ => (),
        }
    }
}

pub unsafe fn copy<T>(src: *const T, dst: *mut T, bytes: usize) -> Result<(), ErrorStatus> {
    check(hsa_memory_copy(dst as *mut c_void, src as *const c_void, bytes), ())
}