Skip to main content

mixtape_tools/aws/
use_aws.rs

1//! AWS service integration tool for dynamic API calls.
2//!
3//! This module provides a universal interface to AWS services, allowing agents to
4//! invoke any AWS API operation dynamically using SigV4 signing.
5//!
6//! # Examples
7//!
8//! ## Basic Usage with GetCallerIdentity
9//!
10//! ```no_run
11//! use mixtape_core::Tool;
12//! use mixtape_tools::aws::UseAwsTool;
13//!
14//! #[tokio::main]
15//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//!     let tool = UseAwsTool::new().await?;
17//!     let input = serde_json::from_value(serde_json::json!({
18//!         "service_name": "sts",
19//!         "operation_name": "GetCallerIdentity",
20//!         "parameters": {},
21//!         "region": "us-east-1",
22//!         "label": "Get AWS caller identity"
23//!     }))?;
24//!     let result = tool.execute(input).await?;
25//!     println!("{}", result.as_text());
26//!     Ok(())
27//! }
28//! ```
29//!
30//! ## Using a Specific Profile
31//!
32//! ```no_run
33//! use mixtape_tools::aws::UseAwsTool;
34//!
35//! #[tokio::main]
36//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
37//!     let tool = UseAwsTool::builder()
38//!         .profile("my-aws-profile")
39//!         .build()
40//!         .await?;
41//!     Ok(())
42//! }
43//! ```
44//!
45//! ## DynamoDB Query Example
46//!
47//! ```no_run
48//! use mixtape_core::Tool;
49//! use mixtape_tools::aws::UseAwsTool;
50//!
51//! #[tokio::main]
52//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
53//!     let tool = UseAwsTool::new().await?;
54//!     let input = serde_json::from_value(serde_json::json!({
55//!         "service_name": "dynamodb",
56//!         "operation_name": "Scan",
57//!         "parameters": {
58//!             "TableName": "my-table",
59//!             "Limit": 10
60//!         },
61//!         "region": "us-west-2",
62//!         "label": "Scan DynamoDB table"
63//!     }))?;
64//!     let result = tool.execute(input).await?;
65//!     Ok(())
66//! }
67//! ```
68
69use crate::prelude::*;
70use aws_config::BehaviorVersion;
71use aws_credential_types::provider::ProvideCredentials;
72use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
73use aws_sigv4::sign::v4;
74use aws_types::region::Region;
75use http::header::{CONTENT_TYPE, HOST};
76use http::{HeaderValue, Method};
77use reqwest::Client;
78use std::collections::HashMap;
79use std::sync::Arc;
80use std::time::{Duration, SystemTime};
81
82// ============================================================================
83// Public Types
84// ============================================================================
85
86/// Input parameters for the AWS service tool.
87///
88/// # Required Fields
89///
90/// - `service_name`: AWS service (e.g., "sts", "dynamodb", "lambda")
91/// - `operation_name`: API operation in PascalCase (e.g., "GetCallerIdentity")
92/// - `region`: AWS region (e.g., "us-east-1")
93/// - `label`: Human-readable description for logging
94///
95/// # Optional Fields
96///
97/// - `parameters`: Operation parameters as a JSON object (default: `{}`)
98/// - `profile_name`: AWS profile from ~/.aws/credentials
99#[derive(Debug, Deserialize, JsonSchema)]
100pub struct UseAwsInput {
101    /// The AWS service name (e.g., "sts", "s3", "dynamodb", "lambda", "ec2").
102    /// Use lowercase service names as they appear in AWS endpoint URLs.
103    pub service_name: String,
104
105    /// The API operation to perform (e.g., "GetCallerIdentity", "ListBuckets").
106    /// Use PascalCase as they appear in AWS API documentation.
107    pub operation_name: String,
108
109    /// Parameters for the operation as a JSON object.
110    /// These are passed as the request body for JSON-based APIs.
111    #[serde(default = "default_parameters")]
112    pub parameters: serde_json::Value,
113
114    /// AWS region for the API call (e.g., "us-east-1", "us-west-2").
115    pub region: String,
116
117    /// Human-readable description of what this operation does.
118    /// Used for logging and display purposes.
119    #[serde(default)]
120    pub label: Option<String>,
121
122    /// Optional AWS profile name from ~/.aws/credentials.
123    /// If not specified, uses default credential chain.
124    #[serde(default)]
125    pub profile_name: Option<String>,
126}
127
128fn default_parameters() -> serde_json::Value {
129    serde_json::json!({})
130}
131
132/// Tool for making AWS API calls using SigV4 signing.
133///
134/// This tool provides a universal interface to AWS services, allowing agents to
135/// invoke any AWS API operation dynamically. It supports:
136///
137/// - All AWS services accessible via JSON-based APIs
138/// - Multiple credential sources (environment, profiles, IAM roles, SSO)
139/// - SigV4 request signing for authentication
140/// - Automatic region-specific endpoint resolution
141/// - Extensible service target prefix configuration
142///
143/// # Construction
144///
145/// This tool requires async initialization due to AWS credential loading.
146/// Use `UseAwsTool::new().await` or `UseAwsTool::builder()...build().await`.
147///
148/// **Note**: Unlike other tools in this crate, `UseAwsTool` does not implement
149/// `Default` because it requires async credential loading. Attempting to use
150/// an uninitialized tool will result in credential errors.
151///
152/// # Safety
153///
154/// Operations that match mutative prefixes (Create, Delete, Update, etc.) will
155/// include a warning in the output. The calling application should implement
156/// appropriate confirmation mechanisms.
157pub struct UseAwsTool {
158    client: Client,
159    credentials_provider: Arc<dyn ProvideCredentials>,
160    service_targets: HashMap<String, String>,
161    #[allow(dead_code)] // Stored for potential future use (e.g., per-request timeout override)
162    timeout: Duration,
163}
164
165/// Builder for creating `UseAwsTool` instances with custom configuration.
166///
167/// # Example
168///
169/// ```no_run
170/// use mixtape_tools::aws::UseAwsTool;
171/// use std::time::Duration;
172///
173/// #[tokio::main]
174/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
175///     let tool = UseAwsTool::builder()
176///         .profile("my-profile")
177///         .timeout(Duration::from_secs(120))
178///         .with_service_target("custom-service", "CustomService_20240101")
179///         .build()
180///         .await?;
181///     Ok(())
182/// }
183/// ```
184#[derive(Default)]
185pub struct UseAwsToolBuilder {
186    profile: Option<String>,
187    timeout: Option<Duration>,
188    custom_service_targets: HashMap<String, String>,
189    credentials_provider: Option<Arc<dyn ProvideCredentials>>,
190}
191
192// ============================================================================
193// Builder Implementation
194// ============================================================================
195
196impl UseAwsToolBuilder {
197    /// Set the AWS profile to use for credentials.
198    pub fn profile(mut self, profile: impl Into<String>) -> Self {
199        self.profile = Some(profile.into());
200        self
201    }
202
203    /// Set the HTTP request timeout (default: 60 seconds).
204    pub fn timeout(mut self, timeout: Duration) -> Self {
205        self.timeout = Some(timeout);
206        self
207    }
208
209    /// Add a custom service target prefix for the x-amz-target header.
210    ///
211    /// This is useful for services not in the default mapping or for
212    /// using different API versions.
213    pub fn with_service_target(
214        mut self,
215        service_name: impl Into<String>,
216        target_prefix: impl Into<String>,
217    ) -> Self {
218        self.custom_service_targets
219            .insert(service_name.into(), target_prefix.into());
220        self
221    }
222
223    /// Inject a custom credentials provider (useful for testing).
224    ///
225    /// When set, skips the default AWS credential chain and uses
226    /// the provided credentials directly.
227    pub fn credentials_provider(mut self, provider: Arc<dyn ProvideCredentials>) -> Self {
228        self.credentials_provider = Some(provider);
229        self
230    }
231
232    /// Build the `UseAwsTool` instance.
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if:
237    /// - No AWS credentials are found (and no custom provider was set)
238    /// - The HTTP client fails to initialize
239    pub async fn build(self) -> Result<UseAwsTool, ToolError> {
240        let timeout = self.timeout.unwrap_or(Duration::from_secs(60));
241
242        // Get credentials provider
243        let credentials_provider = if let Some(provider) = self.credentials_provider {
244            provider
245        } else {
246            let mut config_loader =
247                aws_config::defaults(BehaviorVersion::latest()).region(Region::new("us-east-1"));
248
249            if let Some(profile_name) = &self.profile {
250                config_loader = config_loader.profile_name(profile_name);
251            }
252
253            let config = config_loader.load().await;
254
255            config
256                .credentials_provider()
257                .map(Arc::from)
258                .ok_or_else(|| ToolError::from("No AWS credentials found. Ensure AWS credentials are configured via environment variables, ~/.aws/credentials, or IAM role."))?
259        };
260
261        let client = Client::builder()
262            .timeout(timeout)
263            .build()
264            .map_err(|e| ToolError::from(format!("Failed to create HTTP client: {}", e)))?;
265
266        // Merge default service targets with custom ones
267        let mut service_targets = default_service_targets();
268        for (k, v) in self.custom_service_targets {
269            service_targets.insert(k, v);
270        }
271
272        Ok(UseAwsTool {
273            client,
274            credentials_provider,
275            service_targets,
276            timeout,
277        })
278    }
279}
280
281// ============================================================================
282// UseAwsTool Implementation
283// ============================================================================
284
285impl UseAwsTool {
286    /// Create a new `UseAwsTool` with default configuration.
287    ///
288    /// This loads credentials from the default AWS credential chain:
289    /// 1. Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
290    /// 2. Shared credentials file (~/.aws/credentials)
291    /// 3. IAM instance profile (on EC2)
292    /// 4. Container credentials (in ECS/Fargate)
293    /// 5. SSO credentials (if configured)
294    ///
295    /// # Errors
296    ///
297    /// Returns an error if no AWS credentials are found.
298    pub async fn new() -> Result<Self, ToolError> {
299        Self::builder().build().await
300    }
301
302    /// Create a builder for custom configuration.
303    pub fn builder() -> UseAwsToolBuilder {
304        UseAwsToolBuilder::default()
305    }
306
307    /// Get the service target prefix for a service.
308    fn get_service_target(&self, service_name: &str) -> String {
309        self.service_targets
310            .get(service_name)
311            .cloned()
312            .unwrap_or_else(|| service_name.to_string())
313    }
314}
315
316// ============================================================================
317// Tool Trait Implementation
318// ============================================================================
319
320impl Tool for UseAwsTool {
321    type Input = UseAwsInput;
322
323    fn name(&self) -> &str {
324        "use_aws"
325    }
326
327    fn description(&self) -> &str {
328        "Make AWS API calls using service and operation names. \
329         Supports all AWS services with JSON-based APIs. \
330         Use PascalCase operation names (e.g., 'ListBuckets', 'GetCallerIdentity')."
331    }
332
333    async fn execute(&self, input: Self::Input) -> Result<ToolResult, ToolError> {
334        // Validate required fields with actionable error messages
335        validate_input(&input)?;
336
337        let label = input
338            .label
339            .as_deref()
340            .unwrap_or_else(|| &input.operation_name);
341
342        // Check for mutative operations
343        let is_mutative = is_mutative_operation(&input.operation_name);
344
345        // Build and send the request
346        let request = self
347            .build_signed_request(
348                &input.service_name,
349                &input.operation_name,
350                &input.parameters,
351                &input.region,
352            )
353            .await
354            .map_err(|e| {
355                ToolError::from(format!(
356                    "Failed to build request for {}.{} in {}: {}",
357                    input.service_name, input.operation_name, input.region, e
358                ))
359            })?;
360
361        let response = self.client.execute(request).await.map_err(|e| {
362            ToolError::from(format!(
363                "AWS request failed for {}.{} in {}: {}",
364                input.service_name, input.operation_name, input.region, e
365            ))
366        })?;
367
368        let status = response.status();
369        let body = response.text().await.map_err(|e| {
370            ToolError::from(format!(
371                "Failed to read response from {}.{}: {}",
372                input.service_name, input.operation_name, e
373            ))
374        })?;
375
376        if !status.is_success() {
377            return Err(parse_aws_error(
378                &input.service_name,
379                &input.operation_name,
380                &input.region,
381                status,
382                &body,
383            ));
384        }
385
386        // Parse and format success response
387        let response_json: serde_json::Value = serde_json::from_str(&body)
388            .unwrap_or_else(|_| serde_json::json!({ "raw_response": body }));
389
390        // Build result with metadata
391        let mut result = String::with_capacity(body.len() + 256);
392
393        result.push_str(&format!("Service: {}\n", input.service_name));
394        result.push_str(&format!("Operation: {}\n", input.operation_name));
395        result.push_str(&format!("Region: {}\n", input.region));
396        result.push_str(&format!("Label: {}\n", label));
397
398        if is_mutative {
399            result.push_str("Warning: This was a mutative operation\n");
400        }
401
402        result.push_str("\n---\n\n");
403
404        let pretty_response = serde_json::to_string_pretty(&response_json)
405            .unwrap_or_else(|_| response_json.to_string());
406        result.push_str(&pretty_response);
407
408        Ok(ToolResult::text(result))
409    }
410
411    fn format_output_plain(&self, result: &ToolResult) -> String {
412        let output = result.as_text();
413        let (metadata, content) = parse_output_header(&output);
414
415        if metadata.is_empty() {
416            return output.to_string();
417        }
418
419        let mut out = String::new();
420        out.push_str(&"─".repeat(60));
421        out.push('\n');
422
423        for (key, value) in &metadata {
424            let icon = match *key {
425                "Service" => "[S]",
426                "Operation" => "[O]",
427                "Region" => "[R]",
428                "Label" => "[L]",
429                "Warning" => "[!]",
430                _ => "   ",
431            };
432            out.push_str(&format!("{} {:12} {}\n", icon, key, value));
433        }
434
435        out.push_str(&"─".repeat(60));
436        out.push_str("\n\n");
437        out.push_str(content);
438        out
439    }
440
441    fn format_output_ansi(&self, result: &ToolResult) -> String {
442        let output = result.as_text();
443        let (metadata, content) = parse_output_header(&output);
444
445        if metadata.is_empty() {
446            return output.to_string();
447        }
448
449        let mut out = String::new();
450        out.push_str(&format!("\x1b[2m{}\x1b[0m\n", "─".repeat(60)));
451
452        for (key, value) in &metadata {
453            let (icon, color) = match *key {
454                "Service" => ("\x1b[33m\x1b[0m", "\x1b[33m"),
455                "Operation" => ("\x1b[34m\x1b[0m", "\x1b[34m"),
456                "Region" => ("\x1b[36m\x1b[0m", "\x1b[36m"),
457                "Label" => ("\x1b[32m\x1b[0m", "\x1b[32m"),
458                "Warning" => ("\x1b[31m\x1b[0m", "\x1b[31m"),
459                _ => ("  ", "\x1b[0m"),
460            };
461            out.push_str(&format!(
462                "{} \x1b[2m{:12}\x1b[0m {}{}\x1b[0m\n",
463                icon, key, color, value
464            ));
465        }
466
467        out.push_str(&format!("\x1b[2m{}\x1b[0m\n\n", "─".repeat(60)));
468        out.push_str(content);
469        out
470    }
471
472    fn format_output_markdown(&self, result: &ToolResult) -> String {
473        let output = result.as_text();
474        let (metadata, content) = parse_output_header(&output);
475
476        if metadata.is_empty() {
477            return output.to_string();
478        }
479
480        let mut out = String::new();
481
482        let label = metadata
483            .iter()
484            .find(|(k, _)| *k == "Label")
485            .map(|(_, v)| *v);
486
487        if let Some(l) = label {
488            out.push_str(&format!("## {}\n\n", l));
489        }
490
491        for (key, value) in &metadata {
492            if *key != "Label" {
493                out.push_str(&format!("- **{}**: {}\n", key, value));
494            }
495        }
496
497        out.push_str("\n---\n\n");
498        out.push_str("```json\n");
499        out.push_str(content);
500        out.push_str("\n```");
501        out
502    }
503}
504
505// ============================================================================
506// Private Implementation Details
507// ============================================================================
508
509impl UseAwsTool {
510    /// Build and sign an AWS API request.
511    async fn build_signed_request(
512        &self,
513        service_name: &str,
514        operation_name: &str,
515        parameters: &serde_json::Value,
516        region: &str,
517    ) -> Result<reqwest::Request, ToolError> {
518        let endpoint = get_endpoint(service_name, region);
519
520        let credentials = self
521            .credentials_provider
522            .provide_credentials()
523            .await
524            .map_err(|e| ToolError::from(format!("Failed to get AWS credentials: {}", e)))?;
525
526        let body = serde_json::to_string(parameters)
527            .map_err(|e| ToolError::from(format!("Failed to serialize parameters: {}", e)))?;
528
529        let content_type = "application/x-amz-json-1.1; charset=utf-8";
530        let target_header = format!(
531            "{}.{}",
532            self.get_service_target(service_name),
533            operation_name
534        );
535
536        let url = url::Url::parse(&endpoint)
537            .map_err(|e| ToolError::from(format!("Invalid endpoint URL: {}", e)))?;
538        let host = url
539            .host_str()
540            .ok_or_else(|| ToolError::from("Endpoint has no host"))?;
541
542        let mut builder = http::Request::builder()
543            .method(Method::POST)
544            .uri(&endpoint)
545            .header(HOST, host)
546            .header(CONTENT_TYPE, HeaderValue::from_static(content_type))
547            .header(
548                "x-amz-target",
549                HeaderValue::from_str(&target_header).unwrap(),
550            );
551
552        if let Some(token) = credentials.session_token() {
553            builder = builder.header(
554                "x-amz-security-token",
555                HeaderValue::from_str(token).unwrap(),
556            );
557        }
558
559        let http_request = builder
560            .body(body.clone())
561            .map_err(|e| ToolError::from(format!("Failed to build request: {}", e)))?;
562
563        let signing_settings = SigningSettings::default();
564        let identity = credentials.into();
565        let signing_params = v4::SigningParams::builder()
566            .identity(&identity)
567            .region(region)
568            .name(service_name)
569            .time(SystemTime::now())
570            .settings(signing_settings)
571            .build()
572            .map_err(|e| ToolError::from(format!("Failed to build signing params: {}", e)))?;
573
574        let signable_request = SignableRequest::new(
575            http_request.method().as_str(),
576            http_request.uri().to_string(),
577            http_request
578                .headers()
579                .iter()
580                .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or(""))),
581            SignableBody::Bytes(body.as_bytes()),
582        )
583        .map_err(|e| ToolError::from(format!("Failed to create signable request: {}", e)))?;
584
585        let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
586            .map_err(|e| ToolError::from(format!("Failed to sign request: {}", e)))?
587            .into_parts();
588
589        let mut req_builder = self.client.post(&endpoint).body(body);
590
591        for (name, value) in http_request.headers() {
592            if let Ok(v) = value.to_str() {
593                req_builder = req_builder.header(name.as_str(), v);
594            }
595        }
596
597        for (name, value) in signing_instructions.headers() {
598            let name_str: &str = name;
599            let value_str = std::str::from_utf8(value.as_bytes()).unwrap_or("");
600            req_builder = req_builder.header(name_str, value_str);
601        }
602
603        req_builder
604            .build()
605            .map_err(|e| ToolError::from(format!("Failed to build final request: {}", e)))
606    }
607}
608
609// ============================================================================
610// Helper Functions
611// ============================================================================
612
613/// Validate input fields and provide actionable error messages.
614fn validate_input(input: &UseAwsInput) -> Result<(), ToolError> {
615    if input.service_name.is_empty() {
616        return Err(ToolError::from(
617            "service_name cannot be empty. Use lowercase AWS service names like 'sts', 'dynamodb', 's3'.",
618        ));
619    }
620    if input.operation_name.is_empty() {
621        return Err(ToolError::from(
622            "operation_name cannot be empty. Use PascalCase operation names like 'GetCallerIdentity', 'ListBuckets'.",
623        ));
624    }
625    if input.region.is_empty() {
626        return Err(ToolError::from(
627            "region cannot be empty. Use AWS region codes like 'us-east-1', 'eu-west-1'.",
628        ));
629    }
630
631    // Validate parameters is an object (not array, string, etc.)
632    if !input.parameters.is_object() {
633        return Err(ToolError::from(format!(
634            "parameters must be a JSON object, got: {}",
635            match &input.parameters {
636                serde_json::Value::Null => "null",
637                serde_json::Value::Bool(_) => "boolean",
638                serde_json::Value::Number(_) => "number",
639                serde_json::Value::String(_) => "string",
640                serde_json::Value::Array(_) => "array",
641                serde_json::Value::Object(_) => "object",
642            }
643        )));
644    }
645
646    Ok(())
647}
648
649/// Parse AWS error response and create an actionable error message.
650fn parse_aws_error(
651    service_name: &str,
652    operation_name: &str,
653    region: &str,
654    status: reqwest::StatusCode,
655    body: &str,
656) -> ToolError {
657    if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(body) {
658        let error_type = error_json
659            .get("__type")
660            .or_else(|| error_json.get("Error").and_then(|e| e.get("Code")))
661            .and_then(|v| v.as_str())
662            .unwrap_or("Unknown");
663        let error_message = error_json
664            .get("message")
665            .or_else(|| error_json.get("Message"))
666            .or_else(|| error_json.get("Error").and_then(|e| e.get("Message")))
667            .and_then(|v| v.as_str())
668            .unwrap_or(body);
669
670        ToolError::from(format!(
671            "AWS API error for {}.{} in {} (HTTP {}): {} - {}",
672            service_name, operation_name, region, status, error_type, error_message
673        ))
674    } else {
675        ToolError::from(format!(
676            "AWS API error for {}.{} in {} (HTTP {}): {}",
677            service_name, operation_name, region, status, body
678        ))
679    }
680}
681
682/// List of operation prefixes that indicate potentially mutative operations.
683const MUTATIVE_OPERATIONS: &[&str] = &[
684    "Create",
685    "Put",
686    "Delete",
687    "Update",
688    "Terminate",
689    "Revoke",
690    "Disable",
691    "Deregister",
692    "Stop",
693    "Add",
694    "Modify",
695    "Remove",
696    "Attach",
697    "Detach",
698    "Start",
699    "Enable",
700    "Register",
701    "Set",
702    "Associate",
703    "Disassociate",
704    "Allocate",
705    "Release",
706    "Cancel",
707    "Reboot",
708    "Accept",
709];
710
711/// Check if an operation is potentially mutative (destructive).
712fn is_mutative_operation(operation_name: &str) -> bool {
713    MUTATIVE_OPERATIONS
714        .iter()
715        .any(|prefix| operation_name.starts_with(prefix))
716}
717
718/// Get the AWS endpoint URL for a service and region.
719fn get_endpoint(service_name: &str, region: &str) -> String {
720    match service_name {
721        "iam" => "https://iam.amazonaws.com".to_string(),
722        "sts" if region == "us-east-1" => "https://sts.amazonaws.com".to_string(),
723        "sts" => format!("https://sts.{}.amazonaws.com", region),
724        "route53" | "cloudfront" => format!("https://{}.amazonaws.com", service_name),
725        "s3" => format!("https://s3.{}.amazonaws.com", region),
726        _ => format!("https://{}.{}.amazonaws.com", service_name, region),
727    }
728}
729
730/// Default service target prefixes for the x-amz-target header.
731fn default_service_targets() -> HashMap<String, String> {
732    let mut targets = HashMap::new();
733    targets.insert("dynamodb".into(), "DynamoDB_20120810".into());
734    targets.insert("kinesis".into(), "Kinesis_20131202".into());
735    targets.insert("logs".into(), "Logs_20140328".into());
736    targets.insert("events".into(), "AWSEvents".into());
737    targets.insert("lambda".into(), "AWSLambda".into());
738    targets.insert("sts".into(), "AWSSecurityTokenServiceV20110615".into());
739    targets.insert("sqs".into(), "AmazonSQS".into());
740    targets.insert("sns".into(), "AmazonSimpleNotificationService".into());
741    targets.insert("secretsmanager".into(), "secretsmanager".into());
742    targets.insert("ssm".into(), "AmazonSSM".into());
743    targets.insert("kms".into(), "TrentService".into());
744    targets.insert("iam".into(), "IAMService".into());
745    targets.insert(
746        "cognito-idp".into(),
747        "AWSCognitoIdentityProviderService".into(),
748    );
749    targets.insert(
750        "cognito-identity".into(),
751        "AWSCognitoIdentityService".into(),
752    );
753    targets.insert("cloudwatch".into(), "GraniteServiceVersion20100801".into());
754    targets.insert(
755        "application-autoscaling".into(),
756        "AnyScaleFrontendService".into(),
757    );
758    targets.insert("elasticache".into(), "AmazonElastiCacheV9".into());
759    targets.insert("ecr".into(), "AmazonEC2ContainerRegistry_V20150921".into());
760    targets.insert("ecs".into(), "AmazonEC2ContainerServiceV20141113".into());
761    targets.insert("cloudformation".into(), "CloudFormation".into());
762    targets.insert("codepipeline".into(), "CodePipeline_20150709".into());
763    targets.insert("codebuild".into(), "CodeBuild_20161006".into());
764    targets.insert("codecommit".into(), "CodeCommit_20150413".into());
765    targets.insert("codedeploy".into(), "CodeDeploy_20141006".into());
766    targets.insert("stepfunctions".into(), "AWSStepFunctions".into());
767    targets.insert("glue".into(), "AWSGlue".into());
768    targets.insert("athena".into(), "AmazonAthena".into());
769    targets.insert("redshift-data".into(), "RedshiftData".into());
770    targets.insert("bedrock".into(), "AmazonBedrock".into());
771    targets.insert("bedrock-runtime".into(), "AmazonBedrockRuntime".into());
772    targets.insert("sagemaker".into(), "SageMaker".into());
773    targets.insert("rekognition".into(), "RekognitionService".into());
774    targets.insert("textract".into(), "Textract".into());
775    targets.insert("comprehend".into(), "Comprehend_20171127".into());
776    targets.insert(
777        "translate".into(),
778        "AWSShineFrontendService_20170701".into(),
779    );
780    targets.insert("polly".into(), "Parrot_v1".into());
781    targets.insert("transcribe".into(), "Transcribe".into());
782    targets
783}
784
785/// Parse output header into metadata fields and content.
786/// This is used by formatting methods to separate metadata from response body.
787fn parse_output_header(output: &str) -> (Vec<(&str, &str)>, &str) {
788    let mut metadata = Vec::new();
789    let mut content_start = 0;
790
791    for (i, line) in output.lines().enumerate() {
792        if line == "---" {
793            let lines: Vec<&str> = output.lines().collect();
794            if i + 1 < lines.len() {
795                let header_len: usize = lines[..=i].iter().map(|l| l.len() + 1).sum();
796                content_start = header_len;
797            }
798            break;
799        }
800
801        if let Some(colon_idx) = line.find(": ") {
802            let key = &line[..colon_idx];
803            let value = &line[colon_idx + 2..];
804            metadata.push((key, value));
805        }
806    }
807
808    let content = if content_start < output.len() {
809        output[content_start..].trim_start_matches('\n')
810    } else {
811        ""
812    };
813
814    (metadata, content)
815}
816
817// ============================================================================
818// Tests
819// ============================================================================
820
821#[cfg(test)]
822mod tests {
823    use super::*;
824
825    // ==================== Builder tests ====================
826
827    #[test]
828    fn test_builder_default() {
829        let builder = UseAwsToolBuilder::default();
830        assert!(builder.profile.is_none());
831        assert!(builder.timeout.is_none());
832        assert!(builder.custom_service_targets.is_empty());
833    }
834
835    #[test]
836    fn test_builder_profile() {
837        let builder = UseAwsTool::builder().profile("my-profile");
838        assert_eq!(builder.profile, Some("my-profile".to_string()));
839    }
840
841    #[test]
842    fn test_builder_timeout() {
843        let builder = UseAwsTool::builder().timeout(Duration::from_secs(120));
844        assert_eq!(builder.timeout, Some(Duration::from_secs(120)));
845    }
846
847    #[test]
848    fn test_builder_custom_service_target() {
849        let builder = UseAwsTool::builder().with_service_target("custom", "CustomService_20240101");
850        assert_eq!(
851            builder.custom_service_targets.get("custom"),
852            Some(&"CustomService_20240101".to_string())
853        );
854    }
855
856    // ==================== Validation tests ====================
857
858    #[test]
859    fn test_validate_input_empty_service() {
860        let input = UseAwsInput {
861            service_name: String::new(),
862            operation_name: "GetCallerIdentity".to_string(),
863            parameters: serde_json::json!({}),
864            region: "us-east-1".to_string(),
865            label: None,
866            profile_name: None,
867        };
868        let result = validate_input(&input);
869        assert!(result.is_err());
870        assert!(result.unwrap_err().to_string().contains("service_name"));
871    }
872
873    #[test]
874    fn test_validate_input_empty_operation() {
875        let input = UseAwsInput {
876            service_name: "sts".to_string(),
877            operation_name: String::new(),
878            parameters: serde_json::json!({}),
879            region: "us-east-1".to_string(),
880            label: None,
881            profile_name: None,
882        };
883        let result = validate_input(&input);
884        assert!(result.is_err());
885        assert!(result.unwrap_err().to_string().contains("operation_name"));
886    }
887
888    #[test]
889    fn test_validate_input_empty_region() {
890        let input = UseAwsInput {
891            service_name: "sts".to_string(),
892            operation_name: "GetCallerIdentity".to_string(),
893            parameters: serde_json::json!({}),
894            region: String::new(),
895            label: None,
896            profile_name: None,
897        };
898        let result = validate_input(&input);
899        assert!(result.is_err());
900        assert!(result.unwrap_err().to_string().contains("region"));
901    }
902
903    #[test]
904    fn test_validate_input_parameters_not_object() {
905        let input = UseAwsInput {
906            service_name: "sts".to_string(),
907            operation_name: "GetCallerIdentity".to_string(),
908            parameters: serde_json::json!([1, 2, 3]),
909            region: "us-east-1".to_string(),
910            label: None,
911            profile_name: None,
912        };
913        let result = validate_input(&input);
914        assert!(result.is_err());
915        assert!(result.unwrap_err().to_string().contains("array"));
916    }
917
918    #[test]
919    fn test_validate_input_success() {
920        let input = UseAwsInput {
921            service_name: "sts".to_string(),
922            operation_name: "GetCallerIdentity".to_string(),
923            parameters: serde_json::json!({}),
924            region: "us-east-1".to_string(),
925            label: None,
926            profile_name: None,
927        };
928        assert!(validate_input(&input).is_ok());
929    }
930
931    // ==================== Mutative operation detection ====================
932
933    #[test]
934    fn test_is_mutative_operation_create() {
935        assert!(is_mutative_operation("CreateBucket"));
936        assert!(is_mutative_operation("CreateTable"));
937    }
938
939    #[test]
940    fn test_is_mutative_operation_delete() {
941        assert!(is_mutative_operation("DeleteBucket"));
942        assert!(is_mutative_operation("DeleteItem"));
943    }
944
945    #[test]
946    fn test_is_mutative_operation_update() {
947        assert!(is_mutative_operation("UpdateItem"));
948        assert!(is_mutative_operation("UpdateTable"));
949    }
950
951    #[test]
952    fn test_is_mutative_operation_put() {
953        assert!(is_mutative_operation("PutObject"));
954        assert!(is_mutative_operation("PutItem"));
955    }
956
957    #[test]
958    fn test_is_mutative_operation_terminate() {
959        assert!(is_mutative_operation("TerminateInstances"));
960    }
961
962    #[test]
963    fn test_is_mutative_operation_non_mutative() {
964        assert!(!is_mutative_operation("GetCallerIdentity"));
965        assert!(!is_mutative_operation("ListBuckets"));
966        assert!(!is_mutative_operation("DescribeInstances"));
967        assert!(!is_mutative_operation("Scan"));
968        assert!(!is_mutative_operation("Query"));
969    }
970
971    // ==================== Endpoint generation ====================
972
973    #[test]
974    fn test_get_endpoint_standard_service() {
975        let endpoint = get_endpoint("dynamodb", "us-east-1");
976        assert_eq!(endpoint, "https://dynamodb.us-east-1.amazonaws.com");
977    }
978
979    #[test]
980    fn test_get_endpoint_sts_us_east_1() {
981        let endpoint = get_endpoint("sts", "us-east-1");
982        assert_eq!(endpoint, "https://sts.amazonaws.com");
983    }
984
985    #[test]
986    fn test_get_endpoint_sts_other_region() {
987        let endpoint = get_endpoint("sts", "us-west-2");
988        assert_eq!(endpoint, "https://sts.us-west-2.amazonaws.com");
989    }
990
991    #[test]
992    fn test_get_endpoint_iam() {
993        let endpoint = get_endpoint("iam", "us-east-1");
994        assert_eq!(endpoint, "https://iam.amazonaws.com");
995    }
996
997    #[test]
998    fn test_get_endpoint_s3() {
999        let endpoint = get_endpoint("s3", "us-west-2");
1000        assert_eq!(endpoint, "https://s3.us-west-2.amazonaws.com");
1001    }
1002
1003    // ==================== Service target prefix ====================
1004
1005    #[test]
1006    fn test_default_service_targets_contains_dynamodb() {
1007        let targets = default_service_targets();
1008        assert_eq!(
1009            targets.get("dynamodb"),
1010            Some(&"DynamoDB_20120810".to_string())
1011        );
1012    }
1013
1014    #[test]
1015    fn test_default_service_targets_contains_sts() {
1016        let targets = default_service_targets();
1017        assert_eq!(
1018            targets.get("sts"),
1019            Some(&"AWSSecurityTokenServiceV20110615".to_string())
1020        );
1021    }
1022
1023    #[test]
1024    fn test_default_service_targets_contains_lambda() {
1025        let targets = default_service_targets();
1026        assert_eq!(targets.get("lambda"), Some(&"AWSLambda".to_string()));
1027    }
1028
1029    // ==================== Header parsing ====================
1030
1031    #[test]
1032    fn test_parse_output_header_complete() {
1033        let output = "Service: sts\nOperation: GetCallerIdentity\nRegion: us-east-1\nLabel: Get identity\n\n---\n\n{\"Account\": \"123456789\"}";
1034        let (metadata, content) = parse_output_header(output);
1035
1036        assert_eq!(metadata.len(), 4);
1037        assert_eq!(metadata[0], ("Service", "sts"));
1038        assert_eq!(metadata[1], ("Operation", "GetCallerIdentity"));
1039        assert_eq!(metadata[2], ("Region", "us-east-1"));
1040        assert_eq!(metadata[3], ("Label", "Get identity"));
1041        assert!(content.contains("Account"));
1042    }
1043
1044    #[test]
1045    fn test_parse_output_header_no_separator() {
1046        let output = "Just plain content";
1047        let (metadata, content) = parse_output_header(output);
1048
1049        assert!(metadata.is_empty());
1050        assert_eq!(content, output);
1051    }
1052
1053    #[test]
1054    fn test_parse_output_header_with_warning() {
1055        let output = "Service: s3\nOperation: DeleteBucket\nWarning: This was a mutative operation\n\n---\n\n{}";
1056        let (metadata, _content) = parse_output_header(output);
1057
1058        assert_eq!(metadata.len(), 3);
1059        assert_eq!(metadata[2], ("Warning", "This was a mutative operation"));
1060    }
1061
1062    // ==================== Error parsing ====================
1063
1064    #[test]
1065    fn test_parse_aws_error_with_type() {
1066        let body = r#"{"__type": "ValidationException", "message": "Invalid input"}"#;
1067        let error = parse_aws_error(
1068            "dynamodb",
1069            "PutItem",
1070            "us-east-1",
1071            reqwest::StatusCode::BAD_REQUEST,
1072            body,
1073        );
1074        let msg = error.to_string();
1075
1076        assert!(msg.contains("dynamodb.PutItem"));
1077        assert!(msg.contains("us-east-1"));
1078        assert!(msg.contains("ValidationException"));
1079        assert!(msg.contains("Invalid input"));
1080    }
1081
1082    #[test]
1083    fn test_parse_aws_error_with_nested_error() {
1084        let body = r#"{"Error": {"Code": "AccessDenied", "Message": "Access denied"}}"#;
1085        let error = parse_aws_error(
1086            "s3",
1087            "GetObject",
1088            "us-west-2",
1089            reqwest::StatusCode::FORBIDDEN,
1090            body,
1091        );
1092        let msg = error.to_string();
1093
1094        assert!(msg.contains("s3.GetObject"));
1095        assert!(msg.contains("AccessDenied"));
1096        assert!(msg.contains("Access denied"));
1097    }
1098
1099    #[test]
1100    fn test_parse_aws_error_plain_text() {
1101        let body = "Service unavailable";
1102        let error = parse_aws_error(
1103            "sts",
1104            "GetCallerIdentity",
1105            "us-east-1",
1106            reqwest::StatusCode::SERVICE_UNAVAILABLE,
1107            body,
1108        );
1109        let msg = error.to_string();
1110
1111        assert!(msg.contains("sts.GetCallerIdentity"));
1112        assert!(msg.contains("us-east-1"));
1113        assert!(msg.contains("Service unavailable"));
1114    }
1115}