1use crate::error::MoldError;
8use anyhow::Result;
9use reqwest::{Client, StatusCode};
10use serde::{Deserialize, Serialize};
11use std::fmt;
12use std::time::Duration;
13
14pub const DEFAULT_ENDPOINT: &str = "https://rest.runpod.io/v1";
16
17pub const GRAPHQL_ENDPOINT: &str = "https://api.runpod.io/graphql";
19
20pub const API_KEY_ENV: &str = "RUNPOD_API_KEY";
22
23#[derive(Debug, Clone, Deserialize, Serialize, Default)]
25pub struct RunPodSettings {
26 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub api_key: Option<String>,
29
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub default_gpu: Option<String>,
33
34 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub default_datacenter: Option<String>,
37
38 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub default_network_volume_id: Option<String>,
41
42 #[serde(default)]
45 pub auto_teardown: bool,
46
47 #[serde(default = "default_auto_teardown_idle_mins")]
50 pub auto_teardown_idle_mins: u32,
51
52 #[serde(default)]
55 pub cost_alert_usd: f64,
56
57 #[serde(default, skip_serializing_if = "Option::is_none")]
59 pub endpoint: Option<String>,
60}
61
62fn default_auto_teardown_idle_mins() -> u32 {
63 20
64}
65
66impl RunPodSettings {
68 pub fn redacted_debug(&self) -> String {
69 format!(
70 "RunPodSettings {{ api_key: {}, default_gpu: {:?}, default_datacenter: {:?}, \
71 default_network_volume_id: {:?}, auto_teardown: {}, auto_teardown_idle_mins: {}, \
72 cost_alert_usd: {}, endpoint: {:?} }}",
73 if self.api_key.is_some() {
74 "Some(\"<redacted>\")"
75 } else {
76 "None"
77 },
78 self.default_gpu,
79 self.default_datacenter,
80 self.default_network_volume_id,
81 self.auto_teardown,
82 self.auto_teardown_idle_mins,
83 self.cost_alert_usd,
84 self.endpoint,
85 )
86 }
87}
88
89#[derive(Debug, Clone, Deserialize, Serialize)]
93pub struct UserInfo {
94 pub id: String,
95 pub email: String,
96 #[serde(default)]
97 pub client_balance: f64,
98 #[serde(default)]
99 pub current_spend_per_hr: f64,
100 #[serde(default)]
101 pub spend_limit: Option<f64>,
102}
103
104#[derive(Debug, Clone, Deserialize, Serialize)]
106pub struct GpuType {
107 #[serde(default)]
108 pub id: Option<String>,
109 #[serde(rename = "displayName", default)]
110 pub display_name: String,
111 #[serde(rename = "gpuId", default)]
112 pub gpu_id: String,
113 #[serde(rename = "memoryInGb", default)]
114 pub memory_in_gb: u32,
115 #[serde(rename = "secureCloud", default)]
116 pub secure_cloud: bool,
117 #[serde(rename = "communityCloud", default)]
118 pub community_cloud: bool,
119 #[serde(rename = "stockStatus", default)]
120 pub stock_status: Option<String>,
121 #[serde(default)]
122 pub available: bool,
123}
124
125#[derive(Debug, Clone, Deserialize, Serialize)]
127pub struct Datacenter {
128 pub id: String,
129 #[serde(default)]
130 pub name: String,
131 #[serde(default)]
132 pub location: Option<String>,
133 #[serde(rename = "gpuAvailability", default)]
134 pub gpu_availability: Vec<GpuAvailability>,
135}
136
137#[derive(Debug, Clone, Deserialize, Serialize)]
138pub struct GpuAvailability {
139 #[serde(rename = "displayName", default)]
140 pub display_name: String,
141 #[serde(rename = "gpuId", default)]
142 pub gpu_id: String,
143 #[serde(rename = "stockStatus", default)]
144 pub stock_status: Option<String>,
145}
146
147#[derive(Debug, Clone, Deserialize, Serialize)]
152pub struct Pod {
153 pub id: String,
154 #[serde(default)]
155 pub name: Option<String>,
156 #[serde(rename = "desiredStatus", default)]
157 pub desired_status: String,
158 #[serde(rename = "imageName", default)]
159 pub image_name: Option<String>,
160 #[serde(rename = "gpuCount", default)]
161 pub gpu_count: u32,
162 #[serde(rename = "costPerHr", default)]
163 pub cost_per_hr: f64,
164 #[serde(rename = "uptimeSeconds", default)]
165 pub uptime_seconds: u64,
166 #[serde(rename = "lastStatusChange", default)]
167 pub last_status_change: Option<String>,
168 #[serde(rename = "memoryInGb", default)]
169 pub memory_in_gb: u32,
170 #[serde(rename = "vcpuCount", default)]
171 pub vcpu_count: u32,
172 #[serde(rename = "volumeInGb", default)]
173 pub volume_in_gb: u32,
174 #[serde(rename = "volumeMountPath", default)]
175 pub volume_mount_path: Option<String>,
176 #[serde(default)]
177 pub ports: serde_json::Value,
178 #[serde(default)]
179 pub env: serde_json::Value,
180 #[serde(default)]
181 pub machine: Option<PodMachine>,
182 #[serde(default)]
183 pub runtime: Option<serde_json::Value>,
184}
185
186#[derive(Debug, Clone, Deserialize, Serialize)]
187pub struct PodMachine {
188 #[serde(rename = "gpuDisplayName", default)]
189 pub gpu_display_name: Option<String>,
190 #[serde(default)]
191 pub location: Option<String>,
192}
193
194#[derive(Debug, Clone, Serialize, Default)]
196pub struct CreatePodRequest {
197 pub name: String,
198 #[serde(rename = "imageName")]
199 pub image_name: String,
200 #[serde(rename = "gpuTypeIds")]
201 pub gpu_type_ids: Vec<String>,
202 #[serde(rename = "cloudType")]
203 pub cloud_type: String,
204 #[serde(rename = "dataCenterIds", skip_serializing_if = "Option::is_none")]
205 pub data_center_ids: Option<Vec<String>>,
206 #[serde(rename = "gpuCount")]
207 pub gpu_count: u32,
208 #[serde(rename = "containerDiskInGb")]
209 pub container_disk_in_gb: u32,
210 #[serde(rename = "volumeInGb")]
211 pub volume_in_gb: u32,
212 #[serde(rename = "volumeMountPath")]
213 pub volume_mount_path: String,
214 pub ports: Vec<String>,
215 pub env: serde_json::Map<String, serde_json::Value>,
216 #[serde(rename = "networkVolumeId", skip_serializing_if = "Option::is_none")]
217 pub network_volume_id: Option<String>,
218}
219
220#[derive(Debug, Clone, Deserialize, Serialize)]
222pub struct NetworkVolume {
223 pub id: String,
224 pub name: String,
225 #[serde(rename = "dataCenterId", default)]
226 pub data_center_id: String,
227 pub size: u32,
228}
229
230#[derive(Clone)]
233pub struct RunPodClient {
234 endpoint: String,
235 graphql_endpoint: String,
236 api_key: String,
237 http: Client,
238}
239
240impl fmt::Debug for RunPodClient {
241 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242 f.debug_struct("RunPodClient")
243 .field("endpoint", &self.endpoint)
244 .field("api_key", &"<redacted>")
245 .finish()
246 }
247}
248
249impl RunPodClient {
250 pub fn new(endpoint: impl Into<String>, api_key: impl Into<String>) -> Self {
255 let rest = endpoint.into();
256 let graphql = if rest.starts_with(DEFAULT_ENDPOINT) {
257 GRAPHQL_ENDPOINT.to_string()
258 } else {
259 rest.clone()
260 };
261 Self::new_with_graphql(rest, graphql, api_key)
262 }
263
264 pub fn new_with_graphql(
266 endpoint: impl Into<String>,
267 graphql_endpoint: impl Into<String>,
268 api_key: impl Into<String>,
269 ) -> Self {
270 let http = Client::builder()
271 .timeout(Duration::from_secs(30))
272 .build()
273 .unwrap_or_default();
274 Self {
275 endpoint: endpoint.into(),
276 graphql_endpoint: graphql_endpoint.into(),
277 api_key: api_key.into(),
278 http,
279 }
280 }
281
282 pub fn from_settings(settings: &RunPodSettings) -> std::result::Result<Self, MoldError> {
285 let key = std::env::var(API_KEY_ENV)
286 .ok()
287 .filter(|k| !k.is_empty())
288 .or_else(|| settings.api_key.clone())
289 .ok_or_else(|| {
290 MoldError::RunPodAuth(format!(
291 "RunPod API key not set — export {API_KEY_ENV} or run \
292 `mold config set runpod.api_key <key>`"
293 ))
294 })?;
295 let endpoint = settings
296 .endpoint
297 .clone()
298 .unwrap_or_else(|| DEFAULT_ENDPOINT.to_string());
299 Ok(Self::new(endpoint, key))
300 }
301
302 fn url(&self, path: &str) -> String {
303 format!("{}{}", self.endpoint.trim_end_matches('/'), path)
304 }
305
306 async fn get_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
307 let resp = self
308 .http
309 .get(self.url(path))
310 .bearer_auth(&self.api_key)
311 .send()
312 .await
313 .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
314 let status = resp.status();
315 if status.is_success() {
316 let body = resp
317 .text()
318 .await
319 .map_err(|e| MoldError::RunPod(format!("RunPod {path} body: {e}")))?;
320 serde_json::from_str(&body).map_err(|e| {
321 MoldError::RunPod(format!(
322 "RunPod {path}: failed to parse response: {e} — body: {}",
323 truncate_for_error(&body)
324 ))
325 .into()
326 })
327 } else {
328 Err(http_error(path, status, resp).await.into())
329 }
330 }
331
332 async fn post_json<B: Serialize, T: for<'de> Deserialize<'de>>(
333 &self,
334 path: &str,
335 body: &B,
336 ) -> Result<T> {
337 let resp = self
338 .http
339 .post(self.url(path))
340 .bearer_auth(&self.api_key)
341 .json(body)
342 .send()
343 .await
344 .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
345 let status = resp.status();
346 if status.is_success() {
347 let text = resp
348 .text()
349 .await
350 .map_err(|e| MoldError::RunPod(format!("RunPod {path} body: {e}")))?;
351 serde_json::from_str(&text).map_err(|e| {
352 MoldError::RunPod(format!(
353 "RunPod {path}: failed to parse response: {e} — body: {}",
354 truncate_for_error(&text)
355 ))
356 .into()
357 })
358 } else {
359 Err(http_error(path, status, resp).await.into())
360 }
361 }
362
363 async fn post_empty(&self, path: &str) -> Result<()> {
364 let resp = self
365 .http
366 .post(self.url(path))
367 .bearer_auth(&self.api_key)
368 .send()
369 .await
370 .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
371 let status = resp.status();
372 if status.is_success() {
373 Ok(())
374 } else {
375 Err(http_error(path, status, resp).await.into())
376 }
377 }
378
379 async fn delete(&self, path: &str) -> Result<()> {
380 let resp = self
381 .http
382 .delete(self.url(path))
383 .bearer_auth(&self.api_key)
384 .send()
385 .await
386 .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
387 let status = resp.status();
388 if status.is_success() {
389 Ok(())
390 } else {
391 Err(http_error(path, status, resp).await.into())
392 }
393 }
394
395 async fn get_text(&self, path: &str) -> Result<String> {
396 let resp = self
397 .http
398 .get(self.url(path))
399 .bearer_auth(&self.api_key)
400 .send()
401 .await
402 .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
403 let status = resp.status();
404 if status.is_success() {
405 Ok(resp
406 .text()
407 .await
408 .map_err(|e| MoldError::RunPod(format!("RunPod {path} body: {e}")))?)
409 } else {
410 Err(http_error(path, status, resp).await.into())
411 }
412 }
413
414 pub async fn user(&self) -> Result<UserInfo> {
419 let query = serde_json::json!({
420 "query": "query { myself { id email clientBalance currentSpendPerHr spendLimit } }"
421 });
422 let resp = self
423 .http
424 .post(&self.graphql_endpoint)
425 .bearer_auth(&self.api_key)
426 .json(&query)
427 .send()
428 .await
429 .map_err(|e| MoldError::RunPod(format!("RunPod graphql /user: {e}")))?;
430 let status = resp.status();
431 if !status.is_success() {
432 return Err(http_error("graphql /user", status, resp).await.into());
433 }
434 let body: serde_json::Value = resp
435 .json()
436 .await
437 .map_err(|e| MoldError::RunPod(format!("RunPod graphql /user json: {e}")))?;
438 if let Some(errs) = body.get("errors") {
439 return Err(MoldError::RunPod(format!("RunPod graphql errors: {errs}")).into());
440 }
441 let myself = body
442 .get("data")
443 .and_then(|d| d.get("myself"))
444 .ok_or_else(|| MoldError::RunPod("graphql: missing data.myself".into()))?;
445 let info = UserInfo {
446 id: myself
447 .get("id")
448 .and_then(|v| v.as_str())
449 .unwrap_or("")
450 .to_string(),
451 email: myself
452 .get("email")
453 .and_then(|v| v.as_str())
454 .unwrap_or("")
455 .to_string(),
456 client_balance: myself
457 .get("clientBalance")
458 .and_then(|v| v.as_f64())
459 .unwrap_or(0.0),
460 current_spend_per_hr: myself
461 .get("currentSpendPerHr")
462 .and_then(|v| v.as_f64())
463 .unwrap_or(0.0),
464 spend_limit: myself.get("spendLimit").and_then(|v| v.as_f64()),
465 };
466 Ok(info)
467 }
468
469 pub async fn gpu_types(&self) -> Result<Vec<GpuType>> {
472 let query = serde_json::json!({
473 "query": "query { gpuTypes { id displayName memoryInGb secureCloud communityCloud } dataCenters { gpuAvailability { displayName stockStatus } } }"
474 });
475 let body = self.graphql(&query).await?;
476 let data = body
477 .get("data")
478 .ok_or_else(|| MoldError::RunPod("graphql: missing data".into()))?;
479 let types: Vec<GpuType> = serde_json::from_value(
480 data.get("gpuTypes")
481 .cloned()
482 .unwrap_or(serde_json::Value::Array(vec![])),
483 )
484 .map_err(|e| MoldError::RunPod(format!("parse gpuTypes: {e}")))?;
485 let mut best_stock: std::collections::HashMap<String, String> =
487 std::collections::HashMap::new();
488 if let Some(dcs) = data.get("dataCenters").and_then(|v| v.as_array()) {
489 for dc in dcs {
490 if let Some(avail) = dc.get("gpuAvailability").and_then(|v| v.as_array()) {
491 for a in avail {
492 if let (Some(name), Some(stock)) = (
493 a.get("displayName").and_then(|v| v.as_str()),
494 a.get("stockStatus").and_then(|v| v.as_str()),
495 ) {
496 let current = best_stock.get(name).cloned().unwrap_or_default();
497 if stock_rank(stock) > stock_rank(¤t) {
498 best_stock.insert(name.to_string(), stock.to_string());
499 }
500 }
501 }
502 }
503 }
504 }
505 let mut out = types;
506 for g in out.iter_mut() {
507 if let Some(s) = best_stock.get(&g.display_name) {
508 if !s.is_empty() {
509 g.stock_status = Some(s.clone());
510 }
511 }
512 g.available = g.stock_status.as_deref().is_some_and(|s| s != "None");
513 }
514 Ok(out)
515 }
516
517 pub async fn datacenters(&self) -> Result<Vec<Datacenter>> {
519 let query = serde_json::json!({
520 "query": "query { dataCenters { id name listed gpuAvailability { id displayName stockStatus } } }"
521 });
522 let body = self.graphql(&query).await?;
523 let arr = body
524 .get("data")
525 .and_then(|d| d.get("dataCenters"))
526 .cloned()
527 .unwrap_or(serde_json::Value::Array(vec![]));
528 let arr = match arr {
530 serde_json::Value::Array(mut dcs) => {
531 for dc in dcs.iter_mut() {
532 if let Some(avail) =
533 dc.get_mut("gpuAvailability").and_then(|v| v.as_array_mut())
534 {
535 for a in avail.iter_mut() {
536 if let Some(id) = a.get("id").and_then(|v| v.as_str()) {
537 let id = id.to_string();
538 if let Some(obj) = a.as_object_mut() {
539 obj.insert("gpuId".into(), serde_json::Value::String(id));
540 }
541 }
542 }
543 }
544 }
545 serde_json::Value::Array(dcs)
546 }
547 other => other,
548 };
549 let dcs: Vec<Datacenter> = serde_json::from_value(arr)
550 .map_err(|e| MoldError::RunPod(format!("parse dataCenters: {e}")))?;
551 Ok(dcs)
552 }
553
554 async fn graphql(&self, query: &serde_json::Value) -> Result<serde_json::Value> {
555 let resp = self
556 .http
557 .post(&self.graphql_endpoint)
558 .bearer_auth(&self.api_key)
559 .json(query)
560 .send()
561 .await
562 .map_err(|e| MoldError::RunPod(format!("RunPod graphql: {e}")))?;
563 let status = resp.status();
564 if !status.is_success() {
565 return Err(http_error("graphql", status, resp).await.into());
566 }
567 let body: serde_json::Value = resp
568 .json()
569 .await
570 .map_err(|e| MoldError::RunPod(format!("graphql body: {e}")))?;
571 if let Some(errs) = body
572 .get("errors")
573 .filter(|e| !e.as_array().map(|a| a.is_empty()).unwrap_or(true))
574 {
575 return Err(MoldError::RunPod(format!("graphql errors: {errs}")).into());
576 }
577 Ok(body)
578 }
579
580 pub async fn list_pods(&self) -> Result<Vec<Pod>> {
581 self.get_json("/pods").await
582 }
583
584 pub async fn get_pod(&self, id: &str) -> Result<Pod> {
585 self.get_json(&format!("/pods/{id}")).await
586 }
587
588 pub async fn create_pod(&self, req: &CreatePodRequest) -> Result<Pod> {
589 self.post_json("/pods", req).await
590 }
591
592 pub async fn stop_pod(&self, id: &str) -> Result<()> {
593 self.post_empty(&format!("/pods/{id}/stop")).await
594 }
595
596 pub async fn start_pod(&self, id: &str) -> Result<()> {
597 self.post_empty(&format!("/pods/{id}/start")).await
598 }
599
600 pub async fn delete_pod(&self, id: &str) -> Result<()> {
601 self.delete(&format!("/pods/{id}")).await
602 }
603
604 pub async fn pod_logs(&self, id: &str) -> Result<String> {
605 self.get_text(&format!("/pods/{id}/logs")).await
606 }
607
608 pub async fn network_volumes(&self) -> Result<Vec<NetworkVolume>> {
609 self.get_json("/networkvolumes").await
610 }
611}
612
613async fn http_error(path: &str, status: StatusCode, resp: reqwest::Response) -> MoldError {
616 let body = resp.text().await.unwrap_or_default();
617 let msg = truncate_for_error(&body);
618 match status {
619 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
620 MoldError::RunPodAuth(format!("RunPod {path} {status}: {msg}"))
621 }
622 StatusCode::NOT_FOUND => {
623 MoldError::RunPodNotFound(format!("RunPod {path} {status}: {msg}"))
624 }
625 StatusCode::CONFLICT | StatusCode::SERVICE_UNAVAILABLE
626 if msg.to_lowercase().contains("does not have the resources") =>
627 {
628 MoldError::RunPodNoStock(format!("RunPod {path} {status}: {msg}"))
629 }
630 _ => MoldError::RunPod(format!("RunPod {path} {status}: {msg}")),
631 }
632}
633
634fn stock_rank(s: &str) -> u8 {
635 match s {
636 "High" => 3,
637 "Medium" => 2,
638 "Low" => 1,
639 _ => 0,
640 }
641}
642
643fn truncate_for_error(s: &str) -> String {
644 const MAX: usize = 400;
645 let s = s.trim();
646 if s.len() <= MAX {
647 s.to_string()
648 } else {
649 format!("{}…", &s[..MAX])
650 }
651}
652
653pub fn image_tag_for_gpu(display_name: &str) -> &'static str {
656 let d = display_name.to_lowercase();
657 if d.contains("5090") || d.contains("blackwell") || d.contains("b200") {
658 "latest-sm120"
659 } else if d.contains("a100") || d.contains("3090") || d.contains("a40") || d.contains("ampere")
660 {
661 "latest-sm80"
662 } else {
663 "latest"
665 }
666}
667
668pub const GPU_PREFERENCE: &[&str] = &[
670 "A100 PCIe",
671 "L40",
672 "L40S",
673 "RTX A6000",
674 "RTX 5090",
675 "RTX 4090",
676];
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681
682 #[test]
683 fn image_tag_mapping() {
684 assert_eq!(image_tag_for_gpu("RTX 4090"), "latest");
685 assert_eq!(image_tag_for_gpu("NVIDIA GeForce RTX 4090"), "latest");
686 assert_eq!(image_tag_for_gpu("L40S"), "latest");
687 assert_eq!(image_tag_for_gpu("RTX 5090"), "latest-sm120");
688 assert_eq!(image_tag_for_gpu("NVIDIA GeForce RTX 5090"), "latest-sm120");
689 assert_eq!(image_tag_for_gpu("A100 80GB"), "latest-sm80");
690 assert_eq!(image_tag_for_gpu("A100 PCIe"), "latest-sm80");
691 assert_eq!(image_tag_for_gpu("RTX 3090"), "latest-sm80");
692 }
693
694 #[test]
695 fn redacted_debug_hides_api_key() {
696 let s = RunPodSettings {
697 api_key: Some("secret-key".to_string()),
698 ..Default::default()
699 };
700 let out = s.redacted_debug();
701 assert!(!out.contains("secret-key"));
702 assert!(out.contains("<redacted>"));
703 }
704
705 #[test]
706 fn from_settings_requires_key() {
707 std::env::remove_var(API_KEY_ENV);
708 let err = RunPodClient::from_settings(&RunPodSettings::default()).unwrap_err();
709 assert!(matches!(err, MoldError::RunPodAuth(_)));
710 }
711
712 #[test]
713 fn truncate_for_error_boundary() {
714 let short = "short";
715 assert_eq!(truncate_for_error(short), "short");
716 let long = "x".repeat(500);
717 let truncated = truncate_for_error(&long);
718 assert!(truncated.ends_with('…'));
719 assert!(truncated.chars().count() <= 401);
720 }
721
722 #[test]
723 fn runpod_settings_toml_roundtrip() {
724 let original = RunPodSettings {
725 api_key: Some("k".to_string()),
726 default_gpu: Some("RTX 5090".to_string()),
727 default_datacenter: Some("EUR-IS-2".to_string()),
728 default_network_volume_id: Some("nv-123".to_string()),
729 auto_teardown: true,
730 auto_teardown_idle_mins: 30,
731 cost_alert_usd: 3.5,
732 endpoint: None,
733 };
734 let toml_s = toml::to_string(&original).unwrap();
735 let round: RunPodSettings = toml::from_str(&toml_s).unwrap();
736 assert_eq!(round.api_key, original.api_key);
737 assert_eq!(round.default_gpu, original.default_gpu);
738 assert_eq!(round.default_datacenter, original.default_datacenter);
739 assert_eq!(
740 round.default_network_volume_id,
741 original.default_network_volume_id
742 );
743 assert_eq!(round.auto_teardown, original.auto_teardown);
744 assert_eq!(
745 round.auto_teardown_idle_mins,
746 original.auto_teardown_idle_mins
747 );
748 assert_eq!(round.cost_alert_usd, original.cost_alert_usd);
749 }
750
751 #[test]
752 fn default_auto_teardown_idle_mins_is_20() {
753 let s: RunPodSettings = toml::from_str("").unwrap();
754 assert_eq!(s.auto_teardown_idle_mins, 20);
755 }
756}