halldyll_deploy_pods/runpod/
provisioner.rs

1//! Pod provisioner for creating and managing `RunPod` pods.
2//!
3//! This module handles the provisioning logic for pods, including
4//! resource mapping, creation, and lifecycle management.
5
6use std::collections::HashMap;
7use tracing::{debug, info, warn};
8
9use crate::config::{CloudType, GpuConfig, PodConfig, PortConfig, ProjectConfig, RuntimeConfig};
10use crate::error::{HalldyllError, Result, RunPodError};
11
12use super::client::RunPodClient;
13use super::types::{CreatePodRequest, Pod, PodStatus};
14
15/// Default volume size in GB.
16const DEFAULT_VOLUME_GB: u32 = 50;
17
18/// Default container disk size in GB.
19const DEFAULT_CONTAINER_DISK_GB: u32 = 20;
20
21/// Pod provisioner for managing `RunPod` pods.
22#[derive(Debug)]
23pub struct PodProvisioner {
24    /// `RunPod` API client.
25    client: RunPodClient,
26    /// GPU type mapping (display name -> ID).
27    gpu_type_map: HashMap<String, String>,
28}
29
30impl PodProvisioner {
31    /// Creates a new pod provisioner.
32    #[must_use]
33    pub fn new(client: RunPodClient) -> Self {
34        Self {
35            client,
36            gpu_type_map: HashMap::new(),
37        }
38    }
39
40    /// Initializes the GPU type mapping by fetching available types.
41    ///
42    /// # Errors
43    ///
44    /// Returns an error if the GPU types cannot be fetched.
45    pub async fn init_gpu_types(&mut self) -> Result<()> {
46        info!("Fetching available GPU types");
47
48        let gpu_types = self.client.list_gpu_types().await?;
49
50        self.gpu_type_map.clear();
51        for gpu in gpu_types {
52            // Map both ID and display name to the ID
53            self.gpu_type_map
54                .insert(gpu.display_name.clone(), gpu.id.clone());
55            self.gpu_type_map.insert(gpu.id.clone(), gpu.id);
56        }
57
58        debug!("Loaded {} GPU type mappings", self.gpu_type_map.len());
59        Ok(())
60    }
61
62    /// Resolves a GPU type name to its `RunPod` ID.
63    fn resolve_gpu_type(&self, gpu_type: &str) -> Option<&String> {
64        self.gpu_type_map.get(gpu_type)
65    }
66
67    /// Creates a pod from a pod configuration.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the pod cannot be created.
72    pub async fn create_pod(
73        &self,
74        pod_config: &PodConfig,
75        project: &ProjectConfig,
76        spec_hash: &str,
77    ) -> Result<Pod> {
78        let full_name = pod_config.full_name(project);
79        info!("Creating pod: {full_name}");
80
81        // Resolve GPU type
82        let gpu_type_id = self
83            .resolve_gpu_type_with_fallback(&pod_config.gpu, &project.cloud_type)
84            .await?;
85
86        // Build the create request
87        let request = Self::build_create_request(pod_config, project, &gpu_type_id, spec_hash);
88
89        // Create the pod
90        let pod = self.client.create_pod(&request).await?;
91
92        info!(
93            "Created pod: {} (ID: {})",
94            full_name, pod.id
95        );
96
97        Ok(pod)
98    }
99
100    /// Creates a pod and performs post-provisioning setup (model download, engine start).
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if the pod cannot be created or setup fails.
105    pub async fn create_pod_with_setup(
106        &self,
107        pod_config: &PodConfig,
108        project: &ProjectConfig,
109        spec_hash: &str,
110    ) -> Result<(Pod, Option<super::executor::PostProvisionResult>)> {
111        // Create the pod first
112        let pod = self.create_pod(pod_config, project, spec_hash).await?;
113
114        // If there are models to setup, do post-provisioning
115        if !pod_config.models.is_empty() {
116            info!("Starting post-provisioning setup for pod {}", pod.id);
117            
118            let executor = super::executor::PodExecutor::new(self.client.clone());
119            
120            match executor.post_provision_setup(&pod.id, pod_config).await {
121                Ok(result) => {
122                    info!("Post-provisioning completed: {}", result.summary());
123                    return Ok((pod, Some(result)));
124                }
125                Err(e) => {
126                    warn!("Post-provisioning failed (pod still running): {}", e);
127                    // Don't fail the whole operation, the pod is still usable
128                }
129            }
130        }
131
132        Ok((pod, None))
133    }
134
135    /// Resolves GPU type with fallback support.
136    async fn resolve_gpu_type_with_fallback(
137        &self,
138        gpu_config: &GpuConfig,
139        cloud_type: &CloudType,
140    ) -> Result<String> {
141        let cloud_type_str = match cloud_type {
142            CloudType::Secure => "SECURE",
143            CloudType::Community => "COMMUNITY",
144        };
145
146        // Try primary GPU type
147        if let Some(gpu_id) = self.resolve_gpu_type(&gpu_config.gpu_type) {
148            if self
149                .client
150                .is_gpu_available(gpu_id, cloud_type_str)
151                .await?
152            {
153                debug!(
154                    "Using primary GPU type: {} ({})",
155                    gpu_config.gpu_type, gpu_id
156                );
157                return Ok(gpu_id.clone());
158            }
159            warn!(
160                "Primary GPU type {} not available in {} cloud",
161                gpu_config.gpu_type, cloud_type_str
162            );
163        }
164
165        // Try fallback GPU types
166        for fallback in &gpu_config.fallback {
167            if let Some(gpu_id) = self.resolve_gpu_type(fallback) {
168                if self
169                    .client
170                    .is_gpu_available(gpu_id, cloud_type_str)
171                    .await?
172                {
173                    info!(
174                        "Using fallback GPU type: {} ({})",
175                        fallback, gpu_id
176                    );
177                    return Ok(gpu_id.clone());
178                }
179                debug!("Fallback GPU type {fallback} not available");
180            }
181        }
182
183        Err(HalldyllError::RunPod(RunPodError::GpuNotAvailable {
184            gpu_type: gpu_config.gpu_type.clone(),
185            region: cloud_type_str.to_string(),
186        }))
187    }
188
189    /// Builds a pod creation request from configuration.
190    fn build_create_request(
191        pod_config: &PodConfig,
192        project: &ProjectConfig,
193        gpu_type_id: &str,
194        spec_hash: &str,
195    ) -> CreatePodRequest {
196        let full_name = pod_config.full_name(project);
197
198        // Build ports string
199        let ports = Self::build_ports_string(&pod_config.ports);
200
201        // Calculate volume size
202        let volume_gb = pod_config
203            .volumes
204            .iter()
205            .filter_map(|v| v.size_gb)
206            .max()
207            .unwrap_or_default();
208        let volume_gb = if volume_gb == 0 { DEFAULT_VOLUME_GB } else { volume_gb };
209
210        // Get primary volume mount path
211        let mount_path = pod_config
212            .volumes
213            .first()
214            .map(|v| v.mount.clone());
215
216        // Build environment variables
217        let env = Self::build_env_vars(&pod_config.runtime);
218
219        // Build tags
220        let tags = Self::build_tags(pod_config, project, spec_hash);
221
222        let cloud_type = match project.cloud_type {
223            CloudType::Secure => "SECURE",
224            CloudType::Community => "COMMUNITY",
225        };
226
227        let mut request = CreatePodRequest::new(&full_name, gpu_type_id, &pod_config.runtime.image)
228            .with_cloud_type(cloud_type)
229            .with_gpu_count(pod_config.gpu.count)
230            .with_volume_gb(volume_gb)
231            .with_container_disk_gb(DEFAULT_CONTAINER_DISK_GB)
232            .with_ports(&ports)
233            .with_env_map(env)
234            .with_tags(tags);
235
236        if let Some(path) = mount_path {
237            request = request.with_mount_path(&path);
238        }
239
240        request
241    }
242
243    /// Builds the ports string for the API request.
244    fn build_ports_string(ports: &[PortConfig]) -> String {
245        if ports.is_empty() {
246            return String::from("8000/http");
247        }
248
249        ports
250            .iter()
251            .map(|p| {
252                let protocol = match p.protocol {
253                    crate::config::PortProtocol::Tcp => "tcp",
254                    crate::config::PortProtocol::Http | crate::config::PortProtocol::Https => "http",
255                    crate::config::PortProtocol::Udp => "udp",
256                };
257                format!("{}/{protocol}", p.port)
258            })
259            .collect::<Vec<_>>()
260            .join(",")
261    }
262
263    /// Builds environment variables map.
264    fn build_env_vars(runtime: &RuntimeConfig) -> HashMap<String, String> {
265        let mut env = runtime.env.clone();
266
267        // Add HF token if available
268        if let Ok(hf_token) = std::env::var("HF_TOKEN") {
269            env.entry(String::from("HF_TOKEN"))
270                .or_insert(hf_token);
271        }
272
273        env
274    }
275
276    /// Builds tags for the pod.
277    fn build_tags(
278        pod_config: &PodConfig,
279        project: &ProjectConfig,
280        spec_hash: &str,
281    ) -> HashMap<String, String> {
282        let mut tags = pod_config.tags.clone();
283
284        // Add system tags
285        tags.insert(String::from("halldyll_project"), project.name.clone());
286        tags.insert(String::from("halldyll_env"), project.environment.clone());
287        tags.insert(String::from("halldyll_pod"), pod_config.name.clone());
288        tags.insert(String::from("halldyll_spec_hash"), spec_hash.to_string());
289
290        tags
291    }
292
293    /// Terminates a pod.
294    ///
295    /// # Errors
296    ///
297    /// Returns an error if the pod cannot be terminated.
298    pub async fn terminate_pod(&self, pod_id: &str) -> Result<()> {
299        info!("Terminating pod: {pod_id}");
300        self.client.terminate_pod(pod_id).await?;
301        info!("Pod terminated: {pod_id}");
302        Ok(())
303    }
304
305    /// Stops a pod (keeps it for later restart).
306    ///
307    /// # Errors
308    ///
309    /// Returns an error if the pod cannot be stopped.
310    pub async fn stop_pod(&self, pod_id: &str) -> Result<()> {
311        info!("Stopping pod: {pod_id}");
312        self.client.stop_pod(pod_id).await?;
313        info!("Pod stopped: {pod_id}");
314        Ok(())
315    }
316
317    /// Resumes a stopped pod.
318    ///
319    /// # Errors
320    ///
321    /// Returns an error if the pod cannot be resumed.
322    pub async fn resume_pod(&self, pod_id: &str) -> Result<Pod> {
323        info!("Resuming pod: {pod_id}");
324        let pod = self.client.resume_pod(pod_id).await?;
325        info!("Pod resumed: {pod_id}");
326        Ok(pod)
327    }
328
329    /// Waits for a pod to reach a specific status.
330    ///
331    /// # Errors
332    ///
333    /// Returns an error if the timeout is reached or the API call fails.
334    pub async fn wait_for_status(
335        &self,
336        pod_id: &str,
337        expected_status: PodStatus,
338        timeout_secs: u64,
339    ) -> Result<Pod> {
340        let start = std::time::Instant::now();
341        let timeout = std::time::Duration::from_secs(timeout_secs);
342
343        loop {
344            let pod = self.client.get_pod(pod_id).await?;
345
346            if pod.desired_status == expected_status {
347                return Ok(pod);
348            }
349
350            if start.elapsed() > timeout {
351                return Err(HalldyllError::RunPod(RunPodError::Timeout {
352                    pod_id: pod_id.to_string(),
353                    expected_state: expected_status.to_string(),
354                }));
355            }
356
357            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
358        }
359    }
360
361    /// Gets the underlying client reference.
362    #[must_use]
363    pub const fn client(&self) -> &RunPodClient {
364        &self.client
365    }
366}