use log::*;
use std::fmt::Debug;
use std::collections::HashSet;
use std::io::Cursor;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use log::*;
use sqs_lambda::cache::Cache;
use sqs_lambda::completion_event_serializer::CompletionEventSerializer;
use grapl_graph_descriptions::graph_description::*;
use thiserror::Error;
use sqs_lambda::sqs_completion_handler::{SqsCompletionHandlerActor, SqsCompletionHandler, CompletionPolicy};
use sqs_lambda::event_emitter::S3EventEmitter;
use rusoto_s3::S3Client;
use rusoto_sqs::SqsClient;
use rusoto_core::Region;
use sqs_lambda::redis_cache::RedisCache;
use sqs_lambda::sqs_consumer::{SqsConsumerActor, SqsConsumer, ConsumePolicy};
use tokio::sync::oneshot;
use std::sync::mpsc::Receiver;
use sqs_lambda::event_processor::{EventProcessorActor, EventProcessor};
use sqs_lambda::event_retriever::S3PayloadRetriever;
use sqs_lambda::event_handler::EventHandler;
use sqs_lambda::event_decoder::PayloadDecoder;
use lambda_runtime::Context;
use aws_lambda_events::event::sqs::SqsEvent;
#[derive(Error, Debug)]
pub enum GraphGeneratorError {
#[error("IOError")]
IoError(#[from] std::io::Error),
#[error("Failed to encode protobuf")]
EncodeError(#[from] prost::EncodeError),
#[error("Failed to generate graph")]
GenerationError,
}
#[derive(Clone, Debug, Default)]
pub struct SubgraphSerializer {
proto: Vec<u8>,
}
impl SubgraphSerializer {
pub fn serialize_completed_events(
&mut self,
completed_events: &[Graph],
) -> Result<Vec<Vec<u8>>, GraphGeneratorError> {
let mut subgraph = Graph::new(
0
);
let mut pre_nodes = 0;
let mut pre_edges = 0;
for sg in completed_events.iter() {
pre_nodes += sg.nodes.len();
pre_edges += sg.edges.len();
subgraph.merge(sg);
}
if subgraph.is_empty() {
warn!(
concat!(
"Output subgraph is empty. Serializing to empty vector.",
"pre_nodes: {} pre_edges: {}"
),
pre_nodes,
pre_edges,
);
return Ok(vec![]);
}
info!(
"Serializing {} nodes {} edges. Down from {} nodes {} edges.",
subgraph.nodes.len(),
subgraph.edges.len(),
pre_nodes,
pre_edges,
);
let subgraphs = GeneratedSubgraphs { subgraphs: vec![subgraph] };
self.proto.clear();
prost::Message::encode(&subgraphs, &mut self.proto)?;
let mut compressed = Vec::with_capacity(self.proto.len());
let mut proto = Cursor::new(&self.proto);
zstd::stream::copy_encode(&mut proto, &mut compressed, 4)?;
Ok(vec![compressed])
}
}
impl CompletionEventSerializer for SubgraphSerializer {
type CompletedEvent = Graph;
type Output = Vec<u8>;
type Error = GraphGeneratorError;
fn serialize_completed_events(
&mut self,
completed_events: &[Self::CompletedEvent],
) -> Result<Vec<Self::Output>, Self::Error> {
SubgraphSerializer::serialize_completed_events(
self,
completed_events
)
}
}
#[derive(Debug, Clone, Default)]
pub struct ZstdDecoder;
impl PayloadDecoder<Vec<u8>> for ZstdDecoder
{
fn decode(&mut self, body: Vec<u8>) -> Result<Vec<u8>, Box<dyn std::error::Error>>
{
let mut decompressed = Vec::new();
let mut body = Cursor::new(&body);
zstd::stream::copy_decode(&mut body, &mut decompressed)?;
Ok(decompressed)
}
}
pub fn time_based_key_fn(event: &[u8]) -> String {
info!("event length {}", event.len());
let cur_ms = match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(n) => n.as_millis(),
Err(_) => panic!("SystemTime before UNIX EPOCH!"),
};
let cur_day = cur_ms - (cur_ms % 86400);
format!(
"{}/{}-{}",
cur_day, cur_ms, uuid::Uuid::new_v4()
)
}
pub fn map_sqs_message(event: aws_lambda_events::event::sqs::SqsMessage) -> rusoto_sqs::Message {
rusoto_sqs::Message {
attributes: Some(event.attributes),
body: event.body,
md5_of_body: event.md5_of_body,
md5_of_message_attributes: event.md5_of_message_attributes,
message_attributes: None,
message_id: event.message_id,
receipt_handle: event.receipt_handle,
}
}
pub fn default_event_processor<EH, ProcErr>(
region: Region,
sqs_consumer: SqsConsumerActor,
sqs_completion_handler: SqsCompletionHandlerActor<Graph, ProcErr>,
event_handler: EH,
) -> EventProcessorActor
where
ProcErr: Debug + Send + Sync + Clone + 'static,
EH: EventHandler<
InputEvent=Vec<u8>,
OutputEvent=Graph,
Error=ProcErr,
> + Send + Sync + Clone + 'static,
{
EventProcessorActor::new(EventProcessor::new(
sqs_consumer,
sqs_completion_handler,
event_handler,
S3PayloadRetriever::new(S3Client::new(region.clone()), ZstdDecoder::default()),
))
}
pub fn default_consume_policy(
ctx: lambda_runtime::Context
) -> ConsumePolicy {
ConsumePolicy::new(
ctx,
Duration::from_secs(10),
3,
)
}
pub fn default_consumer<ProcErr>(
queue_url: String,
sqs_completion_handler: SqsCompletionHandlerActor<Graph, ProcErr>,
region: Region,
consume_policy: ConsumePolicy,
shutdown_tx: oneshot::Sender<()>,
) -> SqsConsumerActor
where
ProcErr: Debug + Send + Sync + Clone + 'static,
{
SqsConsumerActor::new(
SqsConsumer::new(
SqsClient::new(region.clone()),
queue_url.clone(),
consume_policy,
sqs_completion_handler.clone(),
shutdown_tx,
)
)
}
pub fn default_completion_handler<ProcErr>(
queue_url: String,
region: Region,
output_bucket: String,
cache: RedisCache,
tx: std::sync::mpsc::SyncSender<String>,
) -> SqsCompletionHandlerActor<Graph, ProcErr>
where
ProcErr: Debug + Send + Sync + Clone + 'static,
{
let completion_handler = SqsCompletionHandlerActor::new(
SqsCompletionHandler::new(
SqsClient::new(region.clone()),
queue_url.to_string(),
SubgraphSerializer { proto: Vec::with_capacity(1024) },
S3EventEmitter::new(
S3Client::new(region.clone()),
output_bucket,
time_based_key_fn,
),
CompletionPolicy::new(
1000,
Duration::from_secs(30),
),
move |_self_actor, result: Result<String, String>| {
match result {
Ok(worked) => {
info!("Handled an event, which was successfully deleted: {}", &worked);
tx.send(worked).unwrap();
}
Err(worked) => {
info!("Handled an initial_event, though we failed to delete it: {}", &worked);
tx.send(worked).unwrap();
}
}
},
cache.clone()
)
);
completion_handler
}
pub fn generate_subgraphs<ProcErr, EH>(
event: SqsEvent,
ctx: Context,
region: Region,
queue_url: String,
output_bucket: String,
cache_address: String,
event_handler: EH,
) -> Result<(), GraphGeneratorError>
where
ProcErr: Debug + Send + Sync + Clone + 'static,
EH: EventHandler<
InputEvent=Vec<u8>,
OutputEvent=Graph,
Error=ProcErr,
> + Send + Sync + Clone + 'static,
{
let mut initial_events: HashSet<String> = event.records
.iter()
.map(|event| event.message_id.clone().unwrap())
.collect();
info!("Initial Events {:?}", initial_events);
let (tx, rx) = std::sync::mpsc::sync_channel(10);
std::thread::spawn(move || {
tokio_compat::run_std(
async move {
let cache = RedisCache::new(cache_address.to_owned()).await.expect("Could not create redis client");
info!("SqsCompletionHandler");
let finished_tx = tx.clone();
let sqs_completion_handler = default_completion_handler(
queue_url.clone(),
region.clone(),
output_bucket.clone(),
cache.clone(),
tx,
);
let (shutdown_tx, shutdown_notify) = tokio::sync::oneshot::channel();
info!("SqsConsumer");
let sqs_consumer = default_consumer(
queue_url.clone(),
sqs_completion_handler.clone(),
region.clone(),
default_consume_policy(ctx),
shutdown_tx,
);
info!("EventProcessors");
let event_processors: Vec<_> = (0..10)
.map(|_| {
default_event_processor(
region.clone(),
sqs_consumer.clone(),
sqs_completion_handler.clone(),
event_handler.clone(),
)
})
.collect();
info!("Start Processing");
futures::future::join_all(event_processors.iter().map(|ep| ep.start_processing())).await;
let mut proc_iter = event_processors.iter().cycle();
for event in event.records {
let next_proc = proc_iter.next().unwrap();
next_proc.process_event(
map_sqs_message(event)
).await;
}
info!("Waiting for shutdown notification");
let _ = shutdown_notify.await;
info!("Consumer shutdown");
finished_tx.send("Completed".to_owned()).unwrap();
});
});
info!("Checking acks");
for r in &rx {
info!("Acking event: {}", &r);
initial_events.remove(&r);
if r == "Completed" {
let r = rx.recv_timeout(Duration::from_millis(100));
if let Ok(r) = r {
initial_events.remove(&r);
}
while let Ok(r) = rx.try_recv() {
initial_events.remove(&r);
}
break;
}
}
info!("Completed execution");
if initial_events.is_empty() {
info!("Successfully acked all initial events");
Ok(())
} else {
Err(
GraphGeneratorError::GenerationError
)
}
}