use crate::core::{Feature, ColmapError};
use image::GrayImage;
use nalgebra::Point2;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum GpuBackend {
Cuda,
OpenCL,
Auto,
}
#[derive(Debug, Clone)]
pub struct GpuDevice {
pub id: u32,
pub name: String,
pub compute_capability: (u32, u32),
pub global_memory: u64,
pub shared_memory: u32,
pub max_threads: u32,
pub available: bool,
}
#[derive(Debug, Clone)]
pub struct GpuConfig {
pub backend: GpuBackend,
pub device_id: Option<u32>,
pub enable_memory_pool: bool,
pub batch_size: usize,
pub work_group_size: usize,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
backend: GpuBackend::Auto,
device_id: None,
enable_memory_pool: true,
batch_size: 1000,
work_group_size: 256,
}
}
}
pub trait GpuFeatureDetector: Send + Sync {
fn detect_gpu(&self, image: &GrayImage) -> Result<Vec<Feature>, ColmapError>;
fn detect_batch_gpu(&self, images: &[GrayImage]) -> Result<Vec<Vec<Feature>>, ColmapError>;
fn device_info(&self) -> &GpuDevice;
fn config(&self) -> &GpuConfig;
}
pub trait GpuDescriptorExtractor: Send + Sync {
fn compute_gpu(
&self,
image: &GrayImage,
keypoints: &[Point2<f64>],
) -> Result<Vec<Vec<u8>>, ColmapError>;
fn compute_batch_gpu(
&self,
images: &[GrayImage],
keypoints: &[Vec<Point2<f64>>],
) -> Result<Vec<Vec<Vec<u8>>>, ColmapError>;
}
pub struct GpuManager {
devices: Vec<GpuDevice>,
active_device: Option<u32>,
backend: GpuBackend,
}
impl GpuManager {
pub fn new() -> Result<Self, ColmapError> {
let mut manager = Self {
devices: Vec::new(),
active_device: None,
backend: GpuBackend::Auto,
};
manager.initialize()?;
Ok(manager)
}
fn initialize(&mut self) -> Result<(), ColmapError> {
self.detect_devices()?;
if !self.devices.is_empty() {
self.active_device = Some(0);
}
Ok(())
}
fn detect_devices(&mut self) -> Result<(), ColmapError> {
#[cfg(feature = "cuda")]
{
self.detect_cuda_devices()?;
}
#[cfg(feature = "opencl")]
{
self.detect_opencl_devices()?;
}
if self.devices.is_empty() {
let cpu_device = GpuDevice {
id: 0,
name: "CPU Fallback".to_string(),
compute_capability: (0, 0),
global_memory: 8 * 1024 * 1024 * 1024, shared_memory: 0,
max_threads: num_cpus::get() as u32,
available: true,
};
self.devices.push(cpu_device);
}
Ok(())
}
#[cfg(feature = "cuda")]
fn detect_cuda_devices(&mut self) -> Result<(), ColmapError> {
Ok(())
}
#[cfg(feature = "opencl")]
fn detect_opencl_devices(&mut self) -> Result<(), ColmapError> {
Ok(())
}
pub fn devices(&self) -> &[GpuDevice] {
&self.devices
}
pub fn set_active_device(&mut self, device_id: u32) -> Result<(), ColmapError> {
if device_id as usize >= self.devices.len() {
return Err(ColmapError::InvalidParameter(
format!("Invalid device ID: {}", device_id)
));
}
if !self.devices[device_id as usize].available {
return Err(ColmapError::InvalidParameter(
format!("Device {} is not available", device_id)
));
}
self.active_device = Some(device_id);
Ok(())
}
pub fn active_device(&self) -> Option<&GpuDevice> {
self.active_device.and_then(|id| self.devices.get(id as usize))
}
pub fn is_gpu_available(&self) -> bool {
self.active_device.is_some() &&
self.active_device().map_or(false, |dev| dev.available && dev.name != "CPU Fallback")
}
}
pub struct GpuSiftDetector {
config: GpuConfig,
device: GpuDevice,
}
impl GpuSiftDetector {
pub fn new(config: GpuConfig) -> Result<Self, ColmapError> {
let manager = GpuManager::new()?;
let device = manager.active_device()
.ok_or_else(|| ColmapError::InvalidParameter("No GPU device available".to_string()))?
.clone();
Ok(Self {
config,
device,
})
}
}
impl GpuFeatureDetector for GpuSiftDetector {
fn detect_gpu(&self, image: &GrayImage) -> Result<Vec<Feature>, ColmapError> {
if self.device.name == "CPU Fallback" {
return self.detect_cpu_fallback(image);
}
Err(ColmapError::NotImplemented(
"GPU SIFT detection not implemented yet".to_string()
))
}
fn detect_batch_gpu(&self, images: &[GrayImage]) -> Result<Vec<Vec<Feature>>, ColmapError> {
let mut results = Vec::with_capacity(images.len());
for image in images {
results.push(self.detect_gpu(image)?);
}
Ok(results)
}
fn device_info(&self) -> &GpuDevice {
&self.device
}
fn config(&self) -> &GpuConfig {
&self.config
}
}
impl GpuSiftDetector {
fn detect_cpu_fallback(&self, image: &GrayImage) -> Result<Vec<Feature>, ColmapError> {
use crate::feature::detector::{DetectorConfig, DetectorType, SiftDetector, FeatureDetector};
let config = DetectorConfig {
detector_type: DetectorType::Sift,
max_features: self.config.batch_size,
..Default::default()
};
let detector = SiftDetector::new(&config)?;
detector.detect(image)
}
}
pub struct GpuMemoryManager {
pool_size: usize,
allocated: usize,
enable_pool: bool,
}
impl GpuMemoryManager {
pub fn new(pool_size: usize, enable_pool: bool) -> Self {
Self {
pool_size,
allocated: 0,
enable_pool,
}
}
pub fn allocate(&mut self, size: usize) -> Result<(), ColmapError> {
if self.allocated + size > self.pool_size {
return Err(ColmapError::OutOfMemory(
format!("Insufficient GPU memory: requested {}, available {}",
size, self.pool_size - self.allocated)
));
}
self.allocated += size;
Ok(())
}
pub fn deallocate(&mut self, size: usize) {
self.allocated = self.allocated.saturating_sub(size);
}
pub fn available_memory(&self) -> usize {
self.pool_size - self.allocated
}
pub fn reset(&mut self) {
self.allocated = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_config_default() {
let config = GpuConfig::default();
assert_eq!(config.backend, GpuBackend::Auto);
assert_eq!(config.device_id, None);
assert!(config.enable_memory_pool);
assert_eq!(config.batch_size, 1000);
}
#[test]
fn test_gpu_manager_creation() {
let manager = GpuManager::new();
assert!(manager.is_ok());
let mgr = manager.unwrap();
assert!(!mgr.devices().is_empty());
}
#[test]
fn test_gpu_memory_manager() {
let mut memory_mgr = GpuMemoryManager::new(1024 * 1024, true);
assert!(memory_mgr.allocate(512 * 1024).is_ok()); assert_eq!(memory_mgr.available_memory(), 512 * 1024);
assert!(memory_mgr.allocate(256 * 1024).is_ok()); assert_eq!(memory_mgr.available_memory(), 256 * 1024);
assert!(memory_mgr.allocate(512 * 1024).is_err());
memory_mgr.deallocate(256 * 1024);
assert_eq!(memory_mgr.available_memory(), 512 * 1024);
}
#[test]
fn test_gpu_sift_detector_creation() {
let config = GpuConfig::default();
let detector = GpuSiftDetector::new(config);
assert!(detector.is_ok());
let det = detector.unwrap();
assert_eq!(det.device_info().name, "CPU Fallback");
}
}