tsunami/providers/
aws.rs

1//! AWS backend for tsunami.
2//!
3//! The primary `impl Launcher` type is [`Launcher`].
4//! It internally uses the lower-level, region-specific [`aws::RegionLauncher`].
5//! Both these types use [`aws::Setup`] as their descriptor type.
6//!
7//! By default, this implementation uses 6-hour [defined
8//! duration](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html#fixed-duration-spot-instances)
9//! spot instances. You can switch to on-demand instances using [`Launcher::set_mode`].
10//!
11//! # Examples
12//! ```rust,no_run
13//! #[tokio::main]
14//! async fn main() {
15//!     use tsunami::Tsunami;
16//!     use tsunami::providers::aws;
17//!
18//!     let mut l = aws::Launcher::default();
19//!     // make the defined-duration instances expire after 1 hour
20//!     l.set_mode(aws::LaunchMode::duration_spot(1));
21//!     l.spawn(vec![(String::from("my machine"), aws::Setup::default())], None).await.unwrap();
22//!     let vms = l.connect_all().await.unwrap();
23//!     let my_machine = vms.get("my machine").unwrap();
24//!     let out = my_machine
25//!         .ssh
26//!         .command("echo")
27//!         .arg("\"Hello, EC2\"")
28//!         .output()
29//!         .await
30//!         .unwrap();
31//!     let stdout = std::string::String::from_utf8(out.stdout).unwrap();
32//!     println!("{}", stdout);
33//!     l.terminate_all().await.unwrap();
34//! }
35//! ```
36//! ```rust,no_run
37//! use rusoto_core::{credential::DefaultCredentialsProvider};
38//! use tsunami::Tsunami;
39//! use tsunami::providers::aws::{self, Region};
40//! #[tokio::main]
41//! async fn main() -> Result<(), color_eyre::Report> {
42//!     // Initialize AWS
43//!     let mut aws = aws::Launcher::default();
44//!     // make the defined-duration instances expire after 1 hour
45//!     // default is the maximum (6 hours)
46//!     aws.set_mode(aws::LaunchMode::duration_spot(1)).open_ports();
47//!
48//!     // Create a machine descriptor and add it to the Tsunami
49//!     let m = aws::Setup::default()
50//!         .region_with_ubuntu_ami(Region::UsWest1) // default is UsEast1
51//!         .await
52//!         .unwrap()
53//!         .setup(|vm| {
54//!             // default is a no-op
55//!             Box::pin(async move {
56//!                 vm.ssh.command("sudo")
57//!                     .arg("apt")
58//!                     .arg("update")
59//!                     .status()
60//!                     .await?;
61//!                 vm.ssh.command("bash")
62//!                     .arg("-c")
63//!                     .arg("\"curl https://sh.rustup.rs -sSf | sh -- -y\"")
64//!                     .status()
65//!                     .await?;
66//!                 Ok(())
67//!             })
68//!         });
69//!
70//!     // Launch the VM
71//!     aws.spawn(vec![(String::from("my_vm"), m)], None).await?;
72//!
73//!     // SSH to the VM and run a command on it
74//!     let vms = aws.connect_all().await?;
75//!     let my_vm = vms.get("my_vm").unwrap();
76//!     println!("public ip: {}", my_vm.public_ip);
77//!     my_vm.ssh
78//!         .command("git")
79//!         .arg("clone")
80//!         .arg("https://github.com/jonhoo/tsunami")
81//!         .status()
82//!         .await?;
83//!     my_vm.ssh
84//!         .command("bash")
85//!         .arg("-c")
86//!         .arg("\"cd tsunami && cargo build\"")
87//!         .status()
88//!         .await?;
89//!     aws.terminate_all().await?;
90//!     Ok(())
91//! }
92//! ```
93
94use color_eyre::{
95    eyre::{self, eyre, WrapErr},
96    Report,
97};
98use educe::Educe;
99use itertools::Itertools;
100use rusoto_core::credential::{DefaultCredentialsProvider, ProvideAwsCredentials};
101use rusoto_core::request::HttpClient;
102pub use rusoto_core::Region;
103use rusoto_ec2::Ec2;
104use std::collections::HashMap;
105use std::future::Future;
106use std::io::Write;
107use std::pin::Pin;
108use std::sync::Arc;
109use std::time;
110use tracing::instrument;
111use tracing_futures::Instrument;
112
113/// Dictate how a set of instances should be launched.
114#[derive(Debug, Clone)]
115#[allow(missing_copy_implementations)]
116#[non_exhaustive]
117pub enum LaunchMode {
118    /// Use AWS [defined duration](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html#fixed-duration-spot-instances)
119    /// spot instances. Fails with an error if there is no spot capacity.
120    DefinedDuration {
121        /// The lifetime of the defined duration instances.
122        /// This value must be between 1 and 6 hours.
123        hours: usize,
124    },
125    /// Try to use AWS [defined duration](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html#fixed-duration-spot-instances)
126    /// spot instances. If that fails, e.g. due to lack of capacity, use `OnDemand` instead.
127    TrySpot {
128        /// The lifetime of the defined duration instances.
129        /// This value must be between 1 and 6 hours.
130        hours: usize,
131    },
132    /// Use regular AWS on-demand instances.
133    OnDemand,
134}
135
136impl LaunchMode {
137    /// Launch using AWS [defined duration](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html#fixed-duration-spot-instances) spot instances.
138    ///
139    /// The lifetime of such instances must be declared in advance (1-6 hours).
140    /// This method thus clamps `hours` to be between 1 and 6.
141    pub fn duration_spot(hours: usize) -> Self {
142        let hours = std::cmp::min(hours, 6);
143        let hours = std::cmp::max(hours, 1);
144        Self::DefinedDuration { hours }
145    }
146
147    /// Try to launch using AWS [defined
148    /// duration](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html#fixed-duration-spot-instances)
149    /// spot instances, and fall back to OnDemand instances otherwise.
150    ///
151    /// The lifetime of such instances must be declared in advance (1-6 hours).
152    /// This method thus clamps `hours` to be between 1 and 6.
153    pub fn try_duration_spot(hours: usize) -> Self {
154        match Self::duration_spot(hours) {
155            Self::DefinedDuration { hours } => Self::TrySpot { hours },
156            _ => unreachable!(),
157        }
158    }
159
160    /// Launch using regular AWS on-demand instances.
161    pub fn on_demand() -> Self {
162        Self::OnDemand
163    }
164}
165
166/// Available configurations of availability zone specifiers.
167///
168/// See [the aws docs](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#using-regions-availability-zones-launching) for more information.
169#[derive(Debug, Clone)]
170pub enum AvailabilityZoneSpec {
171    /// `Any` (the default) will place the instance anywhere there is capacity.
172    Any,
173    /// `Cluster` will group instances by the given `usize` id, and ensure that each group is
174    /// placed in the same availability zone. To specify exactly which availability zone the
175    /// machines should be placed in, see `AvailabilityZoneSpec::Specify`.
176    Cluster(usize),
177    /// `Specify` will place all the instances in the named availability zone.
178    ///
179    /// The string should give the full name of the availability zone, such as `us-east-1a`.
180    Specify(String),
181}
182
183impl Default for AvailabilityZoneSpec {
184    fn default() -> Self {
185        Self::Any
186    }
187}
188
189impl std::fmt::Display for AvailabilityZoneSpec {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        match *self {
192            AvailabilityZoneSpec::Any => write!(f, "any"),
193            AvailabilityZoneSpec::Cluster(ref i) => write!(f, "cluster({})", i),
194            AvailabilityZoneSpec::Specify(ref s) => write!(f, "{}", s),
195        }
196    }
197}
198
199/// A descriptor for a particular machine setup in a tsunami.
200///
201/// The default region and ami is Ubuntu 18.04 LTS in us-east-1. The default AMI is updated on a
202/// passive basis, so you almost certainly want to call one of:
203/// - [`Setup::region_with_ubuntu_ami`]
204/// - [`Setup::ami`]
205/// - [`Setup::region`]
206/// to change these defaults.
207#[derive(Clone, Educe)]
208#[educe(Debug)]
209pub struct Setup {
210    region: Region,
211    availability_zone: AvailabilityZoneSpec,
212    instance_type: String,
213    ami: String,
214    username: String,
215    #[educe(Debug(ignore))]
216    setup_fn: Option<
217        Arc<
218            dyn for<'r> Fn(
219                    &'r crate::Machine<'_>,
220                )
221                    -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'r>>
222                + Send
223                + Sync
224                + 'static,
225        >,
226    >,
227}
228
229impl super::MachineSetup for Setup {
230    type Region = String;
231
232    fn region(&self) -> Self::Region {
233        match self.availability_zone {
234            AvailabilityZoneSpec::Specify(ref id) => format!("{}-{}", self.region.name(), id),
235            AvailabilityZoneSpec::Cluster(id) => format!("{}-{}", self.region.name(), id),
236            AvailabilityZoneSpec::Any => self.region.name().to_string(),
237        }
238    }
239}
240
241impl Default for Setup {
242    fn default() -> Self {
243        Setup {
244            region: Region::UsEast1,
245            availability_zone: AvailabilityZoneSpec::Any,
246            instance_type: "t3.small".into(),
247            ami: String::from("ami-085925f297f89fce1"),
248            username: "ubuntu".into(),
249            setup_fn: None,
250        }
251    }
252}
253
254impl Setup {
255    /// Set up the machine in a specific EC2
256    /// [`Region`](http://rusoto.github.io/rusoto/rusoto_core/region/enum.Region.html).
257    ///
258    /// The default region is us-east-1. [Available regions are listed
259    /// here.](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#concepts-available-regions)
260    ///
261    /// AMIs are region-specific.  This method uses
262    /// [`ubuntu-ami`](https://crates.io/crates/ubuntu-ami), which queries [Ubuntu's cloud image
263    /// list](https://cloud-images.ubuntu.com/) to get the latest Ubuntu 18.04 LTS AMI in the
264    /// selected region.
265    pub async fn region_with_ubuntu_ami(mut self, region: Region) -> Result<Self, Report> {
266        self.region = region.clone();
267        let ami: String = UbuntuAmi::new(region).await?.into();
268        Ok(self.ami(ami, "ubuntu"))
269    }
270
271    /// Set the username used to ssh into the machine.
272    ///
273    /// If the user sets a custom AMI, they must call this method to
274    /// set a username.
275    pub fn username(self, username: impl ToString) -> Self {
276        Self {
277            username: username.to_string(),
278            ..self
279        }
280    }
281
282    /// The new instance will start out in the state dictated by the Amazon Machine Image specified
283    /// in `ami`. Default is Ubuntu 18.04 LTS.
284    pub fn ami(self, ami: impl ToString, username: impl ToString) -> Self {
285        Self {
286            ami: ami.to_string(),
287            username: username.to_string(),
288            ..self
289        }
290    }
291
292    /// The given AWS EC2 instance type will be used.
293    ///
294    /// Note that only [EC2 Defined Duration Spot
295    /// Instance types](https://aws.amazon.com/ec2/spot/pricing/) are allowed.
296    pub fn instance_type(mut self, typ: impl ToString) -> Self {
297        self.instance_type = typ.to_string();
298        self
299    }
300
301    /// Specify instance setup.
302    ///
303    /// The provided callback, `setup`, is called once
304    /// for every spawned instances of this type with a handle
305    /// to the target machine. Use [`Machine::ssh`] to issue
306    /// commands on the host in question.
307    ///
308    /// # Example
309    ///
310    /// ```rust
311    /// use tsunami::providers::aws::Setup;
312    ///
313    /// let m = Setup::default().setup(|vm| {
314    ///     Box::pin(async move {
315    ///         vm.ssh
316    ///             .command("sudo")
317    ///             .arg("apt")
318    ///             .arg("update")
319    ///             .status()
320    ///             .await?;
321    ///         Ok(())
322    ///     })
323    /// });
324    /// ```
325    pub fn setup(
326        mut self,
327        setup: impl for<'r> Fn(
328                &'r crate::Machine<'_>,
329            ) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'r>>
330            + Send
331            + Sync
332            + 'static,
333    ) -> Self {
334        self.setup_fn = Some(Arc::new(setup));
335        self
336    }
337
338    /// Set up the machine in a specific EC2
339    /// [`Region`](http://rusoto.github.io/rusoto/rusoto_core/region/enum.Region.html).
340    ///
341    /// The default region is us-east-1. [Available regions are listed
342    /// here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#concepts-available-regions)
343    ///
344    /// AMIs are region-specific. Therefore, when changing the region a new ami must be given, with
345    /// a corresponding username. For a shortcut helper function that provides an Ubunti ami, see
346    /// `region_with_ubuntu_ami`.
347    pub fn region(mut self, region: Region, ami: impl ToString, username: impl ToString) -> Self {
348        self.region = region;
349        self.ami(ami, username)
350    }
351
352    /// Set up the machine in a specific EC2 availability zone.
353    ///
354    /// The default availability zone is unspecified - EC2 will launch the machine wherever there
355    /// is capacity.
356    pub fn availability_zone(self, az: AvailabilityZoneSpec) -> Self {
357        Self {
358            availability_zone: az,
359            ..self
360        }
361    }
362}
363
364/// AWS EC2 spot instance launcher.
365///
366/// This is a lower-level API. Most users will use [`crate::TsunamiBuilder::spawn`].
367///
368/// Each individual region is handled by `RegionLauncher`.
369///
370/// While the regions are initialized serially, the setup functions for each machine are executed
371/// in parallel (within each region).
372///
373/// By default, `Launcher` launches instances using 6-hour [defined
374/// duration](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html#fixed-duration-spot-instances)
375/// spot requests.
376#[derive(Educe)]
377#[educe(Debug)]
378pub struct Launcher<P = DefaultCredentialsProvider> {
379    #[educe(Debug(ignore))]
380    credential_provider: Box<dyn Fn() -> Result<P, Report> + Send + Sync>,
381    mode: LaunchMode,
382    use_open_ports: bool,
383    regions: HashMap<<Setup as super::MachineSetup>::Region, RegionLauncher>,
384}
385
386impl Default for Launcher {
387    fn default() -> Self {
388        Launcher {
389            credential_provider: Box::new(|| Ok(DefaultCredentialsProvider::new()?)),
390            mode: LaunchMode::DefinedDuration { hours: 6 },
391            use_open_ports: false,
392            regions: Default::default(),
393        }
394    }
395}
396
397impl<P> Launcher<P> {
398    /// Set defined duration instance max instance duration.
399    ///
400    /// The lifetime of such instances must be declared in advance (1-6 hours).
401    /// This method thus clamps `t` to be between 1 and 6.
402    ///
403    /// By default, we use 6 hours (the maximum).
404    ///
405    /// This method also changes the mode to [`LaunchMode::DefinedDuration`].
406    #[deprecated(note = "prefer set_mode")]
407    pub fn set_max_instance_duration(&mut self, t: usize) -> &mut Self {
408        self.mode = LaunchMode::duration_spot(t);
409        self
410    }
411
412    /// Set the launch mode to use for future instances.
413    ///
414    /// See [`LaunchMode`] for more details.
415    pub fn set_mode(&mut self, mode: LaunchMode) -> &mut Self {
416        self.mode = mode;
417        self
418    }
419
420    /// The machines spawned on this launcher will have
421    /// ports open to the public Internet.
422    pub fn open_ports(&mut self) -> &mut Self {
423        self.use_open_ports = true;
424        self
425    }
426
427    /// Set the credential provider used to authenticate to EC2.
428    ///
429    /// The provided function is called once for each region, and is expected to produce a
430    /// [`P: ProvideAwsCredentials`](https://docs.rs/rusoto_core/0.40.0/rusoto_core/trait.ProvideAwsCredentials.html)
431    /// that gives access to the region in question.
432    pub fn with_credentials<P2>(
433        self,
434        f: impl Fn() -> Result<P2, Report> + Send + Sync + 'static,
435    ) -> Launcher<P2> {
436        Launcher {
437            credential_provider: Box::new(f),
438            mode: self.mode,
439            use_open_ports: self.use_open_ports,
440            regions: self.regions,
441        }
442    }
443}
444
445impl<P> super::Launcher for Launcher<P>
446where
447    P: ProvideAwsCredentials + Send + Sync + 'static,
448{
449    type MachineDescriptor = Setup;
450
451    #[instrument(level = "debug", skip(self))]
452    fn launch<'l>(
453        &'l mut self,
454        l: super::LaunchDescriptor<Self::MachineDescriptor>,
455    ) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'l>> {
456        Box::pin(async move {
457            let prov = (*self.credential_provider)()?;
458            let Self {
459                use_open_ports,
460                mode,
461                ref mut regions,
462                ..
463            } = self;
464
465            if !regions.contains_key(&l.region) {
466                let region_span = tracing::debug_span!("new_region", name = %l.region, az = %l.machines[0].1.availability_zone);
467                let awsregion = RegionLauncher::new(
468                    // region name and availability_zone spec are guaranteed to be the same because
469                    // they are included in the region specifier.
470                    l.machines[0].1.region.name(),
471                    l.machines[0].1.availability_zone.clone(),
472                    prov,
473                    *use_open_ports,
474                )
475                .instrument(region_span)
476                .await?;
477                regions.insert(l.region.clone(), awsregion);
478            }
479
480            let region_span = tracing::debug_span!("region", name = %l.region);
481            regions
482                .get_mut(&l.region)
483                .unwrap()
484                .launch(mode.clone(), l.max_wait, l.machines)
485                .instrument(region_span)
486                .await?;
487            Ok(())
488        }.in_current_span())
489    }
490
491    #[instrument(level = "debug", skip(self, max_wait))]
492    fn spawn<'l, I>(
493        &'l mut self,
494        descriptors: I,
495        max_wait: Option<std::time::Duration>,
496    ) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send + 'l>>
497    where
498        I: IntoIterator<Item = (String, Self::MachineDescriptor)> + Send + 'static,
499        I: std::fmt::Debug,
500        I::IntoIter: Send,
501    {
502        use super::MachineSetup;
503        Box::pin(
504            async move {
505                tracing::info!("spinning up tsunami");
506
507                // group by region
508                let names_to_setups = descriptors
509                    .into_iter()
510                    .map(|(name, setup)| (MachineSetup::region(&setup), (name, setup)))
511                    .into_group_map();
512
513                // separate into two lists:
514                // 1. we already have a RegionLauncher
515                // 2. we don't
516                let (mut haves, have_nots): (Vec<_>, Vec<_>) = names_to_setups
517                    .into_iter()
518                    .partition(|(region_name, _)| self.regions.contains_key(region_name));
519
520                // check that this works before unwrap() below
521                let _prov = (*self.credential_provider)()?;
522                let use_open_ports = self.use_open_ports;
523
524                let newly_initialized: Vec<Result<_, _>> =
525                    futures_util::future::join_all(have_nots.iter().map(|(region_name, s)| {
526                        let region_span = tracing::debug_span!("new_region", region = %region_name);
527                        let prov = (*self.credential_provider)().unwrap();
528                        async move {
529                            let awsregion = RegionLauncher::new(
530                                // region name and availability_zone spec are guaranteed to be the
531                                // same because they are included in the region specifier.
532                                s[0].1.region.name(),
533                                s[0].1.availability_zone.clone(),
534                                prov,
535                                use_open_ports,
536                            )
537                            .await?;
538                            Ok::<_, Report>((region_name.clone(), awsregion))
539                        }
540                        .instrument(region_span)
541                    }))
542                    .await;
543                self.regions.extend(
544                    newly_initialized
545                        .into_iter()
546                        .collect::<Result<Vec<_>, _>>()?,
547                );
548
549                // the have-nots are now haves
550                haves.extend(have_nots);
551
552                // Launch instances in the regions concurrently.
553                //
554                // The borrow checker can't know that each future only accesses one entry of the
555                // hashmap - for its RegionLauncher (guaranteed by the `into_group_map()` above).
556                // So, we help it by taking the appropriate RegionLauncher out of the hashmap,
557                // running `launch()`, then putting everything back later.
558                let max_wait = max_wait;
559                let regions = futures_util::future::join_all(haves.into_iter().map(
560                    |(region_name, machines)| {
561                        // unwrap ok because everything is a have now
562                        let mut region_launcher = self.regions.remove(&region_name).unwrap();
563                        let region_span = tracing::debug_span!("region", region = %region_name);
564                        let mode = self.mode.clone();
565                        async move {
566                            if let Err(e) = region_launcher.launch(mode, max_wait, machines).await {
567                                Err((region_name, region_launcher, e))
568                            } else {
569                                Ok((region_name, region_launcher))
570                            }
571                        }
572                        .instrument(region_span)
573                    },
574                ))
575                .await;
576
577                // Put our stuff back where we found it.
578                let (regions, res) =
579                    regions
580                        .into_iter()
581                        .fold((vec![], None), |acc, r| match (acc, r) {
582                            ((mut rs, x), Ok((name, rl))) => {
583                                rs.push((name, rl));
584                                (rs, x)
585                            }
586                            ((mut rs, None), Err((name, rl, e))) => {
587                                rs.push((name, rl));
588                                (rs, Some(e))
589                            }
590                            ((mut rs, x @ Some(_)), Err((name, rl, _))) => {
591                                rs.push((name, rl));
592                                (rs, x)
593                            }
594                        });
595                self.regions.extend(regions.into_iter());
596
597                if let Some(e) = res {
598                    Err(e)
599                } else {
600                    Ok(())
601                }
602            }
603            .in_current_span(),
604        )
605    }
606
607    #[instrument(level = "debug")]
608    fn connect_all<'l>(
609        &'l self,
610    ) -> Pin<
611        Box<dyn Future<Output = Result<HashMap<String, crate::Machine<'l>>, Report>> + Send + 'l>,
612    > {
613        Box::pin(async move { collect!(self.regions) }.in_current_span())
614    }
615
616    #[instrument(level = "debug")]
617    fn terminate_all(mut self) -> Pin<Box<dyn Future<Output = Result<(), Report>> + Send>> {
618        Box::pin(
619            async move {
620                if self.regions.is_empty() {
621                    return Ok(());
622                }
623
624                let res =
625                    futures_util::future::join_all(self.regions.drain().map(|(region, mut rl)| {
626                        let region_span = tracing::debug_span!("region", %region);
627                        async move { rl.terminate_all().await }.instrument(region_span)
628                    }))
629                    .await;
630                res.into_iter().fold(Ok(()), |acc, x| match acc {
631                    Ok(_) => x,
632                    Err(a) => match x {
633                        Ok(_) => Err(a),
634                        Err(e) => Err(a.wrap_err(e)),
635                    },
636                })
637            }
638            .in_current_span(),
639        )
640    }
641}
642
643#[derive(Debug, Clone)]
644struct IpInfo {
645    public_dns: String,
646    public_ip: String,
647    private_ip: String,
648}
649
650// Internal representation of an instance.
651//
652// Tagged with its nickname, and ip_info gets populated once it is available.
653#[derive(Debug, Clone)]
654struct TaggedSetup {
655    name: String,
656    setup: Setup,
657    ip_info: Option<IpInfo>,
658}
659
660/// Region specific. Launch AWS EC2 instances.
661///
662/// This implementation uses [rusoto](https://crates.io/crates/rusoto_core) to connect to AWS.
663///
664/// By default, `RegionLauncher` launches uses AWS [defined
665/// duration](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-requests.html#fixed-duration-spot-instances)
666/// spot instances. These cost slightly more than regular spot instances, but are never prematurely
667/// terminated.  The lifetime of such instances must be declared in advance (1-6 hours). By
668/// default, we use 6 hours (the maximum). To change this, or to switch to on-demand instances, use
669/// [`Launcher::set_mode`].
670///
671/// You must call [`RegionLauncher::shutdown`] to terminate the instances.
672#[derive(Educe, Default)]
673#[educe(Debug)]
674pub struct RegionLauncher {
675    /// The region this RegionLauncher is connected to.
676    pub region: rusoto_core::region::Region,
677    availability_zone: AvailabilityZoneSpec,
678    security_group_id: String,
679    ssh_key_name: String,
680    private_key_path: Option<tempfile::NamedTempFile>,
681    #[educe(Debug(ignore))]
682    client: Option<rusoto_ec2::Ec2Client>,
683    spot_requests: HashMap<String, TaggedSetup>,
684    instances: HashMap<String, TaggedSetup>,
685}
686
687impl RegionLauncher {
688    /// Connect to AWS region `region`, using credentials provider `provider`.
689    ///
690    /// This is a lower-level API, you may want [`Launcher`] instead.
691    ///
692    /// This will create a temporary security group and SSH key in the given AWS region.
693    pub async fn new<P>(
694        region: &str,
695        availability_zone: AvailabilityZoneSpec,
696        provider: P,
697        use_open_ports: bool,
698    ) -> Result<Self, Report>
699    where
700        P: ProvideAwsCredentials + Send + Sync + 'static,
701    {
702        let region = region.parse()?;
703        let ec2 = RegionLauncher::connect(region, availability_zone, provider)
704            .wrap_err("failed to connect to region")?
705            .make_security_group(use_open_ports)
706            .await
707            .wrap_err("failed to make security groups")?
708            .make_ssh_key()
709            .await
710            .wrap_err("failed to make ssh key")?;
711
712        Ok(ec2)
713    }
714
715    #[instrument(level = "debug", skip(provider))]
716    fn connect<P>(
717        region: rusoto_core::region::Region,
718        availability_zone: AvailabilityZoneSpec,
719        provider: P,
720    ) -> Result<Self, Report>
721    where
722        P: ProvideAwsCredentials + Send + Sync + 'static,
723    {
724        tracing::debug!("connecting to ec2");
725        let ec2 = rusoto_ec2::Ec2Client::new_with(
726            HttpClient::new().wrap_err("failed to construct new http client")?,
727            provider,
728            region.clone(),
729        );
730
731        Ok(Self {
732            region,
733            availability_zone,
734            security_group_id: Default::default(),
735            ssh_key_name: Default::default(),
736            private_key_path: Some(
737                tempfile::NamedTempFile::new()
738                    .wrap_err("failed to create temporary file for keypair")?,
739            ),
740            spot_requests: Default::default(),
741            instances: Default::default(),
742            client: Some(ec2),
743        })
744    }
745
746    /// Region-specific instance setup.
747    ///
748    /// Make spot instance requests, wait for the instances, and then call the
749    /// instance setup functions.
750    #[instrument(level = "debug", skip(self, max_wait))]
751    pub async fn launch<M>(
752        &mut self,
753        mode: LaunchMode,
754        mut max_wait: Option<time::Duration>,
755        machines: M,
756    ) -> Result<(), Report>
757    where
758        M: IntoIterator<Item = (String, Setup)> + std::fmt::Debug,
759    {
760        let machines: Vec<_> = machines.into_iter().collect();
761        let machines = machines.clone();
762        let mut do_ondemand = false;
763        match mode {
764            LaunchMode::TrySpot {
765                hours: max_instance_duration_hours,
766            }
767            | LaunchMode::DefinedDuration {
768                hours: max_instance_duration_hours,
769            } => {
770                let machines = machines.clone();
771
772                // leave this to short-circuit: we only want to fall back to OnDemand if there is
773                // no spot capacity, not if we can't make the request in the first place.
774                self.make_spot_instance_requests(
775                    max_instance_duration_hours * 60, // 60 mins/hr
776                    machines,
777                )
778                .await
779                .wrap_err("failed to make spot instance requests")?;
780
781                let start = time::Instant::now();
782                if let Err(e) = self
783                    .wait_for_spot_instance_requests(max_wait)
784                    .await
785                    .wrap_err(eyre!(
786                        "failed while waiting for spot instances fulfilment in {}",
787                        self.region.name()
788                    ))
789                {
790                    // if wait_for_spot_instance_requests returned an Err, it will have cleaned up
791                    // the spot instance requests already.
792                    if let LaunchMode::TrySpot { .. } = mode {
793                        tracing::debug!(err = ?e, "re-trying with OnDemand instace");
794                        do_ondemand = true;
795                    } else {
796                        return Err(e);
797                    }
798                } else {
799                    if let Some(ref mut d) = max_wait {
800                        *d -= time::Instant::now().duration_since(start);
801                    }
802                }
803            }
804            LaunchMode::OnDemand => {
805                do_ondemand = true;
806            }
807        }
808
809        if do_ondemand {
810            self.make_on_demand_requests(machines)
811                .await
812                .wrap_err(eyre!(
813                    "failed to start on demand instances in {}",
814                    self.region.name()
815                ))?;
816
817            // give EC2 a bit of time to discover the instances
818            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
819        }
820
821        self.wait_for_instances(max_wait)
822            .await
823            .wrap_err("failed while waiting for instances to come up")?;
824        Ok(())
825    }
826
827    #[instrument(level = "trace", skip(self))]
828    async fn make_security_group(mut self, use_open_ports: bool) -> Result<Self, Report> {
829        let ec2 = self.client.as_mut().expect("RegionLauncher unconnected");
830
831        // set up network firewall for machines
832        let group_name = super::rand_name("security");
833        tracing::debug!(name = %group_name, "creating security group");
834        let req = rusoto_ec2::CreateSecurityGroupRequest {
835            group_name,
836            description: "temporary access group for tsunami VMs".to_string(),
837            ..Default::default()
838        };
839        let res = ec2
840            .create_security_group(req)
841            .await
842            .wrap_err("failed to create security group for new machines")?;
843        let group_id = res
844            .group_id
845            .expect("aws created security group with no group id");
846        tracing::trace!(id = %group_id, "security group created");
847
848        let mut req = rusoto_ec2::AuthorizeSecurityGroupIngressRequest {
849            group_id: Some(group_id.clone()),
850            // icmp access
851            ip_protocol: Some("icmp".to_string()),
852            from_port: Some(-1),
853            to_port: Some(-1),
854            cidr_ip: Some("0.0.0.0/0".to_string()),
855            ..Default::default()
856        };
857        tracing::trace!("adding icmp access");
858        ec2.authorize_security_group_ingress(req.clone())
859            .await
860            .wrap_err("failed to fill in security group for new machines")?;
861
862        // allow SSH from anywhere
863        req.ip_protocol = Some("tcp".to_string());
864        req.from_port = Some(22);
865        req.to_port = Some(22);
866        req.cidr_ip = Some("0.0.0.0/0".to_string());
867        tracing::trace!("adding ssh access");
868        ec2.authorize_security_group_ingress(req.clone())
869            .await
870            .wrap_err("failed to fill in security group for new machines")?;
871
872        // The default VPC uses IPs in range 172.31.0.0/16:
873        // https://docs.aws.amazon.com/vpc/latest/userguide/default-vpc.html
874        // TODO(might-be-nice) Support configurable rules for other VPCs
875        req.ip_protocol = Some("tcp".to_string());
876        req.from_port = Some(0);
877        req.to_port = Some(65535);
878        if use_open_ports {
879            req.cidr_ip = Some("0.0.0.0/0".to_string());
880        } else {
881            req.cidr_ip = Some("172.31.0.0/16".to_string());
882        }
883
884        tracing::trace!("adding intra-vm tcp access");
885        ec2.authorize_security_group_ingress(req.clone())
886            .await
887            .wrap_err("failed to fill in security group for new machines")?;
888
889        req.ip_protocol = Some("udp".to_string());
890        req.from_port = Some(0);
891        req.to_port = Some(65535);
892        if use_open_ports {
893            req.cidr_ip = Some("0.0.0.0/0".to_string());
894        } else {
895            req.cidr_ip = Some("172.31.0.0/16".to_string());
896        }
897
898        tracing::trace!("adding intra-vm udp access");
899        ec2.authorize_security_group_ingress(req)
900            .await
901            .wrap_err("failed to fill in security group for new machines")?;
902
903        self.security_group_id = group_id;
904        Ok(self)
905    }
906
907    #[instrument(level = "trace", skip(self))]
908    async fn make_ssh_key(mut self) -> Result<Self, Report> {
909        let ec2 = self.client.as_mut().expect("RegionLauncher unconnected");
910        let private_key_path = self
911            .private_key_path
912            .as_mut()
913            .expect("RegionLauncher unconnected");
914
915        // construct keypair for ssh access
916        tracing::debug!("creating keypair");
917        let key_name = super::rand_name("key");
918        let req = rusoto_ec2::CreateKeyPairRequest {
919            key_name: key_name.clone(),
920            ..Default::default()
921        };
922        let res = ec2
923            .create_key_pair(req)
924            .await
925            .context("failed to generate new key pair")?;
926        tracing::trace!(fingerprint = ?res.key_fingerprint, "created keypair");
927
928        // write keypair to disk
929        let private_key = res
930            .key_material
931            .expect("aws did not generate key material for new key");
932        private_key_path
933            .write_all(private_key.as_bytes())
934            .context("could not write private key to file")?;
935        tracing::debug!(
936            filename = %private_key_path.path().display(),
937            "wrote keypair to file"
938        );
939
940        self.ssh_key_name = key_name;
941        Ok(self)
942    }
943
944    /// Make a new placement for a launch request.
945    ///
946    /// This method takes a "placement maker" (`mk`) to allow using this method for both
947    /// `SpotPlacement` and `Placement`. The `mk` function is passed a placement name and an
948    /// availability zone, and is expected to return an appropriate placement type.
949    #[instrument(level = "trace", skip(self, mk))]
950    async fn make_placement<R>(
951        &mut self,
952        mk: impl FnOnce(String, Option<String>) -> R,
953    ) -> Result<Option<R>, Report> {
954        if let AvailabilityZoneSpec::Any = self.availability_zone {
955            Ok(None)
956        } else {
957            let ec2 = self.client.as_mut().expect("RegionLauncher unconnected");
958            tracing::trace!("creating placement group");
959            let placement_name = super::rand_name("placement");
960            let req = rusoto_ec2::CreatePlacementGroupRequest {
961                group_name: Some(placement_name.clone()),
962                strategy: Some(String::from("cluster")),
963                ..Default::default()
964            };
965            ec2.create_placement_group(req).await?;
966            tracing::trace!("created placement group");
967
968            Ok(Some(mk(
969                placement_name,
970                match self.availability_zone {
971                    AvailabilityZoneSpec::Cluster(_) => None,
972                    AvailabilityZoneSpec::Specify(ref av) => Some(av.clone()),
973                    _ => unreachable!(),
974                },
975            )))
976        }
977    }
978
979    fn for_each_machine_group<M>(
980        machines: M,
981    ) -> impl Iterator<Item = ((String, String), Vec<(String, Setup)>)> + Send
982    where
983        M: IntoIterator<Item = (String, Setup)>,
984        M: std::fmt::Debug,
985    {
986        // minimize the number of instance requests:
987        machines
988            .into_iter()
989            .map(|(name, m)| {
990                // attach labels (ami name, instance type):
991                // the only fields that vary between tsunami spot instance requests
992                ((m.ami.clone(), m.instance_type.clone()), (name, m))
993            })
994            .into_group_map()
995            .into_iter()
996    }
997
998    #[instrument(level = "trace", skip(self))]
999    async fn make_on_demand_requests<M>(&mut self, machines: M) -> Result<(), Report>
1000    where
1001        M: IntoIterator<Item = (String, Setup)>,
1002        M: std::fmt::Debug,
1003    {
1004        tracing::info!("launching on demand instances");
1005
1006        // minimize the number of instance requests:
1007        for ((ami, instance_type), reqs) in Self::for_each_machine_group(machines) {
1008            let inst_span = tracing::debug_span!("run_instance", ?ami, ?instance_type);
1009            async {
1010                // and issue one spot request per group
1011                let placement = self
1012                    .make_placement(|group_name, az| rusoto_ec2::Placement {
1013                        group_name: Some(group_name),
1014                        availability_zone: az,
1015                        ..Default::default()
1016                    })
1017                    .await
1018                    .wrap_err("create new placement group")?;
1019                let req = rusoto_ec2::RunInstancesRequest {
1020                    image_id: Some(ami),
1021                    instance_type: Some(instance_type),
1022                    placement,
1023                    security_group_ids: Some(vec![self.security_group_id.clone()]),
1024                    key_name: Some(self.ssh_key_name.clone()),
1025                    min_count: reqs.len() as i64,
1026                    max_count: reqs.len() as i64,
1027                    instance_initiated_shutdown_behavior: Some("terminate".to_string()),
1028                    ..Default::default()
1029                };
1030
1031                // TODO: VPC
1032
1033                tracing::trace!("issuing request");
1034                let res = self
1035                    .client
1036                    .as_mut()
1037                    .unwrap()
1038                    .run_instances(req)
1039                    .await
1040                    .wrap_err("failed to request on demand instances")?;
1041
1042                // collect for length check below
1043                let instances: Vec<String> = res
1044                    .instances
1045                    .expect("run_instances should always return instances")
1046                    .into_iter()
1047                    .filter_map(|i| i.instance_id)
1048                    .inspect(|instance_id| {
1049                        tracing::trace!(id = %instance_id, "launched on-demand instance");
1050                    })
1051                    .collect();
1052
1053                // zip_eq will panic if lengths not equal, so check beforehand
1054                eyre::ensure!(
1055                    instances.len() == reqs.len(),
1056                    "Got {} instances but expected {}",
1057                    instances.len(),
1058                    reqs.len(),
1059                );
1060
1061                self.instances
1062                    .extend(instances.into_iter().zip_eq(reqs.into_iter()).map(
1063                        |(instance_id, (name, setup))| {
1064                            let setup = TaggedSetup {
1065                                name,
1066                                setup,
1067                                ip_info: None,
1068                            };
1069                            (instance_id, setup)
1070                        },
1071                    ));
1072
1073                Ok(())
1074            }
1075            .instrument(inst_span)
1076            .await?;
1077        }
1078
1079        Ok(())
1080    }
1081
1082    /// Make one-time spot instance requests, which will automatically get terminated after
1083    /// `max_duration` minutes.
1084    ///
1085    /// `machines` is a key-value iterator: keys are friendly names for the machines, and values
1086    /// are [`Setup`] describing each machine to launch. Once the machines launch,
1087    /// the friendly names are tied to SSH connections ([`crate::Machine`]) in the `HashMap` that
1088    /// [`connect_all`](RegionLauncher::connect_all) returns.
1089    ///
1090    /// Will *not* wait for the spot instance requests to complete. To wait, call
1091    /// [`wait_for_spot_instance_requests`](RegionLauncher::wait_for_spot_instance_requests).
1092    #[instrument(level = "trace", skip(self, max_duration))]
1093    async fn make_spot_instance_requests<M>(
1094        &mut self,
1095        max_duration: usize,
1096        machines: M,
1097    ) -> Result<(), Report>
1098    where
1099        M: IntoIterator<Item = (String, Setup)>,
1100        M: std::fmt::Debug,
1101    {
1102        tracing::info!("launching spot requests");
1103
1104        // minimize the number of spot requests:
1105        for ((ami, instance_type), reqs) in Self::for_each_machine_group(machines) {
1106            let spot_span = tracing::debug_span!("spot_request", ?ami, ?instance_type);
1107            async {
1108                // and issue one spot request per group
1109                let placement = self
1110                    .make_placement(|group_name, az| rusoto_ec2::SpotPlacement {
1111                        group_name: Some(group_name),
1112                        availability_zone: az,
1113                        ..Default::default()
1114                    })
1115                    .await
1116                    .wrap_err("create new placement group")?;
1117                let launch = rusoto_ec2::RequestSpotLaunchSpecification {
1118                    image_id: Some(ami),
1119                    instance_type: Some(instance_type),
1120                    placement,
1121                    security_group_ids: Some(vec![self.security_group_id.clone()]),
1122                    key_name: Some(self.ssh_key_name.clone()),
1123                    ..Default::default()
1124                };
1125
1126                // TODO: VPC
1127
1128                let req = rusoto_ec2::RequestSpotInstancesRequest {
1129                    instance_count: Some(reqs.len() as i64),
1130                    block_duration_minutes: Some(max_duration as i64),
1131                    launch_specification: Some(launch),
1132                    // one-time spot instances are only fulfilled once and therefore do not need to be
1133                    // cancelled.
1134                    type_: Some("one-time".into()),
1135                    ..Default::default()
1136                };
1137
1138                tracing::trace!("issuing spot request");
1139                let res = self
1140                    .client
1141                    .as_mut()
1142                    .unwrap()
1143                    .request_spot_instances(req)
1144                    .await
1145                    .wrap_err("failed to request spot instance")?;
1146
1147                // collect for length check below
1148                let spot_instance_requests: Vec<String> = res
1149                    .spot_instance_requests
1150                    .expect("request_spot_instances should always return spot instance requests")
1151                    .into_iter()
1152                    .filter_map(|sir| sir.spot_instance_request_id)
1153                    .inspect(|request_id| {
1154                        tracing::trace!(id = %request_id, "activated spot request");
1155                    })
1156                    .collect();
1157
1158                // zip_eq will panic if lengths not equal, so check beforehand
1159                eyre::ensure!(
1160                    spot_instance_requests.len() == reqs.len(),
1161                    "Got {} spot instance requests but expected {}",
1162                    spot_instance_requests.len(),
1163                    reqs.len(),
1164                );
1165
1166                for (request_id, (name, setup)) in
1167                    spot_instance_requests.into_iter().zip_eq(reqs.into_iter())
1168                {
1169                    self.spot_requests.insert(
1170                        request_id,
1171                        TaggedSetup {
1172                            name,
1173                            setup,
1174                            ip_info: None,
1175                        },
1176                    );
1177                }
1178
1179                Ok(())
1180            }
1181            .instrument(spot_span)
1182            .await?;
1183        }
1184
1185        Ok(())
1186    }
1187
1188    /// Poll AWS once a second until either `max_wait` (if not `None`) elapses, or
1189    /// the spot requests are fulfilled.
1190    ///
1191    /// This method will return when the spot requests are fulfilled, *not* when the instances are
1192    /// ready.
1193    ///
1194    /// To wait for the instances to be ready, call
1195    /// [`wait_for_instances`](RegionLauncher::wait_for_instances).
1196    #[instrument(level = "trace", skip(self, max_wait))]
1197    async fn wait_for_spot_instance_requests(
1198        &mut self,
1199        max_wait: Option<time::Duration>,
1200    ) -> Result<(), Report> {
1201        tracing::info!("waiting for instances to spawn");
1202
1203        let start = time::Instant::now();
1204
1205        loop {
1206            tracing::trace!("checking spot request status");
1207            let instances = self.describe_spot_instance_requests().await?;
1208
1209            let mut any_pending = false;
1210            for (request_id, state, status, instance_id) in &instances {
1211                match &**state {
1212                    "active" if instance_id.is_some() => {
1213                        tracing::trace!(%request_id, %state, ?instance_id, "spot instance request ready");
1214                    }
1215                    "active" | "open" => {
1216                        any_pending = true;
1217                    }
1218                    s => {
1219                        // closed | failed | cancelled
1220                        let _ = self.cancel_spot_instance_requests().await;
1221                        eyre::bail!("spot request unexpectedly {}: {}", s, status);
1222                    }
1223                }
1224            }
1225            let all_active = !any_pending;
1226
1227            if all_active {
1228                // unwraps okay because they are the same as expects above
1229                self.instances = instances
1230                    .into_iter()
1231                    .map(|(request_id, state, _, instance_id)| {
1232                        assert_eq!(state, "active");
1233                        let instance_id = instance_id.unwrap();
1234                        let setup = self.spot_requests[&request_id].clone();
1235                        (instance_id, setup)
1236                    })
1237                    .collect();
1238                break;
1239            }
1240
1241            // let's not hammer the API
1242            tokio::time::sleep(time::Duration::from_secs(1)).await;
1243
1244            if let Some(wait_limit) = max_wait {
1245                if start.elapsed() <= wait_limit {
1246                    continue;
1247                }
1248                self.cancel_spot_instance_requests().await?;
1249                eyre::bail!("wait limit reached");
1250            }
1251        }
1252
1253        Ok(())
1254    }
1255
1256    /// Poll AWS until `max_wait` (if not `None`) or the instances are ready to SSH to.
1257    #[instrument(level = "trace", skip(self, max_wait))]
1258    async fn wait_for_instances(&mut self, max_wait: Option<time::Duration>) -> Result<(), Report> {
1259        let start = time::Instant::now();
1260        let desc_req = rusoto_ec2::DescribeInstancesRequest {
1261            instance_ids: Some(self.instances.keys().cloned().collect()),
1262            ..Default::default()
1263        };
1264        let client = self.client.as_ref().unwrap();
1265        let private_key_path = self.private_key_path.as_ref().unwrap();
1266        let mut all_ready = self.instances.is_empty();
1267        while !all_ready {
1268            all_ready = true;
1269
1270            for reservation in client
1271                .describe_instances(desc_req.clone())
1272                .await
1273                .wrap_err("could not query AWS for instance state")?
1274                .reservations
1275                .unwrap_or_else(Vec::new)
1276            {
1277                for instance in reservation.instances.unwrap_or_else(Vec::new) {
1278                    match instance {
1279                        // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_InstanceState.html
1280                        // code 16 means "Running"
1281                        rusoto_ec2::Instance {
1282                            state: Some(rusoto_ec2::InstanceState { code: Some(16), .. }),
1283                            instance_id: Some(instance_id),
1284                            public_dns_name: Some(public_dns),
1285                            public_ip_address: Some(public_ip),
1286                            private_ip_address: Some(private_ip),
1287                            ..
1288                        } => {
1289                            let instance_span =
1290                                tracing::debug_span!("instance", %instance_id, ip = %public_ip);
1291                            let instances = &mut self.instances;
1292                            async {
1293                                tracing::trace!("instance running");
1294
1295                                // try connecting. If can't, not ready.
1296                                let tag_setup = instances.get_mut(&instance_id).unwrap();
1297
1298                                // no need to set public dns nor private ip since `connect_ssh` only uses the public ip
1299                                let m = crate::MachineDescriptor {
1300                                    nickname: Default::default(),
1301                                    public_dns: Default::default(),
1302                                    public_ip: public_ip.to_string(),
1303                                    private_ip: Default::default(),
1304                                    _tsunami: Default::default(),
1305                                };
1306
1307                                if let Err(e) = m
1308                                    .connect_ssh(
1309                                        &tag_setup.setup.username,
1310                                        Some(private_key_path.path()),
1311                                        max_wait,
1312                                        22,
1313                                    )
1314                                    .await
1315                                {
1316                                    tracing::trace!("ssh failed: {}", e);
1317                                    all_ready = false;
1318                                } else {
1319                                    tracing::debug!("instance ready");
1320
1321                                    tag_setup.ip_info = Some(IpInfo {
1322                                        public_dns: public_dns.clone(),
1323                                        public_ip: public_ip.clone(),
1324                                        private_ip: private_ip.clone(),
1325                                    });
1326                                }
1327                            }
1328                            .instrument(instance_span)
1329                            .await
1330                        }
1331                        _ => {
1332                            all_ready = false;
1333                        }
1334                    }
1335                }
1336            }
1337
1338            // let's not hammer the API
1339            tokio::time::sleep(time::Duration::from_secs(1)).await;
1340
1341            if let Some(wait_limit) = max_wait {
1342                if start.elapsed() <= wait_limit {
1343                    continue;
1344                }
1345                self.cancel_spot_instance_requests().await?;
1346                eyre::bail!("wait limit reached");
1347            }
1348        }
1349
1350        futures_util::future::join_all(self.instances.iter().map(
1351            |(
1352                instance_id,
1353                TaggedSetup {
1354                    ip_info,
1355                    name,
1356                    setup,
1357                },
1358            )| {
1359                let IpInfo {
1360                    public_dns,
1361                    public_ip,
1362                    private_ip,
1363                } = ip_info.as_ref().unwrap();
1364                let instance_span = tracing::debug_span!("instance", %instance_id, ip = %public_ip);
1365                async move {
1366                    if let Setup {
1367                        username,
1368                        setup_fn: Some(f),
1369                        ..
1370                    } = setup
1371                    {
1372                        super::setup_machine(
1373                            &name,
1374                            Some(&public_dns),
1375                            &public_ip,
1376                            Some(&private_ip),
1377                            &username,
1378                            max_wait,
1379                            Some(private_key_path.path()),
1380                            f.as_ref(),
1381                        )
1382                        .await?;
1383                    }
1384
1385                    Ok(())
1386                }
1387                .instrument(instance_span)
1388            },
1389        ))
1390        .await
1391        .into_iter()
1392        .collect()
1393    }
1394
1395    /// Establish SSH connections to the machines. The `Ok` value is a `HashMap` associating the
1396    /// friendly name for each `Setup` with the corresponding SSH connection.
1397    #[instrument(level = "debug")]
1398    pub async fn connect_all<'l>(&'l self) -> Result<HashMap<String, crate::Machine<'l>>, Report> {
1399        let private_key_path = self.private_key_path.as_ref().unwrap();
1400        futures_util::future::join_all(self.instances.values().map(|info| {
1401            let instance_span = tracing::trace_span!("instance", name = %info.name);
1402            async move {
1403                match info {
1404                    TaggedSetup {
1405                        name,
1406                        setup: Setup { username, .. },
1407                        ip_info:
1408                            Some(IpInfo {
1409                                public_dns,
1410                                public_ip,
1411                                private_ip,
1412                            }),
1413                    } => {
1414                        let m = crate::MachineDescriptor {
1415                            public_dns: Some(public_dns.clone()),
1416                            public_ip: public_ip.clone(),
1417                            private_ip: Some(private_ip.clone()),
1418                            nickname: name.clone(),
1419                            _tsunami: Default::default(),
1420                        };
1421
1422                        let m = m
1423                            .connect_ssh(&username, Some(private_key_path.path()), None, 22)
1424                            .await?;
1425                        Ok((name.clone(), m))
1426                    }
1427                    _ => eyre::bail!("machine has no ip information"),
1428                }
1429            }
1430            .instrument(instance_span)
1431        }))
1432        .await
1433        .into_iter()
1434        .collect()
1435    }
1436
1437    /// Terminate all running instances.
1438    ///
1439    /// Additionally deletes ephemeral keys and security groups. Sometimes, this deletion can fail
1440    /// for various reasons. This method deletes things in this order:
1441    /// 1. Try to delete the key pair, but emit a log message and continue if it fails.
1442    /// 2. Try to terminate the instances, and short-circuits to return the error if it fails.
1443    /// 3. Try to delete the security group. This can fail as the security groups are still
1444    ///    "attached" to the instances we just terminated in step 2. So, we retry for 2 minutes
1445    ///    before giving up and returning an error.
1446    #[instrument(level = "debug")]
1447    pub async fn terminate_all(&mut self) -> Result<(), Report> {
1448        let client = self.client.as_ref().unwrap();
1449
1450        if !self.ssh_key_name.trim().is_empty() {
1451            let key_span = tracing::trace_span!("key", name = %self.ssh_key_name);
1452            async {
1453                tracing::trace!("removing keypair");
1454                let req = rusoto_ec2::DeleteKeyPairRequest {
1455                    key_name: Some(self.ssh_key_name.clone()),
1456                    ..Default::default()
1457                };
1458                if let Err(e) = client.delete_key_pair(req).await {
1459                    tracing::warn!("failed to clean up temporary SSH key: {}", e);
1460                }
1461            }
1462            .instrument(key_span)
1463            .await;
1464        }
1465
1466        // terminate instances
1467        if !self.instances.is_empty() {
1468            tracing::info!("terminating instances");
1469            let instance_ids = self.instances.keys().cloned().collect();
1470            self.instances.clear();
1471            // Why is `?` here ok? either:
1472            // 1. there was no spot capacity. So self.instances will be empty, and this
1473            //    block will get skipped, so sg will get cleaned up below.
1474            // 2. there were instances, but we couldn't terminate them. Then, the sg will
1475            //    still be attached to them, so there's no point trying to delete it.
1476            self.terminate_instances(instance_ids).await?;
1477        }
1478
1479        use rusoto_core::RusotoError;
1480        if !self.security_group_id.trim().is_empty() {
1481            let group_span =
1482                tracing::trace_span!("removing security group", id = %self.security_group_id);
1483            async {
1484                tracing::trace!("removing security group.");
1485                // clean up security groups and keys
1486                let start = tokio::time::Instant::now();
1487                loop {
1488                    if start.elapsed() > tokio::time::Duration::from_secs(5 * 60) {
1489                        return Err(Report::msg(
1490                            "failed to clean up temporary security group after 5 minutes.",
1491                        ));
1492                    }
1493
1494                    let req = rusoto_ec2::DeleteSecurityGroupRequest {
1495                        group_id: Some(self.security_group_id.clone()),
1496                        ..Default::default()
1497                    };
1498                    match client.delete_security_group(req).await {
1499                        Ok(_) => break,
1500                        Err(RusotoError::Unknown(r)) => {
1501                            let err = r.body_as_str();
1502                            if err.contains("<Code>DependencyViolation</Code>") {
1503                                tracing::trace!("instances not yet shut down -- retrying");
1504                                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
1505                            } else {
1506                                Err(Report::new(RusotoError::<
1507                                    rusoto_ec2::DeleteSecurityGroupError,
1508                                >::Unknown(r)))
1509                                .wrap_err("failed to clean up temporary security group")?;
1510                                unreachable!();
1511                            }
1512                        }
1513                        Err(e) => {
1514                            return Err(Report::new(e)
1515                                .wrap_err("failed to clean up temporary security group"));
1516                        }
1517                    }
1518                }
1519
1520                tracing::trace!("cleaned up temporary security group");
1521                Ok::<_, Report>(())
1522            }
1523            .instrument(group_span)
1524            .await?;
1525        }
1526
1527        Ok(())
1528    }
1529
1530    #[instrument(level = "debug")]
1531    async fn describe_spot_instance_requests(
1532        &self,
1533    ) -> Result<Vec<(String, String, String, Option<String>)>, Report> {
1534        let client = self.client.as_ref().unwrap();
1535        let request_ids = self.spot_requests.keys().cloned().collect();
1536        let req = rusoto_ec2::DescribeSpotInstanceRequestsRequest {
1537            spot_instance_request_ids: Some(request_ids),
1538            ..Default::default()
1539        };
1540        loop {
1541            let res = client.describe_spot_instance_requests(req.clone()).await;
1542            if let Err(ref e) = res {
1543                let msg = e.to_string();
1544                if msg.contains("The spot instance request ID") && msg.contains("does not exist") {
1545                    tracing::trace!("spot instance requests not yet ready");
1546
1547                    // let's not hammer the API
1548                    tokio::time::sleep(time::Duration::from_secs(1)).await;
1549                    continue;
1550                } else {
1551                    res.wrap_err("failed to describe spot instances")?;
1552                    unreachable!();
1553                }
1554            }
1555
1556            let res = res.expect("Err checked above");
1557            let instances = res
1558                .spot_instance_requests
1559                .expect("describe always returns at least one spot instance")
1560                .into_iter()
1561                .map(|sir| {
1562                    let request_id = sir
1563                        .spot_instance_request_id
1564                        .expect("spot request did not have id specified");
1565                    let state = sir
1566                        .state
1567                        .expect("spot request did not have state specified");
1568                    let status = sir
1569                        .status
1570                        .expect("spot request did not have status specified")
1571                        .code
1572                        .expect("spot request status did not have status code");
1573                    let instance_id = sir.instance_id;
1574                    (request_id, state, status, instance_id)
1575                })
1576                .collect();
1577            break Ok(instances);
1578        }
1579    }
1580
1581    #[instrument(level = "debug")]
1582    async fn cancel_spot_instance_requests(&self) -> Result<(), Report> {
1583        tracing::warn!("wait time exceeded for -- cancelling run");
1584        if self.spot_requests.is_empty() {
1585            return Ok(());
1586        }
1587        let request_ids = self.spot_requests.keys().cloned().collect();
1588        let cancel = rusoto_ec2::CancelSpotInstanceRequestsRequest {
1589            spot_instance_request_ids: request_ids,
1590            ..Default::default()
1591        };
1592        self.client
1593            .as_ref()
1594            .unwrap()
1595            .cancel_spot_instance_requests(cancel)
1596            .await
1597            .wrap_err("failed to cancel spot instances")?;
1598
1599        tracing::trace!("spot instances cancelled -- waiting for cancellation");
1600        loop {
1601            tracing::trace!("checking spot request status");
1602            let instances = self.describe_spot_instance_requests().await?;
1603
1604            let all_cancelled = instances.iter().all(|(request_id, state, _, instance_id)| {
1605                if state == "closed"
1606                    || state == "cancelled"
1607                    || state == "failed"
1608                    || state == "completed"
1609                {
1610                    tracing::trace!(%request_id, ?instance_id, "spot instance request {}", state);
1611                    true
1612                } else {
1613                    false
1614                }
1615            });
1616
1617            if all_cancelled {
1618                tracing::trace!("deleting spot instances");
1619                // find instances with an id assigned and terminate them
1620                let instance_ids = instances
1621                    .into_iter()
1622                    .filter_map(|(_, _, _, instance_id)| instance_id)
1623                    .collect();
1624                self.terminate_instances(instance_ids).await?;
1625                break;
1626            }
1627
1628            // let's not hammer the API
1629            tokio::time::sleep(time::Duration::from_secs(1)).await;
1630        }
1631        Ok(())
1632    }
1633
1634    #[instrument(level = "debug")]
1635    async fn terminate_instances(&self, instance_ids: Vec<String>) -> Result<(), Report> {
1636        if instance_ids.is_empty() {
1637            return Ok(());
1638        }
1639        let client = self.client.as_ref().unwrap();
1640        let termination_req = rusoto_ec2::TerminateInstancesRequest {
1641            instance_ids,
1642            ..Default::default()
1643        };
1644        while let Err(e) = client.terminate_instances(termination_req.clone()).await {
1645            let msg = e.to_string();
1646            if msg.contains("Pooled stream disconnected") || msg.contains("broken pipe") {
1647                tracing::trace!("retrying instance termination");
1648                continue;
1649            } else {
1650                Err(e).wrap_err("failed to terminate tsunami instances")?;
1651                unreachable!();
1652            }
1653        }
1654        Ok(())
1655    }
1656}
1657
1658struct UbuntuAmi(String);
1659
1660impl UbuntuAmi {
1661    async fn new(r: Region) -> Result<Self, Report> {
1662        Ok(UbuntuAmi(
1663            ubuntu_ami::get_latest(
1664                &r.name(),
1665                Some("bionic"),
1666                None,
1667                Some("hvm:ebs-ssd"),
1668                Some("amd64"),
1669            )
1670            .await
1671            .map_err(|e| eyre!(e))?,
1672        ))
1673    }
1674}
1675
1676impl From<UbuntuAmi> for String {
1677    fn from(s: UbuntuAmi) -> String {
1678        s.0
1679    }
1680}
1681
1682#[cfg(test)]
1683mod test {
1684    use super::RegionLauncher;
1685    use super::*;
1686    use rusoto_core::credential::DefaultCredentialsProvider;
1687    use rusoto_core::region::Region;
1688    use rusoto_ec2::Ec2;
1689    use std::future::Future;
1690
1691    fn do_make_machine_and_ssh_setupfn<'l>(
1692        l: &'l mut super::Launcher,
1693    ) -> impl Future<Output = Result<(), Report>> + 'l {
1694        use crate::providers::Launcher;
1695        async move {
1696            l.spawn(
1697                vec![(
1698                    String::from("my machine"),
1699                    super::Setup::default().setup(|vm| {
1700                        Box::pin(async move {
1701                            if vm.ssh.command("whoami").status().await?.success() {
1702                                Ok(())
1703                            } else {
1704                                Err(eyre!("failed"))
1705                            }
1706                        })
1707                    }),
1708                )],
1709                None,
1710            )
1711            .await?;
1712            let vms = l.connect_all().await?;
1713            let my_machine = vms
1714                .get("my machine")
1715                .ok_or_else(|| eyre!("machine not found"))?;
1716            my_machine
1717                .ssh
1718                .command("echo")
1719                .arg("\"Hello, EC2\"")
1720                .status()
1721                .await?;
1722
1723            Ok(())
1724        }
1725    }
1726
1727    #[test]
1728    #[ignore]
1729    fn make_machine_and_ssh_setupfn() {
1730        use crate::providers::Launcher;
1731        tracing_subscriber::fmt::init();
1732        let rt = tokio::runtime::Runtime::new().unwrap();
1733        let mut l = super::Launcher::default();
1734        // make the defined-duration instances expire after 1 hour
1735        l.set_mode(LaunchMode::duration_spot(1));
1736        rt.block_on(async move {
1737            if let Err(e) = do_make_machine_and_ssh_setupfn(&mut l).await {
1738                // failed test.
1739                l.terminate_all().await.unwrap();
1740                panic!(e);
1741            } else {
1742                l.terminate_all().await.unwrap();
1743            }
1744        })
1745    }
1746
1747    #[test]
1748    #[ignore]
1749    fn make_key() -> Result<(), Report> {
1750        let rt = tokio::runtime::Runtime::new().unwrap();
1751        let region = Region::UsEast1;
1752        let provider = DefaultCredentialsProvider::new()?;
1753        let ec2 = RegionLauncher::connect(region, super::AvailabilityZoneSpec::Any, provider)?;
1754        rt.block_on(async {
1755            let mut ec2 = ec2.make_ssh_key().await?;
1756            tracing::debug!(
1757                name = %ec2.ssh_key_name,
1758                path = %ec2.private_key_path.as_ref().unwrap().path().display(),
1759                "made key"
1760            );
1761            assert!(!ec2.ssh_key_name.is_empty());
1762            assert!(ec2.private_key_path.as_ref().unwrap().path().exists());
1763
1764            let mut req = rusoto_ec2::DeleteKeyPairRequest::default();
1765            req.key_name = Some(ec2.ssh_key_name.clone());
1766            ec2.client
1767                .as_mut()
1768                .unwrap()
1769                .delete_key_pair(req)
1770                .await
1771                .context(format!(
1772                    "Could not delete ssh key pair {:?}",
1773                    ec2.ssh_key_name
1774                ))?;
1775
1776            Ok(())
1777        })
1778    }
1779
1780    fn do_multi_instance_spot_request<'l>(
1781        ec2: &'l mut super::RegionLauncher,
1782    ) -> impl Future<Output = Result<(), Report>> + 'l {
1783        async move {
1784            let names = (1..).map(|x| format!("{}", x));
1785            let setup = Setup::default();
1786            let ms: Vec<(String, Setup)> = names.zip(itertools::repeat_n(setup, 5)).collect();
1787
1788            tracing::debug!(num = %ms.len(), "make spot instance requests");
1789            ec2.make_spot_instance_requests(60 as _, ms).await?;
1790            assert_eq!(ec2.spot_requests.len(), 5);
1791            tracing::debug!("wait for spot instance requests");
1792            ec2.wait_for_spot_instance_requests(None).await?;
1793
1794            Ok(())
1795        }
1796    }
1797
1798    #[test]
1799    #[ignore]
1800    fn multi_instance_spot_request() -> Result<(), Report> {
1801        let region = "us-east-1";
1802        let provider = DefaultCredentialsProvider::new()?;
1803
1804        let rt = tokio::runtime::Runtime::new().unwrap();
1805        rt.block_on(async {
1806            let mut ec2 =
1807                RegionLauncher::new(region, super::AvailabilityZoneSpec::Any, provider, false)
1808                    .await?;
1809
1810            if let Err(e) = do_multi_instance_spot_request(&mut ec2).await {
1811                ec2.terminate_all().await.unwrap();
1812                panic!(e);
1813            } else {
1814                ec2.terminate_all().await.unwrap();
1815            }
1816
1817            Ok(())
1818        })
1819    }
1820}