Skip to main content

mold_core/
lambda.rs

1//! Lambda Cloud API client and deployment helpers.
2
3use crate::error::MoldError;
4use anyhow::Result;
5use reqwest::{Client, StatusCode};
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9pub const DEFAULT_ENDPOINT: &str = "https://cloud.lambda.ai/api/v1";
10pub const API_KEY_ENV: &str = "LAMBDA_API_KEY";
11pub const DEFAULT_IMAGE_REPOSITORY: &str = "ghcr.io/utensils/mold";
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
14pub struct LambdaSettings {
15    #[serde(default, skip_serializing_if = "Option::is_none")]
16    pub api_key: Option<String>,
17    #[serde(
18        default = "default_endpoint_opt",
19        skip_serializing_if = "Option::is_none"
20    )]
21    pub endpoint: Option<String>,
22    #[serde(
23        default = "default_image_repository_opt",
24        skip_serializing_if = "Option::is_none"
25    )]
26    pub image_repository: Option<String>,
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub ssh_key_name: Option<String>,
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub ssh_private_key_path: Option<String>,
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub filesystem_prefix: Option<String>,
33    #[serde(default = "default_filesystem_mount_path")]
34    pub filesystem_mount_path: String,
35    #[serde(default = "default_confirm_hourly_usd")]
36    pub confirm_hourly_usd: f64,
37    #[serde(default = "default_local_port")]
38    pub local_port: u16,
39}
40
41impl Default for LambdaSettings {
42    fn default() -> Self {
43        Self {
44            api_key: None,
45            endpoint: default_endpoint_opt(),
46            image_repository: default_image_repository_opt(),
47            ssh_key_name: None,
48            ssh_private_key_path: None,
49            filesystem_prefix: None,
50            filesystem_mount_path: default_filesystem_mount_path(),
51            confirm_hourly_usd: default_confirm_hourly_usd(),
52            local_port: default_local_port(),
53        }
54    }
55}
56
57fn default_endpoint_opt() -> Option<String> {
58    Some(DEFAULT_ENDPOINT.to_string())
59}
60
61fn default_image_repository_opt() -> Option<String> {
62    Some(DEFAULT_IMAGE_REPOSITORY.to_string())
63}
64
65fn default_filesystem_mount_path() -> String {
66    "/data/mold".to_string()
67}
68
69fn default_confirm_hourly_usd() -> f64 {
70    5.0
71}
72
73fn default_local_port() -> u16 {
74    7680
75}
76
77impl LambdaSettings {
78    pub fn resolved_api_key(&self) -> Option<String> {
79        std::env::var(API_KEY_ENV)
80            .ok()
81            .filter(|s| !s.is_empty())
82            .or_else(|| self.api_key.clone())
83    }
84
85    pub fn endpoint(&self) -> &str {
86        self.endpoint.as_deref().unwrap_or(DEFAULT_ENDPOINT)
87    }
88
89    pub fn image_repository(&self) -> &str {
90        self.image_repository
91            .as_deref()
92            .unwrap_or(DEFAULT_IMAGE_REPOSITORY)
93    }
94
95    pub fn redacted_debug(&self) -> String {
96        format!(
97            "LambdaSettings {{ api_key: {}, endpoint: {:?}, image_repository: {:?}, \
98             ssh_key_name: {:?}, ssh_private_key_path: {:?}, filesystem_prefix: {:?}, \
99             filesystem_mount_path: {:?}, confirm_hourly_usd: {}, local_port: {} }}",
100            if self.api_key.is_some() {
101                "Some(\"<redacted>\")"
102            } else {
103                "None"
104            },
105            self.endpoint,
106            self.image_repository,
107            self.ssh_key_name,
108            self.ssh_private_key_path,
109            self.filesystem_prefix,
110            self.filesystem_mount_path,
111            self.confirm_hourly_usd,
112            self.local_port,
113        )
114    }
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct ApiList<T> {
119    #[serde(default)]
120    pub data: Vec<T>,
121}
122
123#[derive(Debug, Clone, Deserialize, Serialize)]
124pub struct ApiItem<T> {
125    pub data: T,
126}
127
128#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
129pub struct Region {
130    pub name: String,
131    #[serde(default)]
132    pub description: String,
133}
134
135#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
136pub struct InstanceTypeSpecs {
137    #[serde(default)]
138    pub gpus: u32,
139    #[serde(default)]
140    pub gpu_description: String,
141    #[serde(default)]
142    pub memory_gib: u32,
143    #[serde(default)]
144    pub storage_gib: u32,
145    #[serde(default)]
146    pub vcpus: u32,
147}
148
149#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
150pub struct InstanceType {
151    pub name: String,
152    #[serde(default)]
153    pub description: String,
154    #[serde(default)]
155    pub gpu_description: String,
156    #[serde(default)]
157    pub price_cents_per_hour: u32,
158    #[serde(default)]
159    pub specs: InstanceTypeSpecs,
160    #[serde(default)]
161    pub regions_with_capacity_available: Vec<Region>,
162}
163
164#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
165pub struct SshKey {
166    pub id: String,
167    pub name: String,
168    pub public_key: String,
169}
170
171#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
172pub struct Filesystem {
173    pub id: String,
174    pub name: String,
175    #[serde(default)]
176    pub mount_point: String,
177    #[serde(default)]
178    pub region: Option<Region>,
179    #[serde(default)]
180    pub bytes_used: Option<u64>,
181}
182
183#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
184pub struct Tag {
185    pub key: String,
186    pub value: String,
187}
188
189#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
190pub struct Instance {
191    pub id: String,
192    #[serde(default)]
193    pub name: Option<String>,
194    #[serde(default)]
195    pub status: String,
196    #[serde(default)]
197    pub ip: Option<String>,
198    #[serde(default)]
199    pub private_ip: Option<String>,
200    #[serde(default)]
201    pub instance_type: Option<InstanceType>,
202    #[serde(default)]
203    pub region: Option<Region>,
204    #[serde(default)]
205    pub ssh_key_names: Vec<String>,
206    #[serde(default)]
207    pub file_system_names: Vec<String>,
208    #[serde(default)]
209    pub tags: Vec<Tag>,
210}
211
212#[derive(Debug, Clone, Default, Deserialize, Serialize)]
213pub struct CreateSshKeyRequest {
214    pub name: String,
215    pub public_key: String,
216}
217
218#[derive(Debug, Clone, Default, Deserialize, Serialize)]
219pub struct CreateFilesystemRequest {
220    pub name: String,
221    pub region: String,
222}
223
224#[derive(Debug, Clone, Default, Deserialize, Serialize)]
225pub struct LaunchInstancesRequest {
226    pub region_name: String,
227    pub instance_type_name: String,
228    pub ssh_key_names: Vec<String>,
229    #[serde(skip_serializing_if = "Vec::is_empty", default)]
230    pub file_system_names: Vec<String>,
231    #[serde(skip_serializing_if = "Vec::is_empty", default)]
232    pub file_system_mounts: Vec<FilesystemMount>,
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub hostname: Option<String>,
235    pub name: String,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub image: Option<LaunchImage>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub user_data: Option<String>,
240    #[serde(skip_serializing_if = "Vec::is_empty", default)]
241    pub tags: Vec<Tag>,
242}
243
244#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
245pub struct InstanceLaunchResponse {
246    #[serde(default)]
247    pub instance_ids: Vec<String>,
248}
249
250#[derive(Debug, Clone, Deserialize)]
251struct InstanceTypesResponse {
252    #[serde(default)]
253    data: std::collections::BTreeMap<String, InstanceTypeOffering>,
254}
255
256#[derive(Debug, Clone, Deserialize)]
257struct InstanceTypeOffering {
258    instance_type: InstanceType,
259    #[serde(default)]
260    regions_with_capacity_available: Vec<Region>,
261}
262
263#[derive(Debug, Clone, Deserialize, Serialize)]
264pub struct FilesystemMount {
265    pub mount_point: String,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    pub file_system_name: Option<String>,
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub file_system_id: Option<String>,
270}
271
272#[derive(Debug, Clone, Deserialize, Serialize)]
273pub struct LaunchImage {
274    pub id: String,
275}
276
277pub struct LaunchRequestInput<'a> {
278    pub region_name: &'a str,
279    pub instance_type_name: &'a str,
280    pub ssh_key_name: &'a str,
281    pub filesystem_name: &'a str,
282    pub filesystem_id: Option<&'a str>,
283    pub filesystem_mount_path: &'a str,
284    pub instance_name: &'a str,
285    pub image_id: Option<&'a str>,
286    pub user_data: &'a str,
287}
288
289pub fn build_launch_request(input: LaunchRequestInput<'_>) -> LaunchInstancesRequest {
290    LaunchInstancesRequest {
291        region_name: input.region_name.to_string(),
292        instance_type_name: input.instance_type_name.to_string(),
293        ssh_key_names: vec![input.ssh_key_name.to_string()],
294        file_system_names: vec![input.filesystem_name.to_string()],
295        file_system_mounts: vec![FilesystemMount {
296            mount_point: input.filesystem_mount_path.to_string(),
297            file_system_name: input
298                .filesystem_id
299                .is_none()
300                .then(|| input.filesystem_name.to_string()),
301            file_system_id: input.filesystem_id.map(str::to_string),
302        }],
303        hostname: None,
304        name: input.instance_name.to_string(),
305        image: input.image_id.map(|id| LaunchImage { id: id.to_string() }),
306        user_data: Some(input.user_data.to_string()),
307        tags: vec![Tag {
308            key: "managed-by".to_string(),
309            value: "mold".to_string(),
310        }],
311    }
312}
313
314#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
315pub struct AvailabilityRow {
316    pub instance_type: String,
317    pub region: String,
318    pub gpu_description: String,
319    pub gpu_count: u32,
320    pub generation_slots: u32,
321    pub price_per_hour_usd: f64,
322    pub memory_gib: u32,
323    pub storage_gib: u32,
324    pub image: String,
325}
326
327impl AvailabilityRow {
328    pub fn from_instance_type(
329        instance_type: &InstanceType,
330        image_repository: &str,
331        version: &str,
332    ) -> Self {
333        let region = instance_type
334            .regions_with_capacity_available
335            .first()
336            .map(|r| r.name.clone())
337            .unwrap_or_default();
338        let image = if gpu_uses_unsupported_linux_arm64(&instance_type.specs.gpu_description)
339            || gpu_uses_unsupported_linux_arm64(&instance_type.name)
340        {
341            "unsupported: linux/arm64 host".to_string()
342        } else {
343            let tag = image_tag_for_gpu(&instance_type.specs.gpu_description, version);
344            format!("{image_repository}:{tag}")
345        };
346        Self {
347            instance_type: instance_type.name.clone(),
348            region,
349            gpu_description: instance_type.specs.gpu_description.clone(),
350            gpu_count: instance_type.specs.gpus,
351            generation_slots: instance_type.specs.gpus,
352            price_per_hour_usd: instance_type.price_cents_per_hour as f64 / 100.0,
353            memory_gib: instance_type.specs.memory_gib,
354            storage_gib: instance_type.specs.storage_gib,
355            image,
356        }
357    }
358}
359
360pub fn image_tag_for_gpu(gpu_description: &str, _version: &str) -> String {
361    let lower = gpu_description.to_ascii_lowercase();
362    if lower.contains("a100")
363        || lower.contains("a10")
364        || lower.contains("a40")
365        || lower.contains("rtx 30")
366        || lower.contains("3090")
367    {
368        "latest-sm80".to_string()
369    } else if lower.contains("h100") || lower.contains("h200") || lower.contains("gh") {
370        "latest-sm90".to_string()
371    } else if lower.contains("b200") || lower.contains("5090") || lower.contains("blackwell") {
372        "latest-sm120".to_string()
373    } else {
374        "latest".to_string()
375    }
376}
377
378pub fn gpu_uses_unsupported_linux_arm64(gpu_description: &str) -> bool {
379    gpu_description.to_ascii_lowercase().contains("gh200")
380}
381
382pub fn filesystem_name(settings: &LambdaSettings, region: &str) -> String {
383    let prefix = settings.filesystem_prefix.as_deref().unwrap_or("mold");
384    format!("{prefix}-{region}")
385}
386
387#[derive(Debug, Clone)]
388pub struct CloudInitOptions {
389    pub image: String,
390    pub mount_path: String,
391    pub env_file: String,
392}
393
394pub fn render_cloud_init(opts: &CloudInitOptions) -> String {
395    format!(
396        r#"#cloud-config
397write_files:
398  - path: /etc/systemd/system/mold-lambda.service
399    permissions: '0644'
400    content: |
401      [Unit]
402      Description=mold Lambda container
403      After=docker.service network-online.target
404      Wants=network-online.target
405
406      [Service]
407      Restart=always
408      RestartSec=10
409      ExecStartPre=-/usr/bin/docker rm -f mold
410      ExecStartPre=/usr/bin/docker pull {image}
411      ExecStart=/usr/bin/docker run --name mold --gpus all --restart unless-stopped --env-file {env_file} -e MOLD_PORT=7680 -p 127.0.0.1:7680:7680 -v {mount_path}:/workspace {image}
412      ExecStop=/usr/bin/docker stop mold
413
414      [Install]
415      WantedBy=multi-user.target
416runcmd:
417  - [ mkdir, -p, /etc/mold ]
418  - [ sh, -c, "touch {env_file} && chmod 600 {env_file}" ]
419  - [ systemctl, daemon-reload ]
420  - [ systemctl, enable, --now, mold-lambda.service ]
421"#,
422        image = opts.image,
423        mount_path = opts.mount_path,
424        env_file = opts.env_file,
425    )
426}
427
428#[derive(Clone)]
429pub struct LambdaClient {
430    client: Client,
431    endpoint: String,
432    api_key: String,
433}
434
435impl LambdaClient {
436    pub fn from_settings(settings: &LambdaSettings) -> Result<Self> {
437        let api_key = settings.resolved_api_key().ok_or_else(|| {
438            MoldError::Config("missing Lambda API key; set LAMBDA_API_KEY or lambda.api_key".into())
439        })?;
440        Ok(Self {
441            client: Client::builder().timeout(Duration::from_secs(60)).build()?,
442            endpoint: settings.endpoint().trim_end_matches('/').to_string(),
443            api_key,
444        })
445    }
446
447    pub fn new(endpoint: impl Into<String>, api_key: impl Into<String>) -> Self {
448        Self {
449            client: Client::new(),
450            endpoint: endpoint.into().trim_end_matches('/').to_string(),
451            api_key: api_key.into(),
452        }
453    }
454
455    async fn get_list<T: for<'de> Deserialize<'de> + Default>(&self, path: &str) -> Result<Vec<T>> {
456        let resp = self
457            .client
458            .get(format!("{}{}", self.endpoint, path))
459            .basic_auth(&self.api_key, Some(""))
460            .send()
461            .await?;
462        decode_list(resp).await
463    }
464
465    async fn post_item<B: Serialize, T: for<'de> Deserialize<'de>>(
466        &self,
467        path: &str,
468        body: &B,
469    ) -> Result<T> {
470        let resp = self
471            .client
472            .post(format!("{}{}", self.endpoint, path))
473            .basic_auth(&self.api_key, Some(""))
474            .json(body)
475            .send()
476            .await?;
477        decode_item(resp).await
478    }
479
480    pub async fn list_instance_types(&self) -> Result<Vec<InstanceType>> {
481        let resp = self
482            .client
483            .get(format!("{}/instance-types", self.endpoint))
484            .basic_auth(&self.api_key, Some(""))
485            .send()
486            .await?;
487        if !resp.status().is_success() {
488            return Err(lambda_error(resp).await.into());
489        }
490        decode_instance_types_body(&resp.text().await?)
491    }
492
493    pub async fn list_instances(&self) -> Result<Vec<Instance>> {
494        self.get_list("/instances").await
495    }
496
497    pub async fn get_instance(&self, id: &str) -> Result<Instance> {
498        let resp = self
499            .client
500            .get(format!("{}/instances/{id}", self.endpoint))
501            .basic_auth(&self.api_key, Some(""))
502            .send()
503            .await?;
504        decode_item(resp).await
505    }
506
507    pub async fn launch_instance(
508        &self,
509        req: &LaunchInstancesRequest,
510    ) -> Result<InstanceLaunchResponse> {
511        self.post_item("/instance-operations/launch", req).await
512    }
513
514    pub async fn terminate_instance(&self, id: &str) -> Result<()> {
515        let body = serde_json::json!({ "instance_ids": [id] });
516        let resp = self
517            .client
518            .post(format!("{}/instance-operations/terminate", self.endpoint))
519            .basic_auth(&self.api_key, Some(""))
520            .json(&body)
521            .send()
522            .await?;
523        ensure_success(resp).await
524    }
525
526    pub async fn list_ssh_keys(&self) -> Result<Vec<SshKey>> {
527        self.get_list("/ssh-keys").await
528    }
529
530    pub async fn create_ssh_key(&self, req: &CreateSshKeyRequest) -> Result<SshKey> {
531        self.post_item("/ssh-keys", req).await
532    }
533
534    pub async fn list_filesystems(&self) -> Result<Vec<Filesystem>> {
535        self.get_list("/file-systems").await
536    }
537
538    pub async fn create_filesystem(&self, req: &CreateFilesystemRequest) -> Result<Filesystem> {
539        self.post_item("/filesystems", req).await
540    }
541
542    pub async fn delete_filesystem(&self, id: &str) -> Result<()> {
543        let resp = self
544            .client
545            .delete(format!("{}/filesystems/{id}", self.endpoint))
546            .basic_auth(&self.api_key, Some(""))
547            .send()
548            .await?;
549        ensure_success(resp).await
550    }
551}
552
553async fn decode_list<T: for<'de> Deserialize<'de> + Default>(
554    resp: reqwest::Response,
555) -> Result<Vec<T>> {
556    if !resp.status().is_success() {
557        return Err(lambda_error(resp).await.into());
558    }
559    Ok(resp.json::<ApiList<T>>().await?.data)
560}
561
562async fn decode_item<T: for<'de> Deserialize<'de>>(resp: reqwest::Response) -> Result<T> {
563    if !resp.status().is_success() {
564        return Err(lambda_error(resp).await.into());
565    }
566    Ok(resp.json::<ApiItem<T>>().await?.data)
567}
568
569async fn ensure_success(resp: reqwest::Response) -> Result<()> {
570    if !resp.status().is_success() {
571        return Err(lambda_error(resp).await.into());
572    }
573    Ok(())
574}
575
576async fn lambda_error(resp: reqwest::Response) -> MoldError {
577    let status = resp.status();
578    let body = resp.text().await.unwrap_or_default();
579    let message = if status == StatusCode::UNAUTHORIZED {
580        "Lambda API authentication failed".to_string()
581    } else {
582        format!(
583            "Lambda API request failed with {status}: {}",
584            truncate(&body)
585        )
586    };
587    MoldError::Config(message)
588}
589
590fn truncate(s: &str) -> String {
591    const MAX: usize = 400;
592    if s.chars().count() <= MAX {
593        return s.to_string();
594    }
595    let mut out = s.chars().take(MAX).collect::<String>();
596    out.push('…');
597    out
598}
599
600pub fn decode_instance_types_body(body: &str) -> Result<Vec<InstanceType>> {
601    let response: InstanceTypesResponse = serde_json::from_str(body)?;
602    Ok(response
603        .data
604        .into_values()
605        .map(|offering| {
606            let mut instance_type = offering.instance_type;
607            if instance_type.specs.gpu_description.is_empty() {
608                instance_type.specs.gpu_description = instance_type.gpu_description.clone();
609            }
610            instance_type.regions_with_capacity_available =
611                offering.regions_with_capacity_available;
612            instance_type
613        })
614        .collect())
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn instance_types_decode_lambda_map_shape() {
623        let body = r#"{
624          "data": {
625            "gpu_1x_a10": {
626              "instance_type": {
627                "name": "gpu_1x_a10",
628                "description": "1x A10",
629                "gpu_description": "A10",
630                "price_cents_per_hour": 75,
631                "specs": {
632                  "vcpus": 30,
633                  "memory_gib": 200,
634                  "storage_gib": 1400,
635                  "gpus": 1
636                }
637              },
638              "regions_with_capacity_available": [
639                {"name": "us-west-1", "description": "California"}
640              ]
641            }
642          }
643        }"#;
644
645        let decoded = decode_instance_types_body(body).unwrap();
646        assert_eq!(decoded.len(), 1);
647        assert_eq!(decoded[0].name, "gpu_1x_a10");
648        assert_eq!(decoded[0].specs.gpu_description, "A10");
649        assert_eq!(
650            decoded[0].regions_with_capacity_available[0].name,
651            "us-west-1"
652        );
653    }
654
655    #[test]
656    fn lambda_settings_toml_roundtrip_and_defaults() {
657        let settings: LambdaSettings = toml::from_str("").unwrap();
658        assert_eq!(settings.endpoint.as_deref(), Some(DEFAULT_ENDPOINT));
659        assert_eq!(
660            settings.image_repository.as_deref(),
661            Some(DEFAULT_IMAGE_REPOSITORY)
662        );
663        assert_eq!(settings.filesystem_mount_path, "/data/mold");
664        assert_eq!(settings.confirm_hourly_usd, 5.0);
665        assert_eq!(settings.local_port, 7680);
666
667        let original = LambdaSettings {
668            api_key: Some("secret".into()),
669            endpoint: Some("http://localhost:9999".into()),
670            image_repository: Some("ghcr.io/example/mold".into()),
671            ssh_key_name: Some("mold-key".into()),
672            ssh_private_key_path: Some("~/.ssh/mold_lambda_ed25519".into()),
673            filesystem_prefix: Some("mold".into()),
674            filesystem_mount_path: "/mnt/mold".into(),
675            confirm_hourly_usd: 9.5,
676            local_port: 7777,
677        };
678        let encoded = toml::to_string(&original).unwrap();
679        let decoded: LambdaSettings = toml::from_str(&encoded).unwrap();
680        assert_eq!(decoded.api_key, original.api_key);
681        assert_eq!(decoded.filesystem_mount_path, "/mnt/mold");
682        assert_eq!(decoded.local_port, 7777);
683    }
684
685    #[test]
686    fn auth_prefers_lambda_api_key_env_over_config() {
687        let _guard = crate::test_support::ENV_LOCK.lock().unwrap();
688        std::env::set_var(API_KEY_ENV, "from-env");
689        let settings = LambdaSettings {
690            api_key: Some("from-config".into()),
691            ..Default::default()
692        };
693        assert_eq!(settings.resolved_api_key().as_deref(), Some("from-env"));
694        std::env::remove_var(API_KEY_ENV);
695    }
696
697    #[test]
698    fn image_tag_maps_gpu_generations() {
699        assert_eq!(
700            image_tag_for_gpu("NVIDIA A100-SXM4-80GB", "0.10.0"),
701            "latest-sm80"
702        );
703        assert_eq!(image_tag_for_gpu("NVIDIA L40S", "0.10.0"), "latest");
704        assert_eq!(
705            image_tag_for_gpu("NVIDIA H100 PCIe", "0.10.0"),
706            "latest-sm90"
707        );
708        assert_eq!(image_tag_for_gpu("NVIDIA B200", "0.10.0"), "latest-sm120");
709    }
710
711    #[test]
712    fn gh200_is_not_supported_by_published_linux_images() {
713        assert!(gpu_uses_unsupported_linux_arm64("GH200 (96 GB)"));
714        assert!(gpu_uses_unsupported_linux_arm64("gpu_1x_gh200"));
715        assert!(!gpu_uses_unsupported_linux_arm64("NVIDIA H100 PCIe"));
716        assert!(!gpu_uses_unsupported_linux_arm64("NVIDIA A100-SXM4-80GB"));
717    }
718
719    #[test]
720    fn availability_marks_gh200_as_unsupported() {
721        let ty = InstanceType {
722            name: "gpu_1x_gh200".into(),
723            description: "1x GH200".into(),
724            gpu_description: "GH200 (96 GB)".into(),
725            price_cents_per_hour: 229,
726            specs: InstanceTypeSpecs {
727                gpus: 1,
728                gpu_description: "GH200 (96 GB)".into(),
729                memory_gib: 432,
730                storage_gib: 4096,
731                ..Default::default()
732            },
733            regions_with_capacity_available: vec![Region {
734                name: "us-east-3".into(),
735                description: "Austin".into(),
736            }],
737        };
738        let row = AvailabilityRow::from_instance_type(&ty, "ghcr.io/utensils/mold", "0.10.0");
739        assert_eq!(row.image, "unsupported: linux/arm64 host");
740    }
741
742    #[test]
743    fn availability_row_uses_gpu_count_as_generation_slots() {
744        let ty = InstanceType {
745            name: "gpu_8x_h100".into(),
746            description: "8x H100".into(),
747            gpu_description: "NVIDIA H100".into(),
748            price_cents_per_hour: 15920,
749            specs: InstanceTypeSpecs {
750                gpus: 8,
751                gpu_description: "NVIDIA H100".into(),
752                memory_gib: 1800,
753                storage_gib: 200,
754                ..Default::default()
755            },
756            regions_with_capacity_available: vec![Region {
757                name: "us-east-1".into(),
758                description: "Virginia".into(),
759            }],
760        };
761        let row = AvailabilityRow::from_instance_type(&ty, "ghcr.io/utensils/mold", "0.10.0");
762        assert_eq!(row.generation_slots, 8);
763        assert_eq!(row.image, "ghcr.io/utensils/mold:latest-sm90");
764        assert_eq!(row.price_per_hour_usd, 159.20);
765    }
766
767    #[test]
768    fn filesystem_name_defaults_to_prefix_region() {
769        let settings = LambdaSettings::default();
770        assert_eq!(filesystem_name(&settings, "us-west-1"), "mold-us-west-1");
771        let custom = LambdaSettings {
772            filesystem_prefix: Some("team-mold".into()),
773            ..Default::default()
774        };
775        assert_eq!(filesystem_name(&custom, "us-east-1"), "team-mold-us-east-1");
776    }
777
778    #[test]
779    fn launch_request_contains_expected_shape() {
780        let req = build_launch_request(LaunchRequestInput {
781            region_name: "us-west-1",
782            instance_type_name: "gpu_1x_a10",
783            ssh_key_name: "mold-laptop",
784            filesystem_name: "mold-us-west-1",
785            filesystem_id: None,
786            filesystem_mount_path: "/data/mold",
787            instance_name: "mold-us-west-1",
788            image_id: None,
789            user_data: "#cloud-config\n",
790        });
791        let json = serde_json::to_value(req).unwrap();
792        assert_eq!(json["region_name"], "us-west-1");
793        assert_eq!(json["ssh_key_names"], serde_json::json!(["mold-laptop"]));
794        assert_eq!(
795            json["file_system_names"],
796            serde_json::json!(["mold-us-west-1"])
797        );
798        assert_eq!(json["file_system_mounts"][0]["mount_point"], "/data/mold");
799        assert_eq!(json["tags"][0]["key"], "managed-by");
800        assert_eq!(json["tags"][0]["value"], "mold");
801    }
802
803    #[test]
804    fn launch_request_uses_filesystem_id_when_available() {
805        let req = build_launch_request(LaunchRequestInput {
806            region_name: "us-west-1",
807            instance_type_name: "gpu_1x_a10",
808            ssh_key_name: "mold-laptop",
809            filesystem_name: "mold-us-west-1",
810            filesystem_id: Some("fs-123"),
811            filesystem_mount_path: "/data/mold",
812            instance_name: "mold-us-west-1",
813            image_id: None,
814            user_data: "#cloud-config\n",
815        });
816        let mount = &req.file_system_mounts[0];
817        assert_eq!(mount.file_system_id.as_deref(), Some("fs-123"));
818        assert!(mount.file_system_name.is_none());
819    }
820
821    #[test]
822    fn create_filesystem_request_uses_lambda_region_field() {
823        let req = CreateFilesystemRequest {
824            name: "mold-us-east-1".into(),
825            region: "us-east-1".into(),
826        };
827        let json = serde_json::to_value(req).unwrap();
828        assert_eq!(json["name"], "mold-us-east-1");
829        assert_eq!(json["region"], "us-east-1");
830        assert!(json.get("region_name").is_none());
831    }
832
833    #[test]
834    fn cloud_init_keeps_service_private_and_omits_secrets_by_default() {
835        let rendered = render_cloud_init(&CloudInitOptions {
836            image: "ghcr.io/utensils/mold:0.10.0-sm90".into(),
837            mount_path: "/data/mold".into(),
838            env_file: "/etc/mold/lambda.env".into(),
839        });
840        assert!(rendered.contains("-p 127.0.0.1:7680:7680"));
841        assert!(rendered.contains("-v /data/mold:/workspace"));
842        assert!(rendered.contains("--gpus all"));
843        assert!(rendered.contains("ghcr.io/utensils/mold:0.10.0-sm90"));
844        assert!(!rendered.contains("HF_TOKEN"));
845        assert!(!rendered.contains("CIVITAI_TOKEN"));
846    }
847}