alith_devices/devices/
gpu.rs

1#[derive(Debug, Default)]
2pub struct GpuDevice {
3    pub ordinal: u32,
4    pub available_vram_bytes: u64,
5    pub allocated_layer_bytes: u64,
6    pub allocated_buffer_bytes: u64,
7    pub allocated_layers: u64,
8    pub is_main_gpu: bool,
9}
10
11impl GpuDevice {
12    fn can_allocate(&self, layer_size: u64) -> bool {
13        self.available_vram_bytes >= self.allocated_layer_bytes + layer_size
14    }
15
16    fn allocate_layer(&mut self, layer_size: u64) {
17        self.allocated_layers += 1;
18        self.allocated_layer_bytes += layer_size;
19    }
20}
21
22pub struct GpuLayerAllocator {
23    layer_size: u64,
24    total_layers: u64,
25    buffer_layer_per_gpu: u64,
26    buffer_layer_main_gpu: u64,
27}
28
29impl GpuLayerAllocator {
30    pub fn new(
31        layer_size: u64,
32        total_layers: u64,
33        buffer_layer_per_gpu: u64,
34        buffer_layer_main_gpu: u64,
35    ) -> Self {
36        GpuLayerAllocator {
37            layer_size,
38            total_layers,
39            buffer_layer_per_gpu,
40            buffer_layer_main_gpu,
41        }
42    }
43
44    pub fn allocate(&self, gpus: &mut [GpuDevice]) -> crate::Result<()> {
45        // Sort GPUs by available VRAM, descending
46        gpus.sort_by_key(|gpu| std::cmp::Reverse(gpu.available_vram_bytes));
47
48        // Calculate total available VRAM
49        let total_available_vram: u64 = gpus.iter().map(|gpu| gpu.available_vram_bytes).sum();
50
51        let mut buffer_layers = 0;
52        // Allocate buffer layers
53        for gpu in gpus.iter_mut() {
54            for _ in 1..=self.buffer_layer_per_gpu {
55                buffer_layers += 1;
56                gpu.allocate_layer(self.layer_size);
57                gpu.allocated_buffer_bytes += self.layer_size;
58            }
59            if gpu.is_main_gpu {
60                for _ in 1..=self.buffer_layer_main_gpu {
61                    buffer_layers += 1;
62                    gpu.allocate_layer(self.layer_size);
63                    gpu.allocated_buffer_bytes += self.layer_size;
64                }
65            }
66        }
67
68        let total_required_vram = (self.total_layers + buffer_layers) * self.layer_size;
69
70        // Check if there's enough total VRAM
71        if total_available_vram < total_required_vram {
72            crate::bail!(
73                "Insufficient total VRAM. Required: {}GB, Available: {}GB",
74                total_required_vram / 1_073_741_824,
75                total_available_vram / 1_073_741_824
76            );
77        }
78
79        let mut allocation = vec![0; gpus.len()];
80        let result = self.dfs_allocate(gpus, &mut allocation, 0, self.total_layers);
81        Self::print_allocation(gpus);
82        if !result {
83            // Check why allocation failed
84            let allocated_layers: u64 = gpus.iter().map(|gpu| gpu.allocated_layers).sum();
85            let remaining_layers = self.total_layers - (allocated_layers - buffer_layers);
86
87            if remaining_layers > 0 {
88                crate::bail!(
89                    "Failed to allocate all layers. {} layers remaining unallocated.",
90                    remaining_layers
91                );
92            } else {
93                crate::bail!("Allocation failed due to VRAM fragmentation across GPUs.");
94            }
95        }
96        Ok(())
97    }
98
99    fn dfs_allocate(
100        &self,
101        gpus: &mut [GpuDevice],
102        allocation: &mut Vec<u64>,
103        gpu_index: usize,
104        remaining_layers: u64,
105    ) -> bool {
106        if remaining_layers == 0 {
107            return true;
108        }
109
110        // Try to allocate to each GPU in a round-robin fashion
111        for i in 0..gpus.len() {
112            let current_gpu_index = (gpu_index + i) % gpus.len();
113            if gpus[current_gpu_index].can_allocate(self.layer_size) {
114                gpus[current_gpu_index].allocate_layer(self.layer_size);
115                allocation[current_gpu_index] += 1;
116
117                if self.dfs_allocate(
118                    gpus,
119                    allocation,
120                    (current_gpu_index + 1) % gpus.len(),
121                    remaining_layers - 1,
122                ) {
123                    return true;
124                }
125
126                // If allocation failed, backtrack
127                gpus[current_gpu_index].allocated_layers -= 1;
128                gpus[current_gpu_index].allocated_layer_bytes -= self.layer_size;
129                allocation[current_gpu_index] -= 1;
130            }
131        }
132
133        false
134    }
135
136    fn print_allocation(gpus: &[GpuDevice]) {
137        let message = std::fmt::format(format_args!(
138            "\nGPU Allocation:\n{}",
139            gpus.iter()
140                .map(|gpu| format!("{}", gpu))
141                .collect::<Vec<_>>()
142                .join("\n")
143        ));
144        crate::info!("{}", message);
145    }
146}
147
148impl std::fmt::Display for GpuDevice {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        crate::i_nlns(
151            f,
152            &[
153                format_args!("Ordinal: {}", self.ordinal),
154                format_args!("Main GPU: {}", self.is_main_gpu),
155                format_args!("Allocated Layers: {}", self.allocated_layers),
156                format_args!(
157                    "Layer Size: {:.2} GB",
158                    self.allocated_layer_bytes as f64 / 1_073_741_824.0
159                ),
160                format_args!(
161                    "Buffer Size: {:.2} GB",
162                    self.allocated_buffer_bytes as f64 / 1_073_741_824.0
163                ),
164                format_args!(
165                    "Total Size: {:.2} GB",
166                    (self.allocated_layer_bytes + self.allocated_buffer_bytes) as f64
167                        / 1_073_741_824.0
168                ),
169                format_args!(
170                    "Available VRAM: {:.2} GB",
171                    self.available_vram_bytes as f64 / 1_073_741_824.0
172                ),
173            ],
174        )
175    }
176}