Skip to main content

mold_core/
runpod.rs

1//! RunPod REST API client.
2//!
3//! Wraps `https://rest.runpod.io/v1/` for pod lifecycle management from within
4//! the mold CLI. Uses Bearer-token auth via `RUNPOD_API_KEY` env var or an
5//! explicit key. All methods are async.
6
7use crate::error::MoldError;
8use anyhow::Result;
9use reqwest::{Client, StatusCode};
10use serde::{Deserialize, Serialize};
11use std::fmt;
12use std::time::Duration;
13
14/// Default REST base URL.
15pub const DEFAULT_ENDPOINT: &str = "https://rest.runpod.io/v1";
16
17/// GraphQL endpoint (used for /user since REST doesn't expose it).
18pub const GRAPHQL_ENDPOINT: &str = "https://api.runpod.io/graphql";
19
20/// Environment variable that holds the RunPod API key.
21pub const API_KEY_ENV: &str = "RUNPOD_API_KEY";
22
23/// Persisted configuration under `[runpod]` in `config.toml`.
24#[derive(Debug, Clone, Deserialize, Serialize, Default)]
25pub struct RunPodSettings {
26    /// API key stored in config. Env var `RUNPOD_API_KEY` takes precedence.
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub api_key: Option<String>,
29
30    /// Preferred GPU, e.g. `"NVIDIA GeForce RTX 5090"`.
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub default_gpu: Option<String>,
33
34    /// Preferred datacenter id, e.g. `"EUR-IS-2"`.
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub default_datacenter: Option<String>,
37
38    /// Attach this network volume to new pods (id from RunPod console).
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub default_network_volume_id: Option<String>,
41
42    /// When `true`, `mold runpod run` deletes the pod after generating.
43    /// When `false`, the pod is left running for reuse. Default `false`.
44    #[serde(default)]
45    pub auto_teardown: bool,
46
47    /// After this many minutes of idle time, background reap deletes the pod.
48    /// `0` disables the idle reaper. Default `20`.
49    #[serde(default = "default_auto_teardown_idle_mins")]
50    pub auto_teardown_idle_mins: u32,
51
52    /// Fail UAT or `run` if cumulative pod spend for the session exceeds this
53    /// many USD. `0.0` disables the guard. Default `0.0`.
54    #[serde(default)]
55    pub cost_alert_usd: f64,
56
57    /// Override the REST endpoint (mostly for testing).
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub endpoint: Option<String>,
60}
61
62fn default_auto_teardown_idle_mins() -> u32 {
63    20
64}
65
66/// Redact `api_key` when logging.
67impl RunPodSettings {
68    pub fn redacted_debug(&self) -> String {
69        format!(
70            "RunPodSettings {{ api_key: {}, default_gpu: {:?}, default_datacenter: {:?}, \
71             default_network_volume_id: {:?}, auto_teardown: {}, auto_teardown_idle_mins: {}, \
72             cost_alert_usd: {}, endpoint: {:?} }}",
73            if self.api_key.is_some() {
74                "Some(\"<redacted>\")"
75            } else {
76                "None"
77            },
78            self.default_gpu,
79            self.default_datacenter,
80            self.default_network_volume_id,
81            self.auto_teardown,
82            self.auto_teardown_idle_mins,
83            self.cost_alert_usd,
84            self.endpoint,
85        )
86    }
87}
88
89// ─── API response types ────────────────────────────────────────────────────
90
91/// `GET /user` response.
92#[derive(Debug, Clone, Deserialize, Serialize)]
93pub struct UserInfo {
94    pub id: String,
95    pub email: String,
96    #[serde(default)]
97    pub client_balance: f64,
98    #[serde(default)]
99    pub current_spend_per_hr: f64,
100    #[serde(default)]
101    pub spend_limit: Option<f64>,
102}
103
104/// One entry from `GET /gputypes`.
105#[derive(Debug, Clone, Deserialize, Serialize)]
106pub struct GpuType {
107    #[serde(default)]
108    pub id: Option<String>,
109    #[serde(rename = "displayName", default)]
110    pub display_name: String,
111    #[serde(rename = "gpuId", default)]
112    pub gpu_id: String,
113    #[serde(rename = "memoryInGb", default)]
114    pub memory_in_gb: u32,
115    #[serde(rename = "secureCloud", default)]
116    pub secure_cloud: bool,
117    #[serde(rename = "communityCloud", default)]
118    pub community_cloud: bool,
119    #[serde(rename = "stockStatus", default)]
120    pub stock_status: Option<String>,
121    #[serde(default)]
122    pub available: bool,
123}
124
125/// One entry from `GET /datacenters`.
126#[derive(Debug, Clone, Deserialize, Serialize)]
127pub struct Datacenter {
128    pub id: String,
129    #[serde(default)]
130    pub name: String,
131    #[serde(default)]
132    pub location: Option<String>,
133    #[serde(rename = "gpuAvailability", default)]
134    pub gpu_availability: Vec<GpuAvailability>,
135}
136
137#[derive(Debug, Clone, Deserialize, Serialize)]
138pub struct GpuAvailability {
139    #[serde(rename = "displayName", default)]
140    pub display_name: String,
141    #[serde(rename = "gpuId", default)]
142    pub gpu_id: String,
143    #[serde(rename = "stockStatus", default)]
144    pub stock_status: Option<String>,
145}
146
147/// `GET /pods` / `GET /pods/{id}` response.
148///
149/// Only fields we actually use are deserialized — anything else is allowed via
150/// `#[serde(default)]` on the struct to avoid breaking on RunPod API drift.
151#[derive(Debug, Clone, Deserialize, Serialize)]
152pub struct Pod {
153    pub id: String,
154    #[serde(default)]
155    pub name: Option<String>,
156    #[serde(rename = "desiredStatus", default)]
157    pub desired_status: String,
158    #[serde(rename = "imageName", default)]
159    pub image_name: Option<String>,
160    #[serde(rename = "gpuCount", default)]
161    pub gpu_count: u32,
162    #[serde(rename = "costPerHr", default)]
163    pub cost_per_hr: f64,
164    #[serde(rename = "uptimeSeconds", default)]
165    pub uptime_seconds: u64,
166    #[serde(rename = "lastStatusChange", default)]
167    pub last_status_change: Option<String>,
168    #[serde(rename = "memoryInGb", default)]
169    pub memory_in_gb: u32,
170    #[serde(rename = "vcpuCount", default)]
171    pub vcpu_count: u32,
172    #[serde(rename = "volumeInGb", default)]
173    pub volume_in_gb: u32,
174    #[serde(rename = "volumeMountPath", default)]
175    pub volume_mount_path: Option<String>,
176    #[serde(default)]
177    pub ports: serde_json::Value,
178    #[serde(default)]
179    pub env: serde_json::Value,
180    #[serde(default)]
181    pub machine: Option<PodMachine>,
182    #[serde(default)]
183    pub runtime: Option<serde_json::Value>,
184}
185
186#[derive(Debug, Clone, Deserialize, Serialize)]
187pub struct PodMachine {
188    #[serde(rename = "gpuDisplayName", default)]
189    pub gpu_display_name: Option<String>,
190    #[serde(default)]
191    pub location: Option<String>,
192}
193
194/// Body for `POST /pods`.
195#[derive(Debug, Clone, Serialize, Default)]
196pub struct CreatePodRequest {
197    pub name: String,
198    #[serde(rename = "imageName")]
199    pub image_name: String,
200    #[serde(rename = "gpuTypeIds")]
201    pub gpu_type_ids: Vec<String>,
202    #[serde(rename = "cloudType")]
203    pub cloud_type: String,
204    #[serde(rename = "dataCenterIds", skip_serializing_if = "Option::is_none")]
205    pub data_center_ids: Option<Vec<String>>,
206    #[serde(rename = "gpuCount")]
207    pub gpu_count: u32,
208    #[serde(rename = "containerDiskInGb")]
209    pub container_disk_in_gb: u32,
210    #[serde(rename = "volumeInGb")]
211    pub volume_in_gb: u32,
212    #[serde(rename = "volumeMountPath")]
213    pub volume_mount_path: String,
214    pub ports: Vec<String>,
215    pub env: serde_json::Map<String, serde_json::Value>,
216    #[serde(rename = "networkVolumeId", skip_serializing_if = "Option::is_none")]
217    pub network_volume_id: Option<String>,
218}
219
220/// One entry from `GET /networkvolumes`.
221#[derive(Debug, Clone, Deserialize, Serialize)]
222pub struct NetworkVolume {
223    pub id: String,
224    pub name: String,
225    #[serde(rename = "dataCenterId", default)]
226    pub data_center_id: String,
227    pub size: u32,
228}
229
230// ─── Client ────────────────────────────────────────────────────────────────
231
232#[derive(Clone)]
233pub struct RunPodClient {
234    endpoint: String,
235    graphql_endpoint: String,
236    api_key: String,
237    http: Client,
238}
239
240impl fmt::Debug for RunPodClient {
241    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242        f.debug_struct("RunPodClient")
243            .field("endpoint", &self.endpoint)
244            .field("api_key", &"<redacted>")
245            .finish()
246    }
247}
248
249impl RunPodClient {
250    /// Construct with explicit endpoint + key. The GraphQL endpoint
251    /// defaults to `GRAPHQL_ENDPOINT` when the REST endpoint is production,
252    /// and falls back to the same URL when the REST endpoint is overridden
253    /// (so tests pointing at a mock server route GraphQL calls there too).
254    pub fn new(endpoint: impl Into<String>, api_key: impl Into<String>) -> Self {
255        let rest = endpoint.into();
256        let graphql = if rest.starts_with(DEFAULT_ENDPOINT) {
257            GRAPHQL_ENDPOINT.to_string()
258        } else {
259            rest.clone()
260        };
261        Self::new_with_graphql(rest, graphql, api_key)
262    }
263
264    /// Construct with explicit REST + GraphQL endpoints.
265    pub fn new_with_graphql(
266        endpoint: impl Into<String>,
267        graphql_endpoint: impl Into<String>,
268        api_key: impl Into<String>,
269    ) -> Self {
270        let http = Client::builder()
271            .timeout(Duration::from_secs(30))
272            .build()
273            .unwrap_or_default();
274        Self {
275            endpoint: endpoint.into(),
276            graphql_endpoint: graphql_endpoint.into(),
277            api_key: api_key.into(),
278            http,
279        }
280    }
281
282    /// Construct from config + environment. `RUNPOD_API_KEY` overrides
283    /// `settings.api_key`. Returns `RunPodAuth` error if no key is available.
284    pub fn from_settings(settings: &RunPodSettings) -> std::result::Result<Self, MoldError> {
285        let key = std::env::var(API_KEY_ENV)
286            .ok()
287            .filter(|k| !k.is_empty())
288            .or_else(|| settings.api_key.clone())
289            .ok_or_else(|| {
290                MoldError::RunPodAuth(format!(
291                    "RunPod API key not set — export {API_KEY_ENV} or run \
292                     `mold config set runpod.api_key <key>`"
293                ))
294            })?;
295        let endpoint = settings
296            .endpoint
297            .clone()
298            .unwrap_or_else(|| DEFAULT_ENDPOINT.to_string());
299        Ok(Self::new(endpoint, key))
300    }
301
302    fn url(&self, path: &str) -> String {
303        format!("{}{}", self.endpoint.trim_end_matches('/'), path)
304    }
305
306    async fn get_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
307        let resp = self
308            .http
309            .get(self.url(path))
310            .bearer_auth(&self.api_key)
311            .send()
312            .await
313            .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
314        let status = resp.status();
315        if status.is_success() {
316            let body = resp
317                .text()
318                .await
319                .map_err(|e| MoldError::RunPod(format!("RunPod {path} body: {e}")))?;
320            serde_json::from_str(&body).map_err(|e| {
321                MoldError::RunPod(format!(
322                    "RunPod {path}: failed to parse response: {e} — body: {}",
323                    truncate_for_error(&body)
324                ))
325                .into()
326            })
327        } else {
328            Err(http_error(path, status, resp).await.into())
329        }
330    }
331
332    async fn post_json<B: Serialize, T: for<'de> Deserialize<'de>>(
333        &self,
334        path: &str,
335        body: &B,
336    ) -> Result<T> {
337        let resp = self
338            .http
339            .post(self.url(path))
340            .bearer_auth(&self.api_key)
341            .json(body)
342            .send()
343            .await
344            .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
345        let status = resp.status();
346        if status.is_success() {
347            let text = resp
348                .text()
349                .await
350                .map_err(|e| MoldError::RunPod(format!("RunPod {path} body: {e}")))?;
351            serde_json::from_str(&text).map_err(|e| {
352                MoldError::RunPod(format!(
353                    "RunPod {path}: failed to parse response: {e} — body: {}",
354                    truncate_for_error(&text)
355                ))
356                .into()
357            })
358        } else {
359            Err(http_error(path, status, resp).await.into())
360        }
361    }
362
363    async fn post_empty(&self, path: &str) -> Result<()> {
364        let resp = self
365            .http
366            .post(self.url(path))
367            .bearer_auth(&self.api_key)
368            .send()
369            .await
370            .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
371        let status = resp.status();
372        if status.is_success() {
373            Ok(())
374        } else {
375            Err(http_error(path, status, resp).await.into())
376        }
377    }
378
379    async fn delete(&self, path: &str) -> Result<()> {
380        let resp = self
381            .http
382            .delete(self.url(path))
383            .bearer_auth(&self.api_key)
384            .send()
385            .await
386            .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
387        let status = resp.status();
388        if status.is_success() {
389            Ok(())
390        } else {
391            Err(http_error(path, status, resp).await.into())
392        }
393    }
394
395    async fn get_text(&self, path: &str) -> Result<String> {
396        let resp = self
397            .http
398            .get(self.url(path))
399            .bearer_auth(&self.api_key)
400            .send()
401            .await
402            .map_err(|e| MoldError::RunPod(format!("RunPod {path}: {e}")))?;
403        let status = resp.status();
404        if status.is_success() {
405            Ok(resp
406                .text()
407                .await
408                .map_err(|e| MoldError::RunPod(format!("RunPod {path} body: {e}")))?)
409        } else {
410            Err(http_error(path, status, resp).await.into())
411        }
412    }
413
414    // ─── Typed endpoints ────────────────────────────────────────────
415
416    /// User/account info isn't exposed by the REST API, so we fall back to
417    /// the GraphQL endpoint (same API key works for both).
418    pub async fn user(&self) -> Result<UserInfo> {
419        let query = serde_json::json!({
420            "query": "query { myself { id email clientBalance currentSpendPerHr spendLimit } }"
421        });
422        let resp = self
423            .http
424            .post(&self.graphql_endpoint)
425            .bearer_auth(&self.api_key)
426            .json(&query)
427            .send()
428            .await
429            .map_err(|e| MoldError::RunPod(format!("RunPod graphql /user: {e}")))?;
430        let status = resp.status();
431        if !status.is_success() {
432            return Err(http_error("graphql /user", status, resp).await.into());
433        }
434        let body: serde_json::Value = resp
435            .json()
436            .await
437            .map_err(|e| MoldError::RunPod(format!("RunPod graphql /user json: {e}")))?;
438        if let Some(errs) = body.get("errors") {
439            return Err(MoldError::RunPod(format!("RunPod graphql errors: {errs}")).into());
440        }
441        let myself = body
442            .get("data")
443            .and_then(|d| d.get("myself"))
444            .ok_or_else(|| MoldError::RunPod("graphql: missing data.myself".into()))?;
445        let info = UserInfo {
446            id: myself
447                .get("id")
448                .and_then(|v| v.as_str())
449                .unwrap_or("")
450                .to_string(),
451            email: myself
452                .get("email")
453                .and_then(|v| v.as_str())
454                .unwrap_or("")
455                .to_string(),
456            client_balance: myself
457                .get("clientBalance")
458                .and_then(|v| v.as_f64())
459                .unwrap_or(0.0),
460            current_spend_per_hr: myself
461                .get("currentSpendPerHr")
462                .and_then(|v| v.as_f64())
463                .unwrap_or(0.0),
464            spend_limit: myself.get("spendLimit").and_then(|v| v.as_f64()),
465        };
466        Ok(info)
467    }
468
469    /// Query GPU types via GraphQL (not exposed in REST v1).
470    /// Stock status is aggregated: the highest stock level across all DCs.
471    pub async fn gpu_types(&self) -> Result<Vec<GpuType>> {
472        let query = serde_json::json!({
473            "query": "query { gpuTypes { id displayName memoryInGb secureCloud communityCloud } dataCenters { gpuAvailability { displayName stockStatus } } }"
474        });
475        let body = self.graphql(&query).await?;
476        let data = body
477            .get("data")
478            .ok_or_else(|| MoldError::RunPod("graphql: missing data".into()))?;
479        let types: Vec<GpuType> = serde_json::from_value(
480            data.get("gpuTypes")
481                .cloned()
482                .unwrap_or(serde_json::Value::Array(vec![])),
483        )
484        .map_err(|e| MoldError::RunPod(format!("parse gpuTypes: {e}")))?;
485        // Aggregate stock across datacenters.
486        let mut best_stock: std::collections::HashMap<String, String> =
487            std::collections::HashMap::new();
488        if let Some(dcs) = data.get("dataCenters").and_then(|v| v.as_array()) {
489            for dc in dcs {
490                if let Some(avail) = dc.get("gpuAvailability").and_then(|v| v.as_array()) {
491                    for a in avail {
492                        if let (Some(name), Some(stock)) = (
493                            a.get("displayName").and_then(|v| v.as_str()),
494                            a.get("stockStatus").and_then(|v| v.as_str()),
495                        ) {
496                            let current = best_stock.get(name).cloned().unwrap_or_default();
497                            if stock_rank(stock) > stock_rank(&current) {
498                                best_stock.insert(name.to_string(), stock.to_string());
499                            }
500                        }
501                    }
502                }
503            }
504        }
505        let mut out = types;
506        for g in out.iter_mut() {
507            if let Some(s) = best_stock.get(&g.display_name) {
508                if !s.is_empty() {
509                    g.stock_status = Some(s.clone());
510                }
511            }
512            g.available = g.stock_status.as_deref().is_some_and(|s| s != "None");
513        }
514        Ok(out)
515    }
516
517    /// Query datacenters with per-GPU availability via GraphQL.
518    pub async fn datacenters(&self) -> Result<Vec<Datacenter>> {
519        let query = serde_json::json!({
520            "query": "query { dataCenters { id name listed gpuAvailability { id displayName stockStatus } } }"
521        });
522        let body = self.graphql(&query).await?;
523        let arr = body
524            .get("data")
525            .and_then(|d| d.get("dataCenters"))
526            .cloned()
527            .unwrap_or(serde_json::Value::Array(vec![]));
528        // Map GraphQL `id` → `gpuId` so we can reuse the same Datacenter type.
529        let arr = match arr {
530            serde_json::Value::Array(mut dcs) => {
531                for dc in dcs.iter_mut() {
532                    if let Some(avail) =
533                        dc.get_mut("gpuAvailability").and_then(|v| v.as_array_mut())
534                    {
535                        for a in avail.iter_mut() {
536                            if let Some(id) = a.get("id").and_then(|v| v.as_str()) {
537                                let id = id.to_string();
538                                if let Some(obj) = a.as_object_mut() {
539                                    obj.insert("gpuId".into(), serde_json::Value::String(id));
540                                }
541                            }
542                        }
543                    }
544                }
545                serde_json::Value::Array(dcs)
546            }
547            other => other,
548        };
549        let dcs: Vec<Datacenter> = serde_json::from_value(arr)
550            .map_err(|e| MoldError::RunPod(format!("parse dataCenters: {e}")))?;
551        Ok(dcs)
552    }
553
554    async fn graphql(&self, query: &serde_json::Value) -> Result<serde_json::Value> {
555        let resp = self
556            .http
557            .post(&self.graphql_endpoint)
558            .bearer_auth(&self.api_key)
559            .json(query)
560            .send()
561            .await
562            .map_err(|e| MoldError::RunPod(format!("RunPod graphql: {e}")))?;
563        let status = resp.status();
564        if !status.is_success() {
565            return Err(http_error("graphql", status, resp).await.into());
566        }
567        let body: serde_json::Value = resp
568            .json()
569            .await
570            .map_err(|e| MoldError::RunPod(format!("graphql body: {e}")))?;
571        if let Some(errs) = body
572            .get("errors")
573            .filter(|e| !e.as_array().map(|a| a.is_empty()).unwrap_or(true))
574        {
575            return Err(MoldError::RunPod(format!("graphql errors: {errs}")).into());
576        }
577        Ok(body)
578    }
579
580    pub async fn list_pods(&self) -> Result<Vec<Pod>> {
581        self.get_json("/pods").await
582    }
583
584    pub async fn get_pod(&self, id: &str) -> Result<Pod> {
585        self.get_json(&format!("/pods/{id}")).await
586    }
587
588    pub async fn create_pod(&self, req: &CreatePodRequest) -> Result<Pod> {
589        self.post_json("/pods", req).await
590    }
591
592    pub async fn stop_pod(&self, id: &str) -> Result<()> {
593        self.post_empty(&format!("/pods/{id}/stop")).await
594    }
595
596    pub async fn start_pod(&self, id: &str) -> Result<()> {
597        self.post_empty(&format!("/pods/{id}/start")).await
598    }
599
600    pub async fn delete_pod(&self, id: &str) -> Result<()> {
601        self.delete(&format!("/pods/{id}")).await
602    }
603
604    pub async fn pod_logs(&self, id: &str) -> Result<String> {
605        self.get_text(&format!("/pods/{id}/logs")).await
606    }
607
608    pub async fn network_volumes(&self) -> Result<Vec<NetworkVolume>> {
609        self.get_json("/networkvolumes").await
610    }
611}
612
613// ─── Helpers ────────────────────────────────────────────────────────────────
614
615async fn http_error(path: &str, status: StatusCode, resp: reqwest::Response) -> MoldError {
616    let body = resp.text().await.unwrap_or_default();
617    let msg = truncate_for_error(&body);
618    match status {
619        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
620            MoldError::RunPodAuth(format!("RunPod {path} {status}: {msg}"))
621        }
622        StatusCode::NOT_FOUND => {
623            MoldError::RunPodNotFound(format!("RunPod {path} {status}: {msg}"))
624        }
625        StatusCode::CONFLICT | StatusCode::SERVICE_UNAVAILABLE
626            if msg.to_lowercase().contains("does not have the resources") =>
627        {
628            MoldError::RunPodNoStock(format!("RunPod {path} {status}: {msg}"))
629        }
630        _ => MoldError::RunPod(format!("RunPod {path} {status}: {msg}")),
631    }
632}
633
634fn stock_rank(s: &str) -> u8 {
635    match s {
636        "High" => 3,
637        "Medium" => 2,
638        "Low" => 1,
639        _ => 0,
640    }
641}
642
643fn truncate_for_error(s: &str) -> String {
644    const MAX: usize = 400;
645    let s = s.trim();
646    if s.len() <= MAX {
647        s.to_string()
648    } else {
649        format!("{}…", &s[..MAX])
650    }
651}
652
653/// Map a RunPod GPU `displayName` (e.g. `"RTX 4090"`) to the matching
654/// `ghcr.io/utensils/mold` image tag.
655pub fn image_tag_for_gpu(display_name: &str) -> &'static str {
656    let d = display_name.to_lowercase();
657    if d.contains("5090") || d.contains("blackwell") || d.contains("b200") {
658        "latest-sm120"
659    } else if d.contains("a100") || d.contains("3090") || d.contains("a40") || d.contains("ampere")
660    {
661        "latest-sm80"
662    } else {
663        // Ada (4090, L40, L40S) and fallback
664        "latest"
665    }
666}
667
668/// Ranked preference when auto-picking GPUs. Higher index = more preferred.
669pub const GPU_PREFERENCE: &[&str] = &[
670    "A100 PCIe",
671    "L40",
672    "L40S",
673    "RTX A6000",
674    "RTX 5090",
675    "RTX 4090",
676];
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681
682    #[test]
683    fn image_tag_mapping() {
684        assert_eq!(image_tag_for_gpu("RTX 4090"), "latest");
685        assert_eq!(image_tag_for_gpu("NVIDIA GeForce RTX 4090"), "latest");
686        assert_eq!(image_tag_for_gpu("L40S"), "latest");
687        assert_eq!(image_tag_for_gpu("RTX 5090"), "latest-sm120");
688        assert_eq!(image_tag_for_gpu("NVIDIA GeForce RTX 5090"), "latest-sm120");
689        assert_eq!(image_tag_for_gpu("A100 80GB"), "latest-sm80");
690        assert_eq!(image_tag_for_gpu("A100 PCIe"), "latest-sm80");
691        assert_eq!(image_tag_for_gpu("RTX 3090"), "latest-sm80");
692    }
693
694    #[test]
695    fn redacted_debug_hides_api_key() {
696        let s = RunPodSettings {
697            api_key: Some("secret-key".to_string()),
698            ..Default::default()
699        };
700        let out = s.redacted_debug();
701        assert!(!out.contains("secret-key"));
702        assert!(out.contains("<redacted>"));
703    }
704
705    #[test]
706    fn from_settings_requires_key() {
707        std::env::remove_var(API_KEY_ENV);
708        let err = RunPodClient::from_settings(&RunPodSettings::default()).unwrap_err();
709        assert!(matches!(err, MoldError::RunPodAuth(_)));
710    }
711
712    #[test]
713    fn truncate_for_error_boundary() {
714        let short = "short";
715        assert_eq!(truncate_for_error(short), "short");
716        let long = "x".repeat(500);
717        let truncated = truncate_for_error(&long);
718        assert!(truncated.ends_with('…'));
719        assert!(truncated.chars().count() <= 401);
720    }
721
722    #[test]
723    fn runpod_settings_toml_roundtrip() {
724        let original = RunPodSettings {
725            api_key: Some("k".to_string()),
726            default_gpu: Some("RTX 5090".to_string()),
727            default_datacenter: Some("EUR-IS-2".to_string()),
728            default_network_volume_id: Some("nv-123".to_string()),
729            auto_teardown: true,
730            auto_teardown_idle_mins: 30,
731            cost_alert_usd: 3.5,
732            endpoint: None,
733        };
734        let toml_s = toml::to_string(&original).unwrap();
735        let round: RunPodSettings = toml::from_str(&toml_s).unwrap();
736        assert_eq!(round.api_key, original.api_key);
737        assert_eq!(round.default_gpu, original.default_gpu);
738        assert_eq!(round.default_datacenter, original.default_datacenter);
739        assert_eq!(
740            round.default_network_volume_id,
741            original.default_network_volume_id
742        );
743        assert_eq!(round.auto_teardown, original.auto_teardown);
744        assert_eq!(
745            round.auto_teardown_idle_mins,
746            original.auto_teardown_idle_mins
747        );
748        assert_eq!(round.cost_alert_usd, original.cost_alert_usd);
749    }
750
751    #[test]
752    fn default_auto_teardown_idle_mins_is_20() {
753        let s: RunPodSettings = toml::from_str("").unwrap();
754        assert_eq!(s.auto_teardown_idle_mins, 20);
755    }
756}