halldyll_deploy_pods/runpod/
provisioner.rs1use std::collections::HashMap;
7use tracing::{debug, info, warn};
8
9use crate::config::{CloudType, GpuConfig, PodConfig, PortConfig, ProjectConfig, RuntimeConfig};
10use crate::error::{HalldyllError, Result, RunPodError};
11
12use super::client::RunPodClient;
13use super::types::{CreatePodRequest, Pod, PodStatus};
14
15const DEFAULT_VOLUME_GB: u32 = 50;
17
18const DEFAULT_CONTAINER_DISK_GB: u32 = 20;
20
21#[derive(Debug)]
23pub struct PodProvisioner {
24 client: RunPodClient,
26 gpu_type_map: HashMap<String, String>,
28}
29
30impl PodProvisioner {
31 #[must_use]
33 pub fn new(client: RunPodClient) -> Self {
34 Self {
35 client,
36 gpu_type_map: HashMap::new(),
37 }
38 }
39
40 pub async fn init_gpu_types(&mut self) -> Result<()> {
46 info!("Fetching available GPU types");
47
48 let gpu_types = self.client.list_gpu_types().await?;
49
50 self.gpu_type_map.clear();
51 for gpu in gpu_types {
52 self.gpu_type_map
54 .insert(gpu.display_name.clone(), gpu.id.clone());
55 self.gpu_type_map.insert(gpu.id.clone(), gpu.id);
56 }
57
58 debug!("Loaded {} GPU type mappings", self.gpu_type_map.len());
59 Ok(())
60 }
61
62 fn resolve_gpu_type(&self, gpu_type: &str) -> Option<&String> {
64 self.gpu_type_map.get(gpu_type)
65 }
66
67 pub async fn create_pod(
73 &self,
74 pod_config: &PodConfig,
75 project: &ProjectConfig,
76 spec_hash: &str,
77 ) -> Result<Pod> {
78 let full_name = pod_config.full_name(project);
79 info!("Creating pod: {full_name}");
80
81 let gpu_type_id = self
83 .resolve_gpu_type_with_fallback(&pod_config.gpu, &project.cloud_type)
84 .await?;
85
86 let request = Self::build_create_request(pod_config, project, &gpu_type_id, spec_hash);
88
89 let pod = self.client.create_pod(&request).await?;
91
92 info!(
93 "Created pod: {} (ID: {})",
94 full_name, pod.id
95 );
96
97 Ok(pod)
98 }
99
100 pub async fn create_pod_with_setup(
106 &self,
107 pod_config: &PodConfig,
108 project: &ProjectConfig,
109 spec_hash: &str,
110 ) -> Result<(Pod, Option<super::executor::PostProvisionResult>)> {
111 let pod = self.create_pod(pod_config, project, spec_hash).await?;
113
114 if !pod_config.models.is_empty() {
116 info!("Starting post-provisioning setup for pod {}", pod.id);
117
118 let executor = super::executor::PodExecutor::new(self.client.clone());
119
120 match executor.post_provision_setup(&pod.id, pod_config).await {
121 Ok(result) => {
122 info!("Post-provisioning completed: {}", result.summary());
123 return Ok((pod, Some(result)));
124 }
125 Err(e) => {
126 warn!("Post-provisioning failed (pod still running): {}", e);
127 }
129 }
130 }
131
132 Ok((pod, None))
133 }
134
135 async fn resolve_gpu_type_with_fallback(
137 &self,
138 gpu_config: &GpuConfig,
139 cloud_type: &CloudType,
140 ) -> Result<String> {
141 let cloud_type_str = match cloud_type {
142 CloudType::Secure => "SECURE",
143 CloudType::Community => "COMMUNITY",
144 };
145
146 if let Some(gpu_id) = self.resolve_gpu_type(&gpu_config.gpu_type) {
148 if self
149 .client
150 .is_gpu_available(gpu_id, cloud_type_str)
151 .await?
152 {
153 debug!(
154 "Using primary GPU type: {} ({})",
155 gpu_config.gpu_type, gpu_id
156 );
157 return Ok(gpu_id.clone());
158 }
159 warn!(
160 "Primary GPU type {} not available in {} cloud",
161 gpu_config.gpu_type, cloud_type_str
162 );
163 }
164
165 for fallback in &gpu_config.fallback {
167 if let Some(gpu_id) = self.resolve_gpu_type(fallback) {
168 if self
169 .client
170 .is_gpu_available(gpu_id, cloud_type_str)
171 .await?
172 {
173 info!(
174 "Using fallback GPU type: {} ({})",
175 fallback, gpu_id
176 );
177 return Ok(gpu_id.clone());
178 }
179 debug!("Fallback GPU type {fallback} not available");
180 }
181 }
182
183 Err(HalldyllError::RunPod(RunPodError::GpuNotAvailable {
184 gpu_type: gpu_config.gpu_type.clone(),
185 region: cloud_type_str.to_string(),
186 }))
187 }
188
189 fn build_create_request(
191 pod_config: &PodConfig,
192 project: &ProjectConfig,
193 gpu_type_id: &str,
194 spec_hash: &str,
195 ) -> CreatePodRequest {
196 let full_name = pod_config.full_name(project);
197
198 let ports = Self::build_ports_string(&pod_config.ports);
200
201 let volume_gb = pod_config
203 .volumes
204 .iter()
205 .filter_map(|v| v.size_gb)
206 .max()
207 .unwrap_or_default();
208 let volume_gb = if volume_gb == 0 { DEFAULT_VOLUME_GB } else { volume_gb };
209
210 let mount_path = pod_config
212 .volumes
213 .first()
214 .map(|v| v.mount.clone());
215
216 let env = Self::build_env_vars(&pod_config.runtime);
218
219 let tags = Self::build_tags(pod_config, project, spec_hash);
221
222 let cloud_type = match project.cloud_type {
223 CloudType::Secure => "SECURE",
224 CloudType::Community => "COMMUNITY",
225 };
226
227 let mut request = CreatePodRequest::new(&full_name, gpu_type_id, &pod_config.runtime.image)
228 .with_cloud_type(cloud_type)
229 .with_gpu_count(pod_config.gpu.count)
230 .with_volume_gb(volume_gb)
231 .with_container_disk_gb(DEFAULT_CONTAINER_DISK_GB)
232 .with_ports(&ports)
233 .with_env_map(env)
234 .with_tags(tags);
235
236 if let Some(path) = mount_path {
237 request = request.with_mount_path(&path);
238 }
239
240 request
241 }
242
243 fn build_ports_string(ports: &[PortConfig]) -> String {
245 if ports.is_empty() {
246 return String::from("8000/http");
247 }
248
249 ports
250 .iter()
251 .map(|p| {
252 let protocol = match p.protocol {
253 crate::config::PortProtocol::Tcp => "tcp",
254 crate::config::PortProtocol::Http | crate::config::PortProtocol::Https => "http",
255 crate::config::PortProtocol::Udp => "udp",
256 };
257 format!("{}/{protocol}", p.port)
258 })
259 .collect::<Vec<_>>()
260 .join(",")
261 }
262
263 fn build_env_vars(runtime: &RuntimeConfig) -> HashMap<String, String> {
265 let mut env = runtime.env.clone();
266
267 if let Ok(hf_token) = std::env::var("HF_TOKEN") {
269 env.entry(String::from("HF_TOKEN"))
270 .or_insert(hf_token);
271 }
272
273 env
274 }
275
276 fn build_tags(
278 pod_config: &PodConfig,
279 project: &ProjectConfig,
280 spec_hash: &str,
281 ) -> HashMap<String, String> {
282 let mut tags = pod_config.tags.clone();
283
284 tags.insert(String::from("halldyll_project"), project.name.clone());
286 tags.insert(String::from("halldyll_env"), project.environment.clone());
287 tags.insert(String::from("halldyll_pod"), pod_config.name.clone());
288 tags.insert(String::from("halldyll_spec_hash"), spec_hash.to_string());
289
290 tags
291 }
292
293 pub async fn terminate_pod(&self, pod_id: &str) -> Result<()> {
299 info!("Terminating pod: {pod_id}");
300 self.client.terminate_pod(pod_id).await?;
301 info!("Pod terminated: {pod_id}");
302 Ok(())
303 }
304
305 pub async fn stop_pod(&self, pod_id: &str) -> Result<()> {
311 info!("Stopping pod: {pod_id}");
312 self.client.stop_pod(pod_id).await?;
313 info!("Pod stopped: {pod_id}");
314 Ok(())
315 }
316
317 pub async fn resume_pod(&self, pod_id: &str) -> Result<Pod> {
323 info!("Resuming pod: {pod_id}");
324 let pod = self.client.resume_pod(pod_id).await?;
325 info!("Pod resumed: {pod_id}");
326 Ok(pod)
327 }
328
329 pub async fn wait_for_status(
335 &self,
336 pod_id: &str,
337 expected_status: PodStatus,
338 timeout_secs: u64,
339 ) -> Result<Pod> {
340 let start = std::time::Instant::now();
341 let timeout = std::time::Duration::from_secs(timeout_secs);
342
343 loop {
344 let pod = self.client.get_pod(pod_id).await?;
345
346 if pod.desired_status == expected_status {
347 return Ok(pod);
348 }
349
350 if start.elapsed() > timeout {
351 return Err(HalldyllError::RunPod(RunPodError::Timeout {
352 pod_id: pod_id.to_string(),
353 expected_state: expected_status.to_string(),
354 }));
355 }
356
357 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
358 }
359 }
360
361 #[must_use]
363 pub const fn client(&self) -> &RunPodClient {
364 &self.client
365 }
366}