1use std::{net::Ipv4Addr, time::Duration};
2
3use aws_sdk_ec2::{
4 client::Waiters,
5 error::ProvideErrorMetadata,
6 types::{
7 Filter, Instance, InstanceStateName, InstanceType, IpPermission, IpRange, KeyFormat,
8 KeyPairInfo, KeyType, ResourceType, SecurityGroup, Tag, TagSpecification,
9 },
10 Client as EC2Client,
11};
12use aws_smithy_runtime_api::client::waiters::error::WaiterError;
13
14use crate::util::UtilImpl as Util;
15
16pub const GLOBAL_TAG_FILTER: &str = "hpc-launcher";
19pub const SSH_KEY_NAME: &str = "ec2-ssh-key";
20pub const SSH_SECURITY_GROUP: &str = "allow-ssh";
21
22#[derive(Clone)]
23pub struct EC2Impl {
24 pub client: EC2Client,
26
27 custom_tag: Option<String>,
29}
30
31impl EC2Impl {
32 pub fn new(client: EC2Client, custom_tag: Option<String>) -> Self {
33 EC2Impl { client, custom_tag }
34 }
35
36 pub fn create_tag(&self, res_type: ResourceType) -> TagSpecification {
37 TagSpecification::builder()
38 .set_resource_type(Some(res_type))
39 .set_tags(Some(vec![Tag::builder()
40 .set_key(Some("application".into()))
41 .set_value(Some(
42 self.custom_tag
43 .clone()
44 .unwrap_or(GLOBAL_TAG_FILTER.to_string()),
45 ))
46 .build()]))
47 .build()
48 }
49
50 pub async fn create_key_pair(
51 &self,
52 name: &str,
53 key_type: KeyType,
54 key_format: KeyFormat,
55 ) -> Result<(KeyPairInfo, String), EC2Error> {
56 tracing::info!("Creating key pair {name}");
57 let output = self
58 .client
59 .create_key_pair()
60 .key_name(name)
61 .key_type(key_type)
62 .key_format(key_format)
63 .set_tag_specifications(Some(vec![self.create_tag(ResourceType::KeyPair)]))
64 .send()
65 .await?;
66 tracing::info!("key pair output = {:?}", output);
67 let info = KeyPairInfo::builder()
68 .set_key_name(output.key_name)
69 .set_key_fingerprint(output.key_fingerprint)
70 .set_key_pair_id(output.key_pair_id)
71 .build();
72 let material = output
73 .key_material
74 .ok_or_else(|| EC2Error::new("Create Key Pair has no key material"))?;
75 Ok((info, material))
76 }
77
78 pub async fn list_key_pair(&self, key_names: &str) -> Result<Vec<KeyPairInfo>, EC2Error> {
79 let output = self
80 .client
81 .describe_key_pairs()
82 .key_names(key_names)
83 .set_filters(Some(vec![Filter::builder()
84 .set_name(Some("tag:application".into()))
85 .set_values(Some(vec![GLOBAL_TAG_FILTER.into()]))
86 .build()]))
87 .send()
88 .await?;
89 Ok(output.key_pairs.unwrap_or_default())
90 }
91
92 pub async fn delete_key_pair(&self, key_pair_id: &str) -> Result<(), EC2Error> {
93 let key_pair_id: String = key_pair_id.into();
94 tracing::info!("Deleting key pair {key_pair_id}");
95 self.client
96 .delete_key_pair()
97 .key_pair_id(key_pair_id)
98 .send()
99 .await?;
100 Ok(())
101 }
102
103 pub async fn create_security_group(
104 &self,
105 name: &str,
106 description: &str,
107 ) -> Result<SecurityGroup, EC2Error> {
108 tracing::info!("Creating security group {name}");
109 let create_output = self
110 .client
111 .create_security_group()
112 .group_name(name)
113 .description(description)
114 .set_tag_specifications(Some(vec![self.create_tag(ResourceType::SecurityGroup)]))
115 .send()
116 .await
117 .map_err(EC2Error::from)?;
118
119 let group_id = create_output
120 .group_id
121 .ok_or_else(|| EC2Error::new("Missing security group id after creation"))?;
122
123 let group = self
124 .describe_security_group(&group_id)
125 .await?
126 .ok_or_else(|| {
127 EC2Error::new(format!("Could not find security group with id {group_id}"))
128 })?;
129
130 tracing::info!("Created security group {name} as {group_id}");
131
132 Ok(group)
133 }
134
135 pub async fn describe_security_group(
137 &self,
138 group_name: &str,
139 ) -> Result<Option<SecurityGroup>, EC2Error> {
140 let describe_output = self
141 .client
142 .describe_security_groups()
143 .group_names(group_name)
144 .set_filters(Some(vec![Filter::builder()
145 .set_name(Some("tag:application".into()))
146 .set_values(Some(vec![GLOBAL_TAG_FILTER.into()]))
147 .build()]))
148 .send()
149 .await?;
150
151 let mut groups = describe_output.security_groups.unwrap_or_default();
152
153 match groups.len() {
154 0 => Ok(None),
155 1 => Ok(Some(groups.remove(0))),
156 _ => Err(EC2Error::new(format!(
157 "Expected single group for {group_name}"
158 ))),
159 }
160 }
161
162 pub async fn authorize_security_group_ssh_ingress(
165 &self,
166 group_id: &str,
167 ingress_ips: Vec<Ipv4Addr>,
168 ) -> Result<(), EC2Error> {
169 tracing::info!("Authorizing ingress for security group {group_id}");
170 self.client
171 .authorize_security_group_ingress()
172 .group_id(group_id)
173 .set_ip_permissions(Some(
174 ingress_ips
175 .into_iter()
176 .map(|ip| {
177 IpPermission::builder()
178 .ip_protocol("tcp")
179 .from_port(22)
180 .to_port(22)
181 .ip_ranges(IpRange::builder().cidr_ip(format!("{ip}/32")).build())
182 .build()
183 })
184 .collect(),
185 ))
186 .send()
187 .await?;
188 Ok(())
189 }
190
191 pub async fn delete_security_group(&self, group_id: &str) -> Result<(), EC2Error> {
192 tracing::info!("Deleting security group {group_id}");
193 self.client
194 .delete_security_group()
195 .group_id(group_id)
196 .send()
197 .await?;
198 Ok(())
199 }
200
201 pub async fn create_instances<'a>(
202 &self,
203 instance_name: &str,
204 image_id: &'a str,
205 instance_type: InstanceType,
206 key_pair: &'a KeyPairInfo,
207 security_groups: Vec<&'a SecurityGroup>,
208 user_data: Option<String>,
209 ) -> Result<Vec<String>, EC2Error> {
210 let run_instances = self
211 .client
212 .run_instances()
213 .image_id(image_id)
214 .instance_type(instance_type)
215 .key_name(
216 key_pair
217 .key_name()
218 .ok_or_else(|| EC2Error::new("Missing key name when launching instance"))?,
219 )
220 .set_security_group_ids(Some(
221 security_groups
222 .iter()
223 .filter_map(|sg| sg.group_id.clone())
224 .collect(),
225 ))
226 .set_user_data(user_data)
227 .set_tag_specifications(Some(vec![self.create_tag(ResourceType::Instance)]))
228 .min_count(1)
229 .max_count(1)
230 .send()
231 .await?;
232
233 if run_instances.instances().is_empty() {
234 return Err(EC2Error::new("Failed to create instance"));
235 }
236
237 let mut instance_ids = vec![];
238 for i in run_instances.instances() {
239 let instance_id = i.instance_id().unwrap();
240 let response = self
241 .client
242 .create_tags()
243 .resources(instance_id)
244 .tags(Tag::builder().key("Name").value(instance_name).build())
245 .send()
246 .await;
247
248 match response {
249 Ok(_) => {
250 tracing::info!("Created {instance_id} and applied tags.");
251 instance_ids.push(instance_id.to_string());
252 }
253 Err(err) => {
254 tracing::info!("Error applying tags to {instance_id}: {err:?}");
255 return Err(err.into());
256 }
257 }
258 }
259
260 Ok(instance_ids)
261 }
262
263 pub async fn wait_for_instance_ready(
265 &self,
266 instance_id: &str,
267 duration: Option<Duration>,
268 ) -> Result<(), EC2Error> {
269 self.client
270 .wait_until_instance_status_ok()
271 .instance_ids(instance_id)
272 .wait(duration.unwrap_or(Duration::from_secs(60)))
273 .await
274 .map_err(|err| match err {
275 WaiterError::ExceededMaxWait(exceeded) => EC2Error(format!(
276 "Exceeded max time ({}s) waiting for instance to start.",
277 exceeded.max_wait().as_secs()
278 )),
279 _ => EC2Error::from(err),
280 })?;
281 Ok(())
282 }
283
284 pub async fn describe_instance(
289 &self,
290 mut statuses: Vec<InstanceStateName>,
291 ) -> Result<Vec<Instance>, EC2Error> {
292 let non_terminated = vec![
293 InstanceStateName::Pending,
294 InstanceStateName::Running,
295 InstanceStateName::ShuttingDown,
296 InstanceStateName::Stopping,
297 InstanceStateName::Stopped,
298 ];
299 if statuses.is_empty() {
300 statuses = non_terminated;
301 }
302 let response = self
303 .client
304 .describe_instances()
305 .set_filters(Some(vec![
306 Filter::builder()
307 .set_name(Some("tag:application".into()))
308 .set_values(Some(vec![GLOBAL_TAG_FILTER.into()]))
309 .build(),
310 Filter::builder()
311 .set_name(Some("instance-state-name".into()))
312 .set_values(Some(statuses.into_iter().map(|s| s.to_string()).collect()))
313 .build(),
314 ]))
315 .send()
316 .await?;
317
318 let instances: Vec<_> = response
319 .reservations()
320 .iter()
321 .flat_map(|r| r.instances().to_owned())
322 .collect();
323
324 Ok(instances)
325 }
326
327 pub async fn start_instances(&self, instance_id: &str) -> Result<(), EC2Error> {
328 tracing::info!("Starting instance {instance_id}");
329
330 let mut starter = self.client.start_instances();
331 for id in instance_id.split(",") {
332 starter = starter.instance_ids(id);
333 }
334 starter.send().await?;
335
336 tracing::info!("Started instance.");
337
338 Ok(())
339 }
340
341 pub async fn stop_instances(&self, instance_ids: &str, wait: bool) -> Result<(), EC2Error> {
342 tracing::info!("Stopping instance {instance_ids}");
343
344 let mut stopper = self.client.stop_instances();
345 for id in instance_ids.split(",") {
346 stopper = stopper.instance_ids(id);
347 }
348 stopper.send().await?;
349
350 if wait {
351 self.wait_for_instance_stopped(instance_ids, None).await?;
352 tracing::info!("Stopped instance.");
353 }
354
355 Ok(())
356 }
357
358 pub async fn reboot_instance(&self, instance_id: &str) -> Result<(), EC2Error> {
359 tracing::info!("Rebooting instance {instance_id}");
360
361 self.client
362 .reboot_instances()
363 .instance_ids(instance_id)
364 .send()
365 .await?;
366
367 Ok(())
368 }
369
370 pub async fn wait_for_instance_stopped(
371 &self,
372 instance_ids: &str,
373 duration: Option<Duration>,
374 ) -> Result<(), EC2Error> {
375 let mut waiter = self.client.wait_until_instance_stopped();
376 for id in instance_ids.split(",") {
377 waiter = waiter.instance_ids(id);
378 }
379 waiter
380 .wait(duration.unwrap_or(Duration::from_secs(90)))
381 .await
382 .map_err(|err| match err {
383 WaiterError::ExceededMaxWait(exceeded) => EC2Error(format!(
384 "Exceeded max time ({}s) waiting for instance to stop.",
385 exceeded.max_wait().as_secs(),
386 )),
387 _ => EC2Error::from(err),
388 })?;
389
390 Ok(())
391 }
392
393 pub async fn delete_instances(&self, instance_ids: &str, wait: bool) -> Result<(), EC2Error> {
394 tracing::info!("Deleting instance with id {:?}", instance_ids);
395
396 self.stop_instances(instance_ids, true).await?;
397
398 let mut terminator = self.client.terminate_instances();
399 for id in instance_ids.split(",") {
400 terminator = terminator.instance_ids(id);
401 }
402 terminator.send().await?;
403
404 if wait {
405 self.wait_for_instance_terminated(instance_ids).await?;
406 tracing::info!("Terminated instance with ids {:?}", instance_ids);
407 }
408
409 Ok(())
410 }
411
412 async fn wait_for_instance_terminated(&self, instance_ids: &str) -> Result<(), EC2Error> {
413 let mut waiter = self.client.wait_until_instance_terminated();
414 for id in instance_ids.split(",") {
415 waiter = waiter.instance_ids(id);
416 }
417 waiter
418 .wait(Duration::from_secs(60))
419 .await
420 .map_err(|err| match err {
421 WaiterError::ExceededMaxWait(exceeded) => EC2Error(format!(
422 "Exceeded max time ({}s) waiting for instance to terminate.",
423 exceeded.max_wait().as_secs(),
424 )),
425 _ => EC2Error::from(err),
426 })?;
427 Ok(())
428 }
429
430 async fn update_inbound_ip(&self, group_id: &str) -> Result<(), EC2Error> {
434 let check_ip = Util::do_get("https://checkip.amazonaws.com").await?;
435 tracing::info!("Current IP address = {}", check_ip);
436
437 let current_ip_address: Ipv4Addr = check_ip.trim().parse().map_err(|e| {
438 EC2Error::new(format!(
439 "Failed to convert response {} to IP Address: {e:?}",
440 check_ip
441 ))
442 })?;
443
444 if let Err(err) = self
445 .authorize_security_group_ssh_ingress(group_id, vec![current_ip_address])
446 .await
447 {
448 tracing::warn!("Most likely inbound rule already exists. Err = {err}");
449 };
450
451 Ok(())
452 }
453
454 pub async fn get_ssh_security_group(&self) -> Result<SecurityGroup, EC2Error> {
456 let group = match self
457 .create_security_group(
458 SSH_SECURITY_GROUP,
459 "Enables ssh into instance from your IP.",
460 )
461 .await
462 {
463 Ok(grp) => grp,
464 Err(err) => {
465 let res = self.describe_security_group(SSH_SECURITY_GROUP).await?;
467
468 if res.is_none() {
469 return Err(err);
470 }
471
472 res.unwrap()
473 }
474 };
475
476 self.update_inbound_ip(group.group_id.as_ref().unwrap())
477 .await?;
478
479 Ok(group)
480 }
481}
482
483#[derive(Debug)]
484pub struct EC2Error(String);
485impl EC2Error {
486 pub fn new(value: impl Into<String>) -> Self {
487 EC2Error(value.into())
488 }
489
490 pub fn _add_message(self, message: impl Into<String>) -> Self {
491 EC2Error(format!("{}: {}", message.into(), self.0))
492 }
493}
494
495impl<T: ProvideErrorMetadata> From<T> for EC2Error {
496 fn from(value: T) -> Self {
497 EC2Error(format!(
498 "{}: {}",
499 value
500 .code()
501 .map(String::from)
502 .unwrap_or("unknown code".into()),
503 value
504 .message()
505 .map(String::from)
506 .unwrap_or("missing reason (most likely profile credentials not set)".into()),
507 ))
508 }
509}
510
511impl std::error::Error for EC2Error {}
512
513impl std::fmt::Display for EC2Error {
514 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
515 write!(f, "{}", self.0)
516 }
517}