Skip to main content

dactor_discover_aws/
lib.rs

1//! AWS node discovery for the dactor distributed actor framework.
2//!
3//! Provides two discovery mechanisms:
4//! - [`AutoScalingDiscovery`]: Lists instances in an EC2 Auto Scaling Group.
5//! - [`Ec2TagDiscovery`]: Queries EC2 instances by tag key/value filters.
6
7use dactor::{ClusterDiscovery, DiscoveryError};
8use std::fmt;
9
10// ---------------------------------------------------------------------------
11// Error type
12// ---------------------------------------------------------------------------
13
14/// Errors returned by AWS discovery operations.
15#[derive(Debug)]
16pub enum AwsDiscoveryError {
17    /// Error from the Auto Scaling API.
18    AutoScaling(String),
19    /// Error from the EC2 API.
20    Ec2(String),
21    /// No instances matched the discovery criteria.
22    NoInstances,
23    /// Invalid or missing configuration.
24    Config(String),
25}
26
27impl fmt::Display for AwsDiscoveryError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            AwsDiscoveryError::AutoScaling(e) => write!(f, "Auto Scaling API error: {e}"),
31            AwsDiscoveryError::Ec2(e) => write!(f, "EC2 API error: {e}"),
32            AwsDiscoveryError::NoInstances => write!(f, "no instances found"),
33            AwsDiscoveryError::Config(e) => write!(f, "configuration error: {e}"),
34        }
35    }
36}
37
38impl std::error::Error for AwsDiscoveryError {}
39
40// ---------------------------------------------------------------------------
41// ASG configuration
42// ---------------------------------------------------------------------------
43
44/// Configuration for Auto Scaling Group discovery.
45#[derive(Debug, Clone)]
46pub struct AsgDiscoveryConfig {
47    /// Name of the Auto Scaling Group.
48    pub asg_name: String,
49    /// Port to append to each discovered IP address.
50    pub port: u16,
51    /// AWS region override. Uses the SDK default chain when `None`.
52    pub region: Option<String>,
53    /// When `true`, return public IPs instead of private IPs.
54    pub use_public_ip: bool,
55}
56
57impl Default for AsgDiscoveryConfig {
58    fn default() -> Self {
59        Self {
60            asg_name: String::new(),
61            port: 9000,
62            region: None,
63            use_public_ip: false,
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// AutoScalingDiscovery
70// ---------------------------------------------------------------------------
71
72/// Discovers peer nodes by listing instances in an EC2 Auto Scaling Group.
73///
74/// Filters for `InService` + healthy instances and returns their private
75/// (or public) IP addresses with the configured port.
76pub struct AutoScalingDiscovery {
77    config: AsgDiscoveryConfig,
78}
79
80impl AutoScalingDiscovery {
81    /// Returns a new builder with default configuration.
82    pub fn builder() -> AsgDiscoveryBuilder {
83        AsgDiscoveryBuilder {
84            config: AsgDiscoveryConfig::default(),
85        }
86    }
87
88    /// Returns a reference to the current configuration.
89    pub fn config(&self) -> &AsgDiscoveryConfig {
90        &self.config
91    }
92
93    /// Asynchronously discover peer addresses from the Auto Scaling Group.
94    pub async fn discover_async(&self) -> Result<Vec<String>, AwsDiscoveryError> {
95        if self.config.asg_name.is_empty() {
96            return Err(AwsDiscoveryError::Config(
97                "asg_name must not be empty".to_string(),
98            ));
99        }
100
101        let mut config_loader =
102            aws_config::defaults(aws_config::BehaviorVersion::latest());
103        if let Some(region) = &self.config.region {
104            config_loader =
105                config_loader.region(aws_config::Region::new(region.clone()));
106        }
107        let sdk_config = config_loader.load().await;
108
109        // 1. List instances in the ASG.
110        let asg_client = aws_sdk_autoscaling::Client::new(&sdk_config);
111        let asg_resp = asg_client
112            .describe_auto_scaling_groups()
113            .auto_scaling_group_names(&self.config.asg_name)
114            .send()
115            .await
116            .map_err(|e| AwsDiscoveryError::AutoScaling(e.to_string()))?;
117
118        let asg = asg_resp
119            .auto_scaling_groups()
120            .first()
121            .ok_or(AwsDiscoveryError::NoInstances)?;
122
123        let instance_ids: Vec<String> = asg
124            .instances()
125            .iter()
126            .filter(|i| {
127                i.lifecycle_state()
128                    .map(|s| s.as_str() == "InService")
129                    .unwrap_or(false)
130                    && i.health_status().map(|h| h == "Healthy").unwrap_or(false)
131            })
132            .filter_map(|i| i.instance_id().map(String::from))
133            .collect();
134
135        if instance_ids.is_empty() {
136            return Err(AwsDiscoveryError::NoInstances);
137        }
138
139        // 2. Describe instances to get their IP addresses.
140        let ec2_client = aws_sdk_ec2::Client::new(&sdk_config);
141        let ec2_resp = ec2_client
142            .describe_instances()
143            .set_instance_ids(Some(instance_ids))
144            .send()
145            .await
146            .map_err(|e| AwsDiscoveryError::Ec2(e.to_string()))?;
147
148        let mut addresses = Vec::new();
149        for reservation in ec2_resp.reservations() {
150            for instance in reservation.instances() {
151                let ip = if self.config.use_public_ip {
152                    instance.public_ip_address()
153                } else {
154                    instance.private_ip_address()
155                };
156                if let Some(ip) = ip {
157                    addresses.push(format!("{ip}:{}", self.config.port));
158                }
159            }
160        }
161
162        Ok(addresses)
163    }
164}
165
166#[async_trait::async_trait]
167impl ClusterDiscovery for AutoScalingDiscovery {
168    async fn discover(&self) -> Result<Vec<dactor::DiscoveredPeer>, DiscoveryError> {
169        self.discover_async()
170            .await
171            .map(|addrs| addrs.into_iter().map(dactor::DiscoveredPeer::from_address).collect())
172            .map_err(|e| DiscoveryError::new(e.to_string()))
173    }
174}
175
176// ---------------------------------------------------------------------------
177// ASG Builder
178// ---------------------------------------------------------------------------
179
180/// Builder for [`AutoScalingDiscovery`].
181pub struct AsgDiscoveryBuilder {
182    config: AsgDiscoveryConfig,
183}
184
185impl AsgDiscoveryBuilder {
186    /// Set the Auto Scaling Group name.
187    pub fn asg_name(mut self, name: &str) -> Self {
188        self.config.asg_name = name.to_string();
189        self
190    }
191
192    /// Set the port number (default: `9000`).
193    pub fn port(mut self, port: u16) -> Self {
194        self.config.port = port;
195        self
196    }
197
198    /// Set an explicit AWS region override.
199    pub fn region(mut self, region: &str) -> Self {
200        self.config.region = Some(region.to_string());
201        self
202    }
203
204    /// Return public IPs instead of private IPs.
205    pub fn use_public_ip(mut self, yes: bool) -> Self {
206        self.config.use_public_ip = yes;
207        self
208    }
209
210    /// Build the [`AutoScalingDiscovery`] instance.
211    pub fn build(self) -> AutoScalingDiscovery {
212        AutoScalingDiscovery {
213            config: self.config,
214        }
215    }
216}
217
218// ---------------------------------------------------------------------------
219// EC2 Tag configuration
220// ---------------------------------------------------------------------------
221
222/// Configuration for EC2 tag-based discovery.
223#[derive(Debug, Clone)]
224pub struct Ec2TagConfig {
225    /// Tag key to filter on (e.g., `"dactor-cluster"`).
226    pub tag_key: String,
227    /// Tag value to match (e.g., `"my-cluster"`).
228    pub tag_value: String,
229    /// Port to append to each discovered IP address.
230    pub port: u16,
231    /// AWS region override. Uses the SDK default chain when `None`.
232    pub region: Option<String>,
233    /// When `true`, return public IPs instead of private IPs.
234    pub use_public_ip: bool,
235}
236
237impl Default for Ec2TagConfig {
238    fn default() -> Self {
239        Self {
240            tag_key: String::new(),
241            tag_value: String::new(),
242            port: 9000,
243            region: None,
244            use_public_ip: false,
245        }
246    }
247}
248
249// ---------------------------------------------------------------------------
250// Ec2TagDiscovery
251// ---------------------------------------------------------------------------
252
253/// Discovers peer nodes by querying EC2 instances with matching tags.
254///
255/// Uses the `DescribeInstances` API with tag filters and only returns
256/// instances in the `running` state.
257pub struct Ec2TagDiscovery {
258    config: Ec2TagConfig,
259}
260
261impl Ec2TagDiscovery {
262    /// Returns a new builder with default configuration.
263    pub fn builder() -> Ec2TagDiscoveryBuilder {
264        Ec2TagDiscoveryBuilder {
265            config: Ec2TagConfig::default(),
266        }
267    }
268
269    /// Returns a reference to the current configuration.
270    pub fn config(&self) -> &Ec2TagConfig {
271        &self.config
272    }
273
274    /// Asynchronously discover peer addresses by EC2 tags.
275    pub async fn discover_async(&self) -> Result<Vec<String>, AwsDiscoveryError> {
276        if self.config.tag_key.is_empty() {
277            return Err(AwsDiscoveryError::Config(
278                "tag_key must not be empty".to_string(),
279            ));
280        }
281
282        let mut config_loader =
283            aws_config::defaults(aws_config::BehaviorVersion::latest());
284        if let Some(region) = &self.config.region {
285            config_loader =
286                config_loader.region(aws_config::Region::new(region.clone()));
287        }
288        let sdk_config = config_loader.load().await;
289
290        let ec2_client = aws_sdk_ec2::Client::new(&sdk_config);
291
292        let tag_filter = aws_sdk_ec2::types::Filter::builder()
293            .name(format!("tag:{}", self.config.tag_key))
294            .values(&self.config.tag_value)
295            .build();
296
297        let running_filter = aws_sdk_ec2::types::Filter::builder()
298            .name("instance-state-name")
299            .values("running")
300            .build();
301
302        let resp = ec2_client
303            .describe_instances()
304            .filters(tag_filter)
305            .filters(running_filter)
306            .send()
307            .await
308            .map_err(|e| AwsDiscoveryError::Ec2(e.to_string()))?;
309
310        let mut addresses = Vec::new();
311        for reservation in resp.reservations() {
312            for instance in reservation.instances() {
313                let ip = if self.config.use_public_ip {
314                    instance.public_ip_address()
315                } else {
316                    instance.private_ip_address()
317                };
318                if let Some(ip) = ip {
319                    addresses.push(format!("{ip}:{}", self.config.port));
320                }
321            }
322        }
323
324        if addresses.is_empty() {
325            return Err(AwsDiscoveryError::NoInstances);
326        }
327
328        Ok(addresses)
329    }
330}
331
332#[async_trait::async_trait]
333impl ClusterDiscovery for Ec2TagDiscovery {
334    async fn discover(&self) -> Result<Vec<dactor::DiscoveredPeer>, DiscoveryError> {
335        self.discover_async()
336            .await
337            .map(|addrs| addrs.into_iter().map(dactor::DiscoveredPeer::from_address).collect())
338            .map_err(|e| DiscoveryError::new(e.to_string()))
339    }
340}
341
342// ---------------------------------------------------------------------------
343// EC2 Tag Builder
344// ---------------------------------------------------------------------------
345
346/// Builder for [`Ec2TagDiscovery`].
347pub struct Ec2TagDiscoveryBuilder {
348    config: Ec2TagConfig,
349}
350
351impl Ec2TagDiscoveryBuilder {
352    /// Set the tag key to filter on.
353    pub fn tag_key(mut self, key: &str) -> Self {
354        self.config.tag_key = key.to_string();
355        self
356    }
357
358    /// Set the tag value to match.
359    pub fn tag_value(mut self, value: &str) -> Self {
360        self.config.tag_value = value.to_string();
361        self
362    }
363
364    /// Set the port number (default: `9000`).
365    pub fn port(mut self, port: u16) -> Self {
366        self.config.port = port;
367        self
368    }
369
370    /// Set an explicit AWS region override.
371    pub fn region(mut self, region: &str) -> Self {
372        self.config.region = Some(region.to_string());
373        self
374    }
375
376    /// Return public IPs instead of private IPs.
377    pub fn use_public_ip(mut self, yes: bool) -> Self {
378        self.config.use_public_ip = yes;
379        self
380    }
381
382    /// Build the [`Ec2TagDiscovery`] instance.
383    pub fn build(self) -> Ec2TagDiscovery {
384        Ec2TagDiscovery {
385            config: self.config,
386        }
387    }
388}
389
390// ---------------------------------------------------------------------------
391// Environment helpers
392// ---------------------------------------------------------------------------
393
394/// Read the current instance's private IP from the `DACTOR_INSTANCE_IP`
395/// environment variable.
396pub fn instance_private_ip() -> Option<String> {
397    std::env::var("DACTOR_INSTANCE_IP").ok()
398}
399
400/// Read the current instance's ID from the `DACTOR_INSTANCE_ID`
401/// environment variable.
402pub fn instance_id() -> Option<String> {
403    std::env::var("DACTOR_INSTANCE_ID").ok()
404}
405
406/// Read the current AWS region from `AWS_REGION` or `AWS_DEFAULT_REGION`.
407pub fn current_region() -> Option<String> {
408    std::env::var("AWS_REGION")
409        .ok()
410        .or_else(|| std::env::var("AWS_DEFAULT_REGION").ok())
411}
412
413// ---------------------------------------------------------------------------
414// Tests
415// ---------------------------------------------------------------------------
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    // -- ASG builder --------------------------------------------------------
422
423    #[test]
424    fn asg_builder_creates_valid_config() {
425        let discovery = AutoScalingDiscovery::builder()
426            .asg_name("my-asg")
427            .port(8080)
428            .region("us-west-2")
429            .use_public_ip(true)
430            .build();
431
432        assert_eq!(discovery.config().asg_name, "my-asg");
433        assert_eq!(discovery.config().port, 8080);
434        assert_eq!(discovery.config().region.as_deref(), Some("us-west-2"));
435        assert!(discovery.config().use_public_ip);
436    }
437
438    #[test]
439    fn asg_builder_default_values() {
440        let discovery = AutoScalingDiscovery::builder()
441            .asg_name("test-asg")
442            .build();
443
444        assert_eq!(discovery.config().asg_name, "test-asg");
445        assert_eq!(discovery.config().port, 9000);
446        assert!(discovery.config().region.is_none());
447        assert!(!discovery.config().use_public_ip);
448    }
449
450    #[test]
451    fn asg_default_config() {
452        let cfg = AsgDiscoveryConfig::default();
453        assert!(cfg.asg_name.is_empty());
454        assert_eq!(cfg.port, 9000);
455        assert!(cfg.region.is_none());
456        assert!(!cfg.use_public_ip);
457    }
458
459    // -- EC2 tag builder ----------------------------------------------------
460
461    #[test]
462    fn ec2_tag_builder_creates_valid_config() {
463        let discovery = Ec2TagDiscovery::builder()
464            .tag_key("dactor-cluster")
465            .tag_value("production")
466            .port(7000)
467            .region("eu-west-1")
468            .use_public_ip(false)
469            .build();
470
471        assert_eq!(discovery.config().tag_key, "dactor-cluster");
472        assert_eq!(discovery.config().tag_value, "production");
473        assert_eq!(discovery.config().port, 7000);
474        assert_eq!(discovery.config().region.as_deref(), Some("eu-west-1"));
475        assert!(!discovery.config().use_public_ip);
476    }
477
478    #[test]
479    fn ec2_tag_builder_default_values() {
480        let discovery = Ec2TagDiscovery::builder()
481            .tag_key("cluster")
482            .tag_value("dev")
483            .build();
484
485        assert_eq!(discovery.config().port, 9000);
486        assert!(discovery.config().region.is_none());
487        assert!(!discovery.config().use_public_ip);
488    }
489
490    #[test]
491    fn ec2_tag_default_config() {
492        let cfg = Ec2TagConfig::default();
493        assert!(cfg.tag_key.is_empty());
494        assert!(cfg.tag_value.is_empty());
495        assert_eq!(cfg.port, 9000);
496        assert!(cfg.region.is_none());
497        assert!(!cfg.use_public_ip);
498    }
499
500    // -- Environment helpers ------------------------------------------------
501
502    #[test]
503    fn instance_private_ip_returns_none_outside_aws() {
504        std::env::remove_var("DACTOR_INSTANCE_IP");
505        assert!(instance_private_ip().is_none());
506    }
507
508    #[test]
509    fn instance_id_returns_none_outside_aws() {
510        std::env::remove_var("DACTOR_INSTANCE_ID");
511        assert!(instance_id().is_none());
512    }
513
514    #[test]
515    fn current_region_returns_none_when_unset() {
516        // NOTE: env var tests must not race. This test only removes vars,
517        // which is safe if no other test is concurrently setting them.
518        // The set_var tests are consolidated below.
519        std::env::remove_var("AWS_REGION");
520        std::env::remove_var("AWS_DEFAULT_REGION");
521        assert!(current_region().is_none());
522    }
523
524    #[test]
525    fn current_region_preference_order() {
526        // Consolidated test: avoids env var race between parallel tests.
527        // Step 1: AWS_REGION takes precedence over AWS_DEFAULT_REGION
528        std::env::set_var("AWS_REGION", "us-east-1");
529        std::env::set_var("AWS_DEFAULT_REGION", "eu-west-1");
530        assert_eq!(current_region(), Some("us-east-1".to_string()));
531
532        // Step 2: Falls back to AWS_DEFAULT_REGION when AWS_REGION absent
533        std::env::remove_var("AWS_REGION");
534        assert_eq!(current_region(), Some("eu-west-1".to_string()));
535
536        // Cleanup
537        std::env::remove_var("AWS_DEFAULT_REGION");
538    }
539
540    // -- Error display ------------------------------------------------------
541
542    #[test]
543    fn error_display_autoscaling() {
544        let err = AwsDiscoveryError::AutoScaling("timeout".to_string());
545        assert_eq!(err.to_string(), "Auto Scaling API error: timeout");
546    }
547
548    #[test]
549    fn error_display_ec2() {
550        let err = AwsDiscoveryError::Ec2("access denied".to_string());
551        assert_eq!(err.to_string(), "EC2 API error: access denied");
552    }
553
554    #[test]
555    fn error_display_no_instances() {
556        let err = AwsDiscoveryError::NoInstances;
557        assert_eq!(err.to_string(), "no instances found");
558    }
559
560    #[test]
561    fn error_display_config() {
562        let err = AwsDiscoveryError::Config("missing asg_name".to_string());
563        assert_eq!(err.to_string(), "configuration error: missing asg_name");
564    }
565}