athena_cli/commands/
query.rs1use crate::cli;
2use crate::context::Context;
3use anyhow::Result;
4use aws_sdk_athena::types::{
5 QueryExecutionContext, QueryExecutionState, ResultConfiguration, ResultReuseByAgeConfiguration,
6 ResultReuseConfiguration,
7};
8use aws_sdk_athena::Client;
9use byte_unit::Byte;
10use polars::prelude::*;
11use std::{thread, time::Duration};
12
13pub async fn execute(ctx: &Context, args: &cli::QueryArgs) -> Result<()> {
14 println!("Executing query: {}", args.query);
15
16 let database = ctx
17 .database()
18 .ok_or_else(|| anyhow::anyhow!("Database name is required but was not provided"))?;
19
20 let client = ctx.create_athena_client();
21
22 let query_id = start_query(
23 &client,
24 &database,
25 &args.query,
26 &ctx.workgroup(),
27 args.reuse_time,
28 ctx.output_location()
29 .as_deref()
30 .unwrap_or("s3://aws-athena-query-results"),
31 )
32 .await?;
33
34 println!("Query execution ID: {}", query_id);
35
36 let df = get_query_results(&client, &query_id).await?;
37 println!("Results DataFrame:");
38 println!("{}", df);
39
40 Ok(())
41}
42
43async fn start_query(
44 client: &Client,
45 database: &str,
46 query: &str,
47 workgroup: &str,
48 reuse_duration: Duration,
49 output_location: &str,
50) -> Result<String> {
51 let context = QueryExecutionContext::builder().database(database).build();
52
53 let config = ResultConfiguration::builder()
54 .output_location(output_location)
55 .build();
56
57 let result = client
58 .start_query_execution()
59 .result_reuse_configuration(
60 ResultReuseConfiguration::builder()
61 .result_reuse_by_age_configuration(
62 ResultReuseByAgeConfiguration::builder()
63 .enabled(true)
64 .max_age_in_minutes(reuse_duration.as_secs() as i32 / 60)
65 .build(),
66 )
67 .build(),
68 )
69 .query_string(query)
70 .query_execution_context(context)
71 .result_configuration(config)
72 .work_group(workgroup)
73 .send()
74 .await?;
75
76 Ok(result.query_execution_id().unwrap_or_default().to_string())
77}
78
79async fn get_query_results(client: &Client, query_execution_id: &str) -> Result<DataFrame> {
80 loop {
82 let status = client
83 .get_query_execution()
84 .query_execution_id(query_execution_id)
85 .send()
86 .await?;
87
88 if let Some(execution) = status.query_execution() {
89 match execution.status().unwrap().state().as_ref() {
90 Some(QueryExecutionState::Succeeded) => {
91 if let Some(result_config) = execution.result_configuration() {
93 if let Some(output_location) = result_config.output_location() {
94 println!("Results S3 path: {}", output_location);
95 }
96 }
97
98 if let Some(statistics) = execution.statistics() {
99 let data_scanned = statistics.data_scanned_in_bytes().unwrap_or(0);
100 let is_cached = data_scanned == 0;
101 println!(
102 "Query cache status: {}",
103 if is_cached {
104 String::from("Results retrieved from cache")
105 } else {
106 let formatted_size = Byte::from_i64(data_scanned)
107 .map(|b| {
108 b.get_appropriate_unit(byte_unit::UnitType::Decimal)
109 .to_string()
110 })
111 .unwrap_or_else(|| "-".to_string());
112 format!("Fresh query execution (scanned {})", formatted_size)
113 }
114 );
115 }
116 break;
117 }
118 Some(QueryExecutionState::Failed) | Some(QueryExecutionState::Cancelled) => {
119 return Err(anyhow::anyhow!("Query failed or was cancelled"));
120 }
121 _ => {
122 thread::sleep(Duration::from_secs(1));
123 continue;
124 }
125 }
126 }
127 }
128
129 let mut all_columns: Vec<Vec<String>> = Vec::new();
130 let mut column_names: Vec<String> = Vec::new();
131 let mut next_token: Option<String> = None;
132
133 let mut results = client
135 .get_query_results()
136 .query_execution_id(query_execution_id)
137 .max_results(100)
138 .send()
139 .await?;
140
141 if let Some(rs) = results.result_set() {
143 if let Some(first_row) = rs.rows().first() {
144 column_names = first_row
145 .data()
146 .iter()
147 .map(|d| d.var_char_value().unwrap_or_default().to_string())
148 .collect();
149 all_columns = vec![Vec::new(); column_names.len()];
150 }
151 }
152
153 let mut page_count = 1;
155 loop {
156 if let Some(rs) = results.result_set() {
157 let start_idx = if next_token.is_none() { 1 } else { 0 };
158 let rows_count = rs.rows().len() - start_idx;
159
160 println!("Processing page {}: {} rows", page_count, rows_count);
161
162 for row in rs.rows().iter().skip(start_idx) {
163 for (i, data) in row.data().iter().enumerate() {
164 all_columns[i].push(data.var_char_value().unwrap_or_default().to_string());
165 }
166 }
167 }
168
169 next_token = results.next_token().map(|s| s.to_string());
170
171 if next_token.is_none() {
172 println!(
173 "Finished processing {} pages, total rows: {}",
174 page_count,
175 all_columns[0].len()
176 );
177 break;
178 }
179
180 page_count += 1;
181 results = client
182 .get_query_results()
183 .query_execution_id(query_execution_id)
184 .max_results(100)
185 .next_token(next_token.as_ref().unwrap())
186 .send()
187 .await?;
188 }
189
190 let series = all_columns
192 .iter()
193 .zip(column_names.iter())
194 .map(|(col, name)| Series::new(name.into(), col))
195 .map(|s| s.into_column())
196 .collect();
197
198 Ok(DataFrame::new(series)?)
200}