korasi_cli/
ec2.rs

1use std::{net::Ipv4Addr, time::Duration};
2
3use aws_sdk_ec2::{
4    client::Waiters,
5    error::ProvideErrorMetadata,
6    types::{
7        Filter, Instance, InstanceStateName, InstanceType, IpPermission, IpRange, KeyFormat,
8        KeyPairInfo, KeyType, ResourceType, SecurityGroup, Tag, TagSpecification,
9    },
10    Client as EC2Client,
11};
12use aws_smithy_runtime_api::client::waiters::error::WaiterError;
13
14use crate::util::UtilImpl as Util;
15
16/// Co-locate all common keys here for now till a flexible
17/// configuration is needed.
18pub const GLOBAL_TAG_FILTER: &str = "hpc-launcher";
19pub const SSH_KEY_NAME: &str = "ec2-ssh-key";
20pub const SSH_SECURITY_GROUP: &str = "allow-ssh";
21
22#[derive(Clone)]
23pub struct EC2Impl {
24    /// AWS sdk client to access EC2 resources.
25    pub client: EC2Client,
26
27    /// Override default `GLOBAL_TAG_FILTER`.
28    custom_tag: Option<String>,
29}
30
31impl EC2Impl {
32    pub fn new(client: EC2Client, custom_tag: Option<String>) -> Self {
33        EC2Impl { client, custom_tag }
34    }
35
36    pub fn create_tag(&self, res_type: ResourceType) -> TagSpecification {
37        TagSpecification::builder()
38            .set_resource_type(Some(res_type))
39            .set_tags(Some(vec![Tag::builder()
40                .set_key(Some("application".into()))
41                .set_value(Some(
42                    self.custom_tag
43                        .clone()
44                        .unwrap_or(GLOBAL_TAG_FILTER.to_string()),
45                ))
46                .build()]))
47            .build()
48    }
49
50    pub async fn create_key_pair(
51        &self,
52        name: &str,
53        key_type: KeyType,
54        key_format: KeyFormat,
55    ) -> Result<(KeyPairInfo, String), EC2Error> {
56        tracing::info!("Creating key pair {name}");
57        let output = self
58            .client
59            .create_key_pair()
60            .key_name(name)
61            .key_type(key_type)
62            .key_format(key_format)
63            .set_tag_specifications(Some(vec![self.create_tag(ResourceType::KeyPair)]))
64            .send()
65            .await?;
66        tracing::info!("key pair output = {:?}", output);
67        let info = KeyPairInfo::builder()
68            .set_key_name(output.key_name)
69            .set_key_fingerprint(output.key_fingerprint)
70            .set_key_pair_id(output.key_pair_id)
71            .build();
72        let material = output
73            .key_material
74            .ok_or_else(|| EC2Error::new("Create Key Pair has no key material"))?;
75        Ok((info, material))
76    }
77
78    pub async fn list_key_pair(&self, key_names: &str) -> Result<Vec<KeyPairInfo>, EC2Error> {
79        let output = self
80            .client
81            .describe_key_pairs()
82            .key_names(key_names)
83            .set_filters(Some(vec![Filter::builder()
84                .set_name(Some("tag:application".into()))
85                .set_values(Some(vec![GLOBAL_TAG_FILTER.into()]))
86                .build()]))
87            .send()
88            .await?;
89        Ok(output.key_pairs.unwrap_or_default())
90    }
91
92    pub async fn delete_key_pair(&self, key_pair_id: &str) -> Result<(), EC2Error> {
93        let key_pair_id: String = key_pair_id.into();
94        tracing::info!("Deleting key pair {key_pair_id}");
95        self.client
96            .delete_key_pair()
97            .key_pair_id(key_pair_id)
98            .send()
99            .await?;
100        Ok(())
101    }
102
103    pub async fn create_security_group(
104        &self,
105        name: &str,
106        description: &str,
107    ) -> Result<SecurityGroup, EC2Error> {
108        tracing::info!("Creating security group {name}");
109        let create_output = self
110            .client
111            .create_security_group()
112            .group_name(name)
113            .description(description)
114            .set_tag_specifications(Some(vec![self.create_tag(ResourceType::SecurityGroup)]))
115            .send()
116            .await
117            .map_err(EC2Error::from)?;
118
119        let group_id = create_output
120            .group_id
121            .ok_or_else(|| EC2Error::new("Missing security group id after creation"))?;
122
123        let group = self
124            .describe_security_group(&group_id)
125            .await?
126            .ok_or_else(|| {
127                EC2Error::new(format!("Could not find security group with id {group_id}"))
128            })?;
129
130        tracing::info!("Created security group {name} as {group_id}");
131
132        Ok(group)
133    }
134
135    /// Find a single security group, by name. Returns Err if multiple groups are found.
136    pub async fn describe_security_group(
137        &self,
138        group_name: &str,
139    ) -> Result<Option<SecurityGroup>, EC2Error> {
140        let describe_output = self
141            .client
142            .describe_security_groups()
143            .group_names(group_name)
144            .set_filters(Some(vec![Filter::builder()
145                .set_name(Some("tag:application".into()))
146                .set_values(Some(vec![GLOBAL_TAG_FILTER.into()]))
147                .build()]))
148            .send()
149            .await?;
150
151        let mut groups = describe_output.security_groups.unwrap_or_default();
152
153        match groups.len() {
154            0 => Ok(None),
155            1 => Ok(Some(groups.remove(0))),
156            _ => Err(EC2Error::new(format!(
157                "Expected single group for {group_name}"
158            ))),
159        }
160    }
161
162    /// Add an ingress rule to a security group explicitly allowing IPv4 address
163    /// as {ip}/32 over TCP port 22.
164    pub async fn authorize_security_group_ssh_ingress(
165        &self,
166        group_id: &str,
167        ingress_ips: Vec<Ipv4Addr>,
168    ) -> Result<(), EC2Error> {
169        tracing::info!("Authorizing ingress for security group {group_id}");
170        self.client
171            .authorize_security_group_ingress()
172            .group_id(group_id)
173            .set_ip_permissions(Some(
174                ingress_ips
175                    .into_iter()
176                    .map(|ip| {
177                        IpPermission::builder()
178                            .ip_protocol("tcp")
179                            .from_port(22)
180                            .to_port(22)
181                            .ip_ranges(IpRange::builder().cidr_ip(format!("{ip}/32")).build())
182                            .build()
183                    })
184                    .collect(),
185            ))
186            .send()
187            .await?;
188        Ok(())
189    }
190
191    pub async fn delete_security_group(&self, group_id: &str) -> Result<(), EC2Error> {
192        tracing::info!("Deleting security group {group_id}");
193        self.client
194            .delete_security_group()
195            .group_id(group_id)
196            .send()
197            .await?;
198        Ok(())
199    }
200
201    pub async fn create_instances<'a>(
202        &self,
203        instance_name: &str,
204        image_id: &'a str,
205        instance_type: InstanceType,
206        key_pair: &'a KeyPairInfo,
207        security_groups: Vec<&'a SecurityGroup>,
208        user_data: Option<String>,
209    ) -> Result<Vec<String>, EC2Error> {
210        let run_instances = self
211            .client
212            .run_instances()
213            .image_id(image_id)
214            .instance_type(instance_type)
215            .key_name(
216                key_pair
217                    .key_name()
218                    .ok_or_else(|| EC2Error::new("Missing key name when launching instance"))?,
219            )
220            .set_security_group_ids(Some(
221                security_groups
222                    .iter()
223                    .filter_map(|sg| sg.group_id.clone())
224                    .collect(),
225            ))
226            .set_user_data(user_data)
227            .set_tag_specifications(Some(vec![self.create_tag(ResourceType::Instance)]))
228            .min_count(1)
229            .max_count(1)
230            .send()
231            .await?;
232
233        if run_instances.instances().is_empty() {
234            return Err(EC2Error::new("Failed to create instance"));
235        }
236
237        let mut instance_ids = vec![];
238        for i in run_instances.instances() {
239            let instance_id = i.instance_id().unwrap();
240            let response = self
241                .client
242                .create_tags()
243                .resources(instance_id)
244                .tags(Tag::builder().key("Name").value(instance_name).build())
245                .send()
246                .await;
247
248            match response {
249                Ok(_) => {
250                    tracing::info!("Created {instance_id} and applied tags.");
251                    instance_ids.push(instance_id.to_string());
252                }
253                Err(err) => {
254                    tracing::info!("Error applying tags to {instance_id}: {err:?}");
255                    return Err(err.into());
256                }
257            }
258        }
259
260        Ok(instance_ids)
261    }
262
263    /// Wait for an instance to be ready and status ok (default wait 60 seconds)
264    pub async fn wait_for_instance_ready(
265        &self,
266        instance_id: &str,
267        duration: Option<Duration>,
268    ) -> Result<(), EC2Error> {
269        self.client
270            .wait_until_instance_status_ok()
271            .instance_ids(instance_id)
272            .wait(duration.unwrap_or(Duration::from_secs(60)))
273            .await
274            .map_err(|err| match err {
275                WaiterError::ExceededMaxWait(exceeded) => EC2Error(format!(
276                    "Exceeded max time ({}s) waiting for instance to start.",
277                    exceeded.max_wait().as_secs()
278                )),
279                _ => EC2Error::from(err),
280            })?;
281        Ok(())
282    }
283
284    /// List instances that are "active" (non-terminated) and are tagged
285    /// by this tool.
286    ///
287    /// If statuses is an empty `Vec`, return all non-terminated instances as the default.
288    pub async fn describe_instance(
289        &self,
290        mut statuses: Vec<InstanceStateName>,
291    ) -> Result<Vec<Instance>, EC2Error> {
292        let non_terminated = vec![
293            InstanceStateName::Pending,
294            InstanceStateName::Running,
295            InstanceStateName::ShuttingDown,
296            InstanceStateName::Stopping,
297            InstanceStateName::Stopped,
298        ];
299        if statuses.is_empty() {
300            statuses = non_terminated;
301        }
302        let response = self
303            .client
304            .describe_instances()
305            .set_filters(Some(vec![
306                Filter::builder()
307                    .set_name(Some("tag:application".into()))
308                    .set_values(Some(vec![GLOBAL_TAG_FILTER.into()]))
309                    .build(),
310                Filter::builder()
311                    .set_name(Some("instance-state-name".into()))
312                    .set_values(Some(statuses.into_iter().map(|s| s.to_string()).collect()))
313                    .build(),
314            ]))
315            .send()
316            .await?;
317
318        let instances: Vec<_> = response
319            .reservations()
320            .iter()
321            .flat_map(|r| r.instances().to_owned())
322            .collect();
323
324        Ok(instances)
325    }
326
327    pub async fn start_instances(&self, instance_id: &str) -> Result<(), EC2Error> {
328        tracing::info!("Starting instance {instance_id}");
329
330        let mut starter = self.client.start_instances();
331        for id in instance_id.split(",") {
332            starter = starter.instance_ids(id);
333        }
334        starter.send().await?;
335
336        tracing::info!("Started instance.");
337
338        Ok(())
339    }
340
341    pub async fn stop_instances(&self, instance_ids: &str, wait: bool) -> Result<(), EC2Error> {
342        tracing::info!("Stopping instance {instance_ids}");
343
344        let mut stopper = self.client.stop_instances();
345        for id in instance_ids.split(",") {
346            stopper = stopper.instance_ids(id);
347        }
348        stopper.send().await?;
349
350        if wait {
351            self.wait_for_instance_stopped(instance_ids, None).await?;
352            tracing::info!("Stopped instance.");
353        }
354
355        Ok(())
356    }
357
358    pub async fn reboot_instance(&self, instance_id: &str) -> Result<(), EC2Error> {
359        tracing::info!("Rebooting instance {instance_id}");
360
361        self.client
362            .reboot_instances()
363            .instance_ids(instance_id)
364            .send()
365            .await?;
366
367        Ok(())
368    }
369
370    pub async fn wait_for_instance_stopped(
371        &self,
372        instance_ids: &str,
373        duration: Option<Duration>,
374    ) -> Result<(), EC2Error> {
375        let mut waiter = self.client.wait_until_instance_stopped();
376        for id in instance_ids.split(",") {
377            waiter = waiter.instance_ids(id);
378        }
379        waiter
380            .wait(duration.unwrap_or(Duration::from_secs(90)))
381            .await
382            .map_err(|err| match err {
383                WaiterError::ExceededMaxWait(exceeded) => EC2Error(format!(
384                    "Exceeded max time ({}s) waiting for instance to stop.",
385                    exceeded.max_wait().as_secs(),
386                )),
387                _ => EC2Error::from(err),
388            })?;
389
390        Ok(())
391    }
392
393    pub async fn delete_instances(&self, instance_ids: &str, wait: bool) -> Result<(), EC2Error> {
394        tracing::info!("Deleting instance with id {:?}", instance_ids);
395
396        self.stop_instances(instance_ids, true).await?;
397
398        let mut terminator = self.client.terminate_instances();
399        for id in instance_ids.split(",") {
400            terminator = terminator.instance_ids(id);
401        }
402        terminator.send().await?;
403
404        if wait {
405            self.wait_for_instance_terminated(instance_ids).await?;
406            tracing::info!("Terminated instance with ids {:?}", instance_ids);
407        }
408
409        Ok(())
410    }
411
412    async fn wait_for_instance_terminated(&self, instance_ids: &str) -> Result<(), EC2Error> {
413        let mut waiter = self.client.wait_until_instance_terminated();
414        for id in instance_ids.split(",") {
415            waiter = waiter.instance_ids(id);
416        }
417        waiter
418            .wait(Duration::from_secs(60))
419            .await
420            .map_err(|err| match err {
421                WaiterError::ExceededMaxWait(exceeded) => EC2Error(format!(
422                    "Exceeded max time ({}s) waiting for instance to terminate.",
423                    exceeded.max_wait().as_secs(),
424                )),
425                _ => EC2Error::from(err),
426            })?;
427        Ok(())
428    }
429
430    /// Add new local IP to inbound security group.
431    ///
432    /// Local IPs can rotate or if you change to a different location.
433    async fn update_inbound_ip(&self, group_id: &str) -> Result<(), EC2Error> {
434        let check_ip = Util::do_get("https://checkip.amazonaws.com").await?;
435        tracing::info!("Current IP address = {}", check_ip);
436
437        let current_ip_address: Ipv4Addr = check_ip.trim().parse().map_err(|e| {
438            EC2Error::new(format!(
439                "Failed to convert response {} to IP Address: {e:?}",
440                check_ip
441            ))
442        })?;
443
444        if let Err(err) = self
445            .authorize_security_group_ssh_ingress(group_id, vec![current_ip_address])
446            .await
447        {
448            tracing::warn!("Most likely inbound rule already exists. Err = {err}");
449        };
450
451        Ok(())
452    }
453
454    /// Call this function to update local IP in inbound group.
455    pub async fn get_ssh_security_group(&self) -> Result<SecurityGroup, EC2Error> {
456        let group = match self
457            .create_security_group(
458                SSH_SECURITY_GROUP,
459                "Enables ssh into instance from your IP.",
460            )
461            .await
462        {
463            Ok(grp) => grp,
464            Err(err) => {
465                // Try to find existing group (if any).
466                let res = self.describe_security_group(SSH_SECURITY_GROUP).await?;
467
468                if res.is_none() {
469                    return Err(err);
470                }
471
472                res.unwrap()
473            }
474        };
475
476        self.update_inbound_ip(group.group_id.as_ref().unwrap())
477            .await?;
478
479        Ok(group)
480    }
481}
482
483#[derive(Debug)]
484pub struct EC2Error(String);
485impl EC2Error {
486    pub fn new(value: impl Into<String>) -> Self {
487        EC2Error(value.into())
488    }
489
490    pub fn _add_message(self, message: impl Into<String>) -> Self {
491        EC2Error(format!("{}: {}", message.into(), self.0))
492    }
493}
494
495impl<T: ProvideErrorMetadata> From<T> for EC2Error {
496    fn from(value: T) -> Self {
497        EC2Error(format!(
498            "{}: {}",
499            value
500                .code()
501                .map(String::from)
502                .unwrap_or("unknown code".into()),
503            value
504                .message()
505                .map(String::from)
506                .unwrap_or("missing reason (most likely profile credentials not set)".into()),
507        ))
508    }
509}
510
511impl std::error::Error for EC2Error {}
512
513impl std::fmt::Display for EC2Error {
514    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
515        write!(f, "{}", self.0)
516    }
517}