athena_cli/commands/
query.rs

1use crate::cli;
2use crate::context::Context;
3use anyhow::Result;
4use aws_sdk_athena::types::{
5    QueryExecutionContext, QueryExecutionState, ResultConfiguration, ResultReuseByAgeConfiguration,
6    ResultReuseConfiguration,
7};
8use aws_sdk_athena::Client;
9use byte_unit::Byte;
10use polars::prelude::*;
11use std::{thread, time::Duration};
12
13pub async fn execute(ctx: &Context, args: &cli::QueryArgs) -> Result<()> {
14    println!("Executing query: {}", args.query);
15
16    let database = ctx
17        .database()
18        .ok_or_else(|| anyhow::anyhow!("Database name is required but was not provided"))?;
19
20    let client = ctx.create_athena_client();
21
22    let query_id = start_query(
23        &client,
24        &database,
25        &args.query,
26        &ctx.workgroup(),
27        args.reuse_time,
28        ctx.output_location()
29            .as_deref()
30            .unwrap_or("s3://aws-athena-query-results"),
31    )
32    .await?;
33
34    println!("Query execution ID: {}", query_id);
35
36    let df = get_query_results(&client, &query_id).await?;
37    println!("Results DataFrame:");
38    println!("{}", df);
39
40    Ok(())
41}
42
43async fn start_query(
44    client: &Client,
45    database: &str,
46    query: &str,
47    workgroup: &str,
48    reuse_duration: Duration,
49    output_location: &str,
50) -> Result<String> {
51    let context = QueryExecutionContext::builder().database(database).build();
52
53    let config = ResultConfiguration::builder()
54        .output_location(output_location)
55        .build();
56
57    let result = client
58        .start_query_execution()
59        .result_reuse_configuration(
60            ResultReuseConfiguration::builder()
61                .result_reuse_by_age_configuration(
62                    ResultReuseByAgeConfiguration::builder()
63                        .enabled(true)
64                        .max_age_in_minutes(reuse_duration.as_secs() as i32 / 60)
65                        .build(),
66                )
67                .build(),
68        )
69        .query_string(query)
70        .query_execution_context(context)
71        .result_configuration(config)
72        .work_group(workgroup)
73        .send()
74        .await?;
75
76    Ok(result.query_execution_id().unwrap_or_default().to_string())
77}
78
79async fn get_query_results(client: &Client, query_execution_id: &str) -> Result<DataFrame> {
80    // Wait for query to complete
81    loop {
82        let status = client
83            .get_query_execution()
84            .query_execution_id(query_execution_id)
85            .send()
86            .await?;
87
88        if let Some(execution) = status.query_execution() {
89            match execution.status().unwrap().state().as_ref() {
90                Some(QueryExecutionState::Succeeded) => {
91                    // Print query info once before breaking
92                    if let Some(result_config) = execution.result_configuration() {
93                        if let Some(output_location) = result_config.output_location() {
94                            println!("Results S3 path: {}", output_location);
95                        }
96                    }
97
98                    if let Some(statistics) = execution.statistics() {
99                        let data_scanned = statistics.data_scanned_in_bytes().unwrap_or(0);
100                        let is_cached = data_scanned == 0;
101                        println!(
102                            "Query cache status: {}",
103                            if is_cached {
104                                String::from("Results retrieved from cache")
105                            } else {
106                                let formatted_size = Byte::from_i64(data_scanned)
107                                    .map(|b| {
108                                        b.get_appropriate_unit(byte_unit::UnitType::Decimal)
109                                            .to_string()
110                                    })
111                                    .unwrap_or_else(|| "-".to_string());
112                                format!("Fresh query execution (scanned {})", formatted_size)
113                            }
114                        );
115                    }
116                    break;
117                }
118                Some(QueryExecutionState::Failed) | Some(QueryExecutionState::Cancelled) => {
119                    return Err(anyhow::anyhow!("Query failed or was cancelled"));
120                }
121                _ => {
122                    thread::sleep(Duration::from_secs(1));
123                    continue;
124                }
125            }
126        }
127    }
128
129    let mut all_columns: Vec<Vec<String>> = Vec::new();
130    let mut column_names: Vec<String> = Vec::new();
131    let mut next_token: Option<String> = None;
132
133    // Get first page and column names
134    let mut results = client
135        .get_query_results()
136        .query_execution_id(query_execution_id)
137        .max_results(100)
138        .send()
139        .await?;
140
141    // Initialize column names from first result
142    if let Some(rs) = results.result_set() {
143        if let Some(first_row) = rs.rows().first() {
144            column_names = first_row
145                .data()
146                .iter()
147                .map(|d| d.var_char_value().unwrap_or_default().to_string())
148                .collect();
149            all_columns = vec![Vec::new(); column_names.len()];
150        }
151    }
152
153    // Process results page by page
154    let mut page_count = 1;
155    loop {
156        if let Some(rs) = results.result_set() {
157            let start_idx = if next_token.is_none() { 1 } else { 0 };
158            let rows_count = rs.rows().len() - start_idx;
159
160            println!("Processing page {}: {} rows", page_count, rows_count);
161
162            for row in rs.rows().iter().skip(start_idx) {
163                for (i, data) in row.data().iter().enumerate() {
164                    all_columns[i].push(data.var_char_value().unwrap_or_default().to_string());
165                }
166            }
167        }
168
169        next_token = results.next_token().map(|s| s.to_string());
170
171        if next_token.is_none() {
172            println!(
173                "Finished processing {} pages, total rows: {}",
174                page_count,
175                all_columns[0].len()
176            );
177            break;
178        }
179
180        page_count += 1;
181        results = client
182            .get_query_results()
183            .query_execution_id(query_execution_id)
184            .max_results(100)
185            .next_token(next_token.as_ref().unwrap())
186            .send()
187            .await?;
188    }
189
190    // Create DataFrame
191    let series = all_columns
192        .iter()
193        .zip(column_names.iter())
194        .map(|(col, name)| Series::new(name.into(), col))
195        .map(|s| s.into_column())
196        .collect();
197
198    // Convert Series to Columns and create DataFrame
199    Ok(DataFrame::new(series)?)
200}