use common_failures::prelude::*;
use csv::{self, StringRecord};
use failure::{format_err, ResultExt};
use futures::{executor::block_on, future, FutureExt, StreamExt};
use hyper::Client;
use hyper_tls::HttpsConnector;
use log::{debug, error, trace, warn};
use std::{
cmp::max, io, iter::FromIterator, sync::Arc, thread::sleep, time::Duration,
};
use strum_macros::EnumString;
use tokio::sync::mpsc::{self, Receiver, Sender};
use crate::addresses::AddressColumnSpec;
use crate::async_util::run_sync_fn_in_background;
use crate::smartystreets::{
AddressRequest, MatchStrategy, SharedHyperClient, SmartyStreets,
};
use crate::structure::Structure;
use crate::Result;
const CHANNEL_BUFFER: usize = 8;
const CONCURRENCY: usize = 48;
const GEOCODE_SIZE: usize = 72;
#[derive(Debug, Clone, Copy, EnumString, Eq, PartialEq)]
#[strum(serialize_all = "snake_case")]
pub enum OnDuplicateColumns {
Error,
Replace,
Append,
}
struct Shared {
spec: AddressColumnSpec<usize>,
structure: Structure,
out_headers: StringRecord,
}
struct Chunk {
shared: Arc<Shared>,
rows: Vec<StringRecord>,
}
enum Message {
Chunk(Chunk),
EndOfStream,
}
pub async fn geocode_stdio(
spec: AddressColumnSpec<String>,
match_strategy: MatchStrategy,
license: String,
on_duplicate_columns: OnDuplicateColumns,
structure: Structure,
) -> Result<()> {
let (in_tx, in_rx) = mpsc::channel::<Message>(CHANNEL_BUFFER);
let (mut out_tx, out_rx) = mpsc::channel::<Message>(CHANNEL_BUFFER);
let read_fut = run_sync_fn_in_background("read CSV".to_owned(), move || {
read_csv_from_stdin(spec, structure, on_duplicate_columns, in_tx)
});
let write_fut = run_sync_fn_in_background("write CSV".to_owned(), move || {
write_csv_to_stdout(out_rx)
});
let client = Arc::new(
Client::builder()
.pool_max_idle_per_host(CONCURRENCY)
.build(HttpsConnector::new()),
);
let geocode_fut = async move {
let mut stream = in_rx
.map(move |message| {
geocode_message(
client.clone(),
match_strategy,
license.clone(),
message,
)
.boxed()
})
.buffered(CONCURRENCY);
while let Some(result) = stream.next().await {
out_tx
.send(result?)
.await
.map_err(|_| format_err!("could not send message to output thread"))?;
}
Ok::<_, Error>(())
}
.boxed();
let (read_result, geocode_result, write_result) =
future::join3(read_fut, geocode_fut, write_fut).await;
let read_result: Result<()> = read_result
.context("error reading input")
.map_err(|e| e.into());
let geocode_result: Result<()> = geocode_result
.context("error geocoding")
.map_err(|e| e.into());
let write_result: Result<()> = write_result
.context("error writing output")
.map_err(|e| e.into());
let mut failed = false;
if let Err(err) = &read_result {
failed = true;
eprintln!("{}", err.display_causes_and_backtrace());
}
if let Err(err) = &geocode_result {
failed = true;
eprintln!("{}", err.display_causes_and_backtrace());
}
if let Err(err) = &write_result {
failed = true;
eprintln!("{}", err.display_causes_and_backtrace());
}
if failed {
Err(format_err!("geocoding stdio failed"))
} else {
Ok(())
}
}
fn read_csv_from_stdin(
spec: AddressColumnSpec<String>,
structure: Structure,
on_duplicate_columns: OnDuplicateColumns,
mut tx: Sender<Message>,
) -> Result<()> {
let stdin = io::stdin();
let mut rdr = csv::Reader::from_reader(stdin.lock());
let mut in_headers = rdr.headers()?.to_owned();
debug!("input headers: {:?}", in_headers);
let (duplicate_column_indices, duplicate_column_names) = {
let duplicate_columns = spec.duplicate_columns(&structure, &in_headers)?;
let indices = duplicate_columns
.iter()
.map(|name_idx| name_idx.1)
.collect::<Vec<_>>();
let names = duplicate_columns
.iter()
.map(|name_idx| name_idx.0)
.collect::<Vec<_>>()
.join(", ");
(indices, names)
};
let mut should_remove_columns = false;
let mut remove_column_flags = vec![false; in_headers.len()];
if !duplicate_column_indices.is_empty() {
match on_duplicate_columns {
OnDuplicateColumns::Error => {
return Err(format_err!(
"input columns would conflict with geocoding columns: {}",
duplicate_column_names,
));
}
OnDuplicateColumns::Replace => {
warn!("replacing input columns: {}", duplicate_column_names);
should_remove_columns = true;
for i in duplicate_column_indices.iter().cloned() {
remove_column_flags[i] = true;
}
}
OnDuplicateColumns::Append => {
warn!(
"output contains duplicate columns: {}",
duplicate_column_names,
);
}
}
}
if should_remove_columns {
in_headers = remove_columns(&in_headers, &remove_column_flags);
}
let spec = spec.convert_to_indices_using_headers(&in_headers)?;
let chunk_size = max(1, GEOCODE_SIZE / spec.prefix_count());
let mut out_headers = in_headers;
for prefix in spec.prefixes() {
structure.add_header_columns(prefix, &mut out_headers)?;
}
debug!("output headers: {:?}", out_headers);
let shared = Arc::new(Shared {
spec,
structure,
out_headers,
});
let mut sent_chunk = false;
let mut rows = Vec::with_capacity(chunk_size);
for row in rdr.records() {
let mut row = row?;
if should_remove_columns {
row = remove_columns(&row, &remove_column_flags);
}
rows.push(row);
if rows.len() >= chunk_size {
trace!("sending {} input rows", rows.len());
block_on(tx.send(Message::Chunk(Chunk {
shared: shared.clone(),
rows,
})))
.map_err(|_| {
format_err!("could not send rows to geocoder (perhaps it failed)")
})?;
sent_chunk = true;
rows = Vec::with_capacity(chunk_size);
}
}
if !sent_chunk || !rows.is_empty() {
trace!("sending final {} input rows", rows.len());
block_on(tx.send(Message::Chunk(Chunk { shared, rows }))).map_err(|_| {
format_err!("could not send rows to geocoder (perhaps it failed)")
})?;
}
trace!("sending end-of-stream for input");
block_on(tx.send(Message::EndOfStream)).map_err(|_| {
format_err!("could not send end-of-stream to geocoder (perhaps it failed)")
})?;
debug!("done sending input");
Ok(())
}
fn remove_columns(row: &StringRecord, remove_column_flags: &[bool]) -> StringRecord {
debug_assert_eq!(row.len(), remove_column_flags.len());
StringRecord::from_iter(row.iter().zip(remove_column_flags).filter_map(
|(value, &remove)| {
if remove {
None
} else {
Some(value.to_owned())
}
},
))
}
fn write_csv_to_stdout(mut rx: Receiver<Message>) -> Result<()> {
let stdout = io::stdout();
let mut wtr = csv::Writer::from_writer(stdout.lock());
let mut headers_written = false;
let mut end_of_stream_seen = false;
while let Some(message) = block_on(rx.next()) {
match message {
Message::Chunk(chunk) => {
trace!("received {} output rows", chunk.rows.len());
if !headers_written {
wtr.write_record(&chunk.shared.out_headers)?;
headers_written = true;
}
for row in chunk.rows {
wtr.write_record(&row)?;
}
}
Message::EndOfStream => {
trace!("received end-of-stream for output");
assert!(headers_written);
end_of_stream_seen = true;
break;
}
}
}
if !end_of_stream_seen {
error!("did not receive end-of-stream");
return Err(format_err!(
"did not receive end-of-stream from geocoder (perhaps it failed)"
));
}
Ok(())
}
async fn geocode_message(
client: SharedHyperClient,
match_strategy: MatchStrategy,
license: String,
message: Message,
) -> Result<Message> {
match message {
Message::Chunk(chunk) => {
trace!("geocoding {} rows", chunk.rows.len());
Ok(Message::Chunk(
geocode_chunk(client, match_strategy, license, chunk).await?,
))
}
Message::EndOfStream => {
trace!("geocoding received end-of-stream");
Ok(Message::EndOfStream)
}
}
}
async fn geocode_chunk(
client: SharedHyperClient,
match_strategy: MatchStrategy,
license: String,
mut chunk: Chunk,
) -> Result<Chunk> {
let prefixes = chunk.shared.spec.prefixes();
let mut addresses = vec![];
for prefix in &prefixes {
let column_keys = chunk
.shared
.spec
.get(prefix)
.expect("should always have prefix");
for row in &chunk.rows {
addresses.push(AddressRequest {
address: column_keys.extract_address_from_record(row)?,
match_strategy,
});
}
}
let addresses_len = addresses.len();
let smartystreets = SmartyStreets::new(client)?;
trace!("geocoding {} addresses", addresses_len);
let mut failures: u8 = 0;
let geocoded = loop {
let result = smartystreets
.street_addresses(addresses.clone(), license.clone())
.await;
match result {
Err(ref err) if failures < 5 => {
failures += 1;
debug!("retrying smartystreets error: {}", err);
sleep(Duration::from_secs(2));
}
Err(err) => {
return Err(err)
.context("smartystreets error")
.map_err(|e| e.into());
}
Ok(geocoded) => {
break geocoded;
}
}
};
trace!("geocoded {} addresses", addresses_len);
for geocoded_for_prefix in geocoded.chunks(chunk.rows.len()) {
assert_eq!(geocoded_for_prefix.len(), chunk.rows.len());
for (response, row) in geocoded_for_prefix.iter().zip(&mut chunk.rows) {
if let Some(response) = response {
chunk
.shared
.structure
.add_value_columns_to_row(&response.fields, row)?;
} else {
chunk.shared.structure.add_empty_columns_to_row(row)?;
}
}
}
Ok(chunk)
}