use crate::Error;
use hyper::client::HttpConnector;
use hyper_proxy::{Intercept, Proxy, ProxyConnector};
use rusoto_core::credential::ChainProvider;
use rusoto_core::request::HttpClient;
use rusoto_core::Region;
use rusoto_athena::{
Athena as AthenaTrait,
AthenaClient,
GetQueryExecutionInput,
GetQueryExecutionOutput,
GetQueryResultsInput,
ResultConfiguration,
StartQueryExecutionInput,
QueryExecution,
QueryExecutionContext,
QueryExecutionStatistics,
QueryExecutionStatus,
};
use std::str::FromStr;
use std::time::{Duration, Instant};
mod types;
pub use self::types::*;
pub struct Athena {
client: AthenaClient,
database: String,
result_bucket: String,
workgroup: Option<String>
}
impl Athena {
pub fn new(region: &str, database: &str, result_bucket: &str, workgroup: Option<String>) -> crate::Result<Self> {
let client = match std::env::var("HTTPS_PROXY") {
Ok(proxy_uri) => {
let proxy = Proxy::new(Intercept::All, proxy_uri.parse()?);
let proxy_connector = ProxyConnector::from_proxy(HttpConnector::new(), proxy)?;
let http = HttpClient::from_connector(proxy_connector);
AthenaClient::new_with(http, ChainProvider::new(), Region::from_str(®ion)?)
}
Err(_) => AthenaClient::new(Region::from_str(®ion)?),
};
Ok(Self {
client,
database: database.into(),
result_bucket: result_bucket.into(),
workgroup
})
}
pub async fn query(&self, query: &str) -> crate::Result<QueryResult> {
let query_req = StartQueryExecutionInput {
client_request_token: Some(uuid::Uuid::new_v4().to_string()),
result_configuration: Some(ResultConfiguration {
encryption_configuration: Default::default(),
output_location: Some(self.result_bucket.clone()),
}),
query_execution_context: Some(QueryExecutionContext{
database: Some(self.database.clone()),
..Default::default()
}),
query_string: query.to_owned(),
work_group: self.workgroup.clone()
};
let query_execution_id = self.client.start_query_execution(query_req).await
.map_err(|error| {
tracing::error!(%error, "error starting query execution");
Error::AthenaError
})?
.query_execution_id.expect("missing execution id");
tracing::trace!(%query_execution_id);
let start = Instant::now();
let mut err_count = 0u8;
let mut result = loop {
let res = match self.client.get_query_execution(GetQueryExecutionInput {
query_execution_id: query_execution_id.clone()
}).await {
Ok(res) => res,
Err(error) => {
err_count += 1;
if err_count > 5 {
tracing::error!(%error, "error getting query execution status");
Err(Error::AthenaError)?;
}
tokio::time::delay_for(Duration::from_millis(250)).await;
continue;
}
};
match res {
GetQueryExecutionOutput {
query_execution: Some(QueryExecution {
query_execution_id: Some(ref query_execution_id),
statistics: Some(QueryExecutionStatistics {
data_scanned_in_bytes: Some(data_scanned_bytes),
query_queue_time_in_millis: Some(query_queue_time_ms),
total_execution_time_in_millis: Some(total_execution_time_ms),
..
}),
status: Some(QueryExecutionStatus {
state: Some(ref state),
..
}),
..
})
} if state == "SUCCEEDED" => {
break QueryResult {
query_execution_id: query_execution_id.into(),
data_scanned_bytes,
query_queue_time_ms,
total_execution_time_ms,
rows: 0,
columns: vec![],
data: vec![]
}
},
GetQueryExecutionOutput {
query_execution: Some(QueryExecution {
status: Some(QueryExecutionStatus {
state: Some(ref state),
state_change_reason: Some(ref state_change_reason),
..
}),
..
}),
..
} => {
tracing::error!(result = %state, reason = %state_change_reason);
match state.as_str() {
"FAILED" => Err(Error::QueryFailed(state_change_reason.to_owned()))?,
"CANCELLED" => Err(Error::QueryCancelled)?,
_ => unimplemented!()
};
},
GetQueryExecutionOutput {
query_execution: Some(QueryExecution {
status: Some(QueryExecutionStatus {
state: Some(ref state),
..
}),
..
}),
..
} if state == "RUNNING" || state == "QUEUED" => {
tracing::trace!(
%state,
time_taken = %humantime::format_duration(start.elapsed()).to_string()
);
tokio::time::delay_for(Duration::from_secs(1)).await;
},
v @ _ => {
tracing::debug!(debug = ?v);
Err(Error::QueryError)?
}
}
};
let mut result_req = GetQueryResultsInput {
next_token: None,
query_execution_id: result.query_execution_id.clone(),
..Default::default()
};
loop {
let res = self.client.get_query_results(result_req.clone()).await?;
let resultset = res.result_set.expect("result_set was none");
let cols = resultset.result_set_metadata.expect("missing result_set_metadata")
.column_info.expect("missing column_info");
result.columns = cols.into_iter().map(|c| c.name).collect();
let rows = resultset.rows.expect("missing result_set.rows");
rows.into_iter()
.skip(1) .for_each(|row| {
let row = row.data.expect("missing row.data");
let new: Vec<String> = row.into_iter()
.map(|field| {
field.var_char_value.unwrap_or_else(String::new)
})
.collect();
result.append_row(new);
});
tracing::trace!(
state = %"SUCCEEDED",
rows_read = %result.rows
);
if res.next_token.is_some() {
result_req.next_token = res.next_token;
continue;
}
break;
}
Ok(result)
}
}