use crate::error::Result;
use crate::faiss_try;
use faiss_sys::*;
use std::ptr;
pub trait GpuResources {
fn inner_ptr(&self) -> *mut FaissGpuResourcesProvider;
fn no_temp_memory(&mut self) -> Result<()>;
fn set_temp_memory(&mut self, size: usize) -> Result<()>;
fn set_pinned_memory(&mut self, size: usize) -> Result<()>;
}
pub trait GpuResourcesProvider {
fn inner_ptr(&self) -> *mut FaissGpuResourcesProvider;
}
#[derive(Debug)]
pub struct StandardGpuResources {
inner: *mut FaissGpuResourcesProvider,
}
unsafe impl Send for StandardGpuResources {}
impl StandardGpuResources {
pub fn new() -> Result<Self> {
unsafe {
let mut ptr = ptr::null_mut();
faiss_try(faiss_StandardGpuResources_new(&mut ptr))?;
Ok(StandardGpuResources { inner: ptr })
}
}
}
impl Drop for StandardGpuResources {
fn drop(&mut self) {
unsafe {
faiss_StandardGpuResources_free(self.inner);
}
}
}
impl GpuResourcesProvider for StandardGpuResources {
fn inner_ptr(&self) -> *mut FaissGpuResourcesProvider {
self.inner as *mut _
}
}
impl GpuResources for StandardGpuResources {
fn inner_ptr(&self) -> *mut FaissGpuResourcesProvider {
self.inner
}
fn no_temp_memory(&mut self) -> Result<()> {
unsafe {
faiss_try(faiss_StandardGpuResources_noTempMemory(self.inner))?;
Ok(())
}
}
fn set_temp_memory(&mut self, size: usize) -> Result<()> {
unsafe {
faiss_try(faiss_StandardGpuResources_setTempMemory(self.inner, size))?;
Ok(())
}
}
fn set_pinned_memory(&mut self, size: usize) -> Result<()> {
unsafe {
faiss_try(faiss_StandardGpuResources_setPinnedMemory(self.inner, size))?;
Ok(())
}
}
}
impl<'g> GpuResources for &'g mut StandardGpuResources {
fn inner_ptr(&self) -> *mut FaissGpuResourcesProvider {
self.inner
}
fn no_temp_memory(&mut self) -> Result<()> {
(**self).no_temp_memory()
}
fn set_temp_memory(&mut self, size: usize) -> Result<()> {
(**self).set_temp_memory(size)
}
fn set_pinned_memory(&mut self, size: usize) -> Result<()> {
(**self).set_pinned_memory(size)
}
}
#[cfg(test)]
mod tests {
use super::StandardGpuResources;
#[test]
fn smoke_detector() {
StandardGpuResources::new().unwrap();
}
#[ignore]
#[test]
fn resources_leak() {
use crate::{index_factory, MetricType};
for _ in 0..50 {
let res = StandardGpuResources::new().unwrap();
let index = index_factory(32, "Flat", MetricType::InnerProduct).unwrap();
let _gpu_index = index.into_gpu(&res, 0).unwrap();
}
}
}