hive_gpu/
types.rs

1//! Core types for Hive GPU
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// A GPU vector with its associated data
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct GpuVector {
9    /// Unique identifier for the vector
10    pub id: String,
11    /// The vector data (always f32 for compatibility)
12    pub data: Vec<f32>,
13    /// Optional metadata associated with the vector
14    pub metadata: HashMap<String, String>,
15}
16
17impl GpuVector {
18    /// Create a new GPU vector
19    pub fn new(id: String, data: Vec<f32>) -> Self {
20        Self {
21            id,
22            data,
23            metadata: HashMap::new(),
24        }
25    }
26
27    /// Create a new GPU vector with metadata
28    pub fn with_metadata(id: String, data: Vec<f32>, metadata: HashMap<String, String>) -> Self {
29        Self { id, data, metadata }
30    }
31
32    /// Get the dimension of the vector
33    pub fn dimension(&self) -> usize {
34        self.data.len()
35    }
36
37    /// Get memory usage in bytes
38    pub fn memory_size(&self) -> usize {
39        self.data.len() * std::mem::size_of::<f32>() + self.id.len() + self.metadata.len() * 32 // rough estimate
40    }
41}
42
43impl From<&GpuVector> for Vec<f32> {
44    fn from(v: &GpuVector) -> Self {
45        v.data.clone()
46    }
47}
48
49/// Distance metrics for vector similarity
50#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "lowercase")]
52pub enum GpuDistanceMetric {
53    /// Cosine similarity
54    Cosine,
55    /// Euclidean distance
56    Euclidean,
57    /// Dot product
58    DotProduct,
59}
60
61impl std::fmt::Display for GpuDistanceMetric {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            GpuDistanceMetric::Cosine => write!(f, "cosine"),
65            GpuDistanceMetric::Euclidean => write!(f, "euclidean"),
66            GpuDistanceMetric::DotProduct => write!(f, "dot_product"),
67        }
68    }
69}
70
71/// Search result from GPU operations
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct GpuSearchResult {
74    /// Vector ID
75    pub id: String,
76    /// Similarity score
77    pub score: f32,
78    /// Vector index in storage
79    pub index: usize,
80}
81
82/// GPU device information
83///
84/// Provides detailed information about the GPU device including memory,
85/// capabilities, and backend-specific details.
86///
87/// # Examples
88///
89/// ```no_run
90/// # #[cfg(all(target_os = "macos", feature = "metal-native"))]
91/// # {
92/// use hive_gpu::metal::MetalNativeContext;
93/// use hive_gpu::traits::GpuContext;
94///
95/// let context = MetalNativeContext::new().expect("Failed to create Metal context");
96/// let info = context.device_info().expect("Failed to get device info");
97///
98/// println!("Device: {}", info.name);
99/// println!("Backend: {}", info.backend);
100/// println!("VRAM: {} MB", info.total_vram_bytes / 1024 / 1024);
101/// println!("Usage: {:.1}%", info.vram_usage_percent());
102/// # }
103/// ```
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct GpuDeviceInfo {
106    /// Device name (e.g., "Apple M2 Pro", "NVIDIA RTX 4090")
107    pub name: String,
108
109    /// Backend type (e.g., "Metal", "CUDA", "ROCm")
110    pub backend: String,
111
112    /// Total VRAM in bytes
113    pub total_vram_bytes: u64,
114
115    /// Currently available VRAM in bytes
116    pub available_vram_bytes: u64,
117
118    /// Currently used VRAM in bytes (calculated as total - available)
119    pub used_vram_bytes: u64,
120
121    /// Driver version string (e.g., "macOS 14.1", "CUDA 12.0", "ROCm 5.4")
122    pub driver_version: String,
123
124    /// Compute capability or architecture version
125    /// - Metal: None
126    /// - CUDA: e.g., "8.9" for sm_89
127    /// - ROCm: e.g., "gfx1030"
128    pub compute_capability: Option<String>,
129
130    /// Maximum threads per block/workgroup
131    pub max_threads_per_block: u32,
132
133    /// Maximum shared memory per block (in bytes)
134    pub max_shared_memory_per_block: u64,
135
136    /// Device ID (0-indexed)
137    pub device_id: i32,
138
139    /// PCI bus ID (e.g., "0000:01:00.0")
140    /// None for Metal (Apple Silicon doesn't expose PCI)
141    pub pci_bus_id: Option<String>,
142}
143
144impl GpuDeviceInfo {
145    /// Calculate VRAM usage percentage (0.0 to 100.0)
146    ///
147    /// # Examples
148    ///
149    /// ```no_run
150    /// # use hive_gpu::types::GpuDeviceInfo;
151    /// # let info = GpuDeviceInfo {
152    /// #     name: "Test GPU".to_string(),
153    /// #     backend: "Test".to_string(),
154    /// #     total_vram_bytes: 16 * 1024 * 1024 * 1024,
155    /// #     available_vram_bytes: 8 * 1024 * 1024 * 1024,
156    /// #     used_vram_bytes: 8 * 1024 * 1024 * 1024,
157    /// #     driver_version: "1.0".to_string(),
158    /// #     compute_capability: None,
159    /// #     max_threads_per_block: 1024,
160    /// #     max_shared_memory_per_block: 49152,
161    /// #     device_id: 0,
162    /// #     pci_bus_id: None,
163    /// # };
164    /// let usage = info.vram_usage_percent();
165    /// assert!(usage >= 0.0 && usage <= 100.0);
166    /// ```
167    pub fn vram_usage_percent(&self) -> f64 {
168        if self.total_vram_bytes == 0 {
169            return 0.0;
170        }
171        (self.used_vram_bytes as f64 / self.total_vram_bytes as f64) * 100.0
172    }
173
174    /// Check if there is sufficient available VRAM
175    ///
176    /// # Arguments
177    ///
178    /// * `required_bytes` - Minimum required VRAM in bytes
179    ///
180    /// # Examples
181    ///
182    /// ```no_run
183    /// # use hive_gpu::types::GpuDeviceInfo;
184    /// # let info = GpuDeviceInfo {
185    /// #     name: "Test GPU".to_string(),
186    /// #     backend: "Test".to_string(),
187    /// #     total_vram_bytes: 16 * 1024 * 1024 * 1024,
188    /// #     available_vram_bytes: 8 * 1024 * 1024 * 1024,
189    /// #     used_vram_bytes: 8 * 1024 * 1024 * 1024,
190    /// #     driver_version: "1.0".to_string(),
191    /// #     compute_capability: None,
192    /// #     max_threads_per_block: 1024,
193    /// #     max_shared_memory_per_block: 49152,
194    /// #     device_id: 0,
195    /// #     pci_bus_id: None,
196    /// # };
197    /// // Check if we have at least 1GB available
198    /// if info.has_available_vram(1 * 1024 * 1024 * 1024) {
199    ///     println!("Sufficient VRAM available");
200    /// }
201    /// ```
202    pub fn has_available_vram(&self, required_bytes: u64) -> bool {
203        self.available_vram_bytes >= required_bytes
204    }
205
206    /// Get VRAM available in megabytes (convenience method)
207    ///
208    /// # Examples
209    ///
210    /// ```no_run
211    /// # use hive_gpu::types::GpuDeviceInfo;
212    /// # let info = GpuDeviceInfo {
213    /// #     name: "Test GPU".to_string(),
214    /// #     backend: "Test".to_string(),
215    /// #     total_vram_bytes: 16 * 1024 * 1024 * 1024,
216    /// #     available_vram_bytes: 8 * 1024 * 1024 * 1024,
217    /// #     used_vram_bytes: 8 * 1024 * 1024 * 1024,
218    /// #     driver_version: "1.0".to_string(),
219    /// #     compute_capability: None,
220    /// #     max_threads_per_block: 1024,
221    /// #     max_shared_memory_per_block: 49152,
222    /// #     device_id: 0,
223    /// #     pci_bus_id: None,
224    /// # };
225    /// println!("Available: {} MB", info.available_vram_mb());
226    /// ```
227    pub fn available_vram_mb(&self) -> u64 {
228        self.available_vram_bytes / (1024 * 1024)
229    }
230
231    /// Get total VRAM in megabytes (convenience method)
232    ///
233    /// # Examples
234    ///
235    /// ```no_run
236    /// # use hive_gpu::types::GpuDeviceInfo;
237    /// # let info = GpuDeviceInfo {
238    /// #     name: "Test GPU".to_string(),
239    /// #     backend: "Test".to_string(),
240    /// #     total_vram_bytes: 16 * 1024 * 1024 * 1024,
241    /// #     available_vram_bytes: 8 * 1024 * 1024 * 1024,
242    /// #     used_vram_bytes: 8 * 1024 * 1024 * 1024,
243    /// #     driver_version: "1.0".to_string(),
244    /// #     compute_capability: None,
245    /// #     max_threads_per_block: 1024,
246    /// #     max_shared_memory_per_block: 49152,
247    /// #     device_id: 0,
248    /// #     pci_bus_id: None,
249    /// # };
250    /// println!("Total: {} MB", info.total_vram_mb());
251    /// ```
252    pub fn total_vram_mb(&self) -> u64 {
253        self.total_vram_bytes / (1024 * 1024)
254    }
255}
256
257/// GPU capabilities
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct GpuCapabilities {
260    /// Supports HNSW operations
261    pub supports_hnsw: bool,
262    /// Supports batch operations
263    pub supports_batch: bool,
264    /// Maximum vector dimension
265    pub max_dimension: usize,
266    /// Maximum vectors per batch
267    pub max_batch_size: usize,
268}
269
270/// GPU memory statistics
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct GpuMemoryStats {
273    /// Total allocated memory in bytes
274    pub total_allocated: usize,
275    /// Available memory in bytes
276    pub available: usize,
277    /// Memory utilization percentage (0.0-1.0)
278    pub utilization: f32,
279    /// Number of active buffers
280    pub buffer_count: usize,
281}
282
283/// HNSW configuration for GPU operations
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct HnswConfig {
286    /// Number of bidirectional links created for each node
287    pub max_connections: usize,
288    /// Size of the dynamic list for nearest neighbors (construction)
289    pub ef_construction: usize,
290    /// Size of the dynamic list for nearest neighbors (search)
291    pub ef_search: usize,
292    /// Maximum level in the hierarchy
293    pub max_level: usize,
294    /// Level assignment multiplier
295    pub level_multiplier: f32,
296    /// Random seed for level assignment
297    pub seed: Option<u64>,
298}
299
300impl Default for HnswConfig {
301    fn default() -> Self {
302        Self {
303            max_connections: 16,
304            ef_construction: 100,
305            ef_search: 50,
306            max_level: 8,
307            level_multiplier: 0.5,
308            seed: None,
309        }
310    }
311}
312
313/// Vector metadata for GPU operations
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct VectorMetadata {
316    /// Original vector ID
317    pub original_id: String,
318    /// Index in storage
319    pub index: usize,
320    /// Timestamp of creation
321    pub timestamp: u64,
322}