use crate::tools::{config::Config as KineticsConfig, resource_name};
use aws_lambda_events::sqs::{BatchItemFailure, SqsBatchResponse, SqsEvent};
use aws_sdk_sqs::operation::send_message::builders::SendMessageFluentBuilder;
use eyre::OptionExt;
use kinetics_parser::ParsedFunction;
use lambda_runtime::LambdaEvent;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{OnceCell, RwLock};
#[derive(Clone)]
pub struct Client {
queue: SendMessageFluentBuilder,
}
static SQS_CLIENT_CACHE: OnceCell<Arc<RwLock<HashMap<String, Client>>>> = OnceCell::const_new();
impl Client {
pub fn new(queue: SendMessageFluentBuilder) -> Self {
Client { queue }
}
pub async fn send(
&self,
message: impl ::std::convert::Into<::std::string::String>,
) -> eyre::Result<()> {
self.queue.clone().message_body(message).send().await?;
Ok(())
}
pub async fn from_worker<'a, Fut>(
worker: impl Fn(Vec<Record>, &'a HashMap<String, String>, &'a KineticsConfig) -> Fut,
) -> eyre::Result<Self>
where
Fut:
std::future::Future<Output = Result<Retries, Box<dyn std::error::Error + Send + Sync>>>,
{
let cache_key = std::any::type_name_of_val(&worker);
let cache = SQS_CLIENT_CACHE
.get_or_init(|| async { Arc::new(RwLock::new(HashMap::new())) })
.await;
let mut write_guard = cache.write().await;
if let Some(client) = write_guard.get(cache_key) {
return Ok(client.clone());
}
let client = Client {
queue: {
let (project_name, function_path) = cache_key
.split_once("::")
.ok_or_eyre("Failed to get the project name from a worker")?;
let region = std::env::var("AWS_REGION").unwrap_or("us-east-1".to_string());
let queue_endpoint_url = std::env::var("KINETICS_QUEUE_ENDPOINT_URL")
.unwrap_or(format!("https://sqs.{region}.amazonaws.com"));
let config = if std::env::var("KINETICS_IS_LOCAL").is_ok() {
aws_config::defaults(aws_config::BehaviorVersion::latest())
.endpoint_url(&queue_endpoint_url)
.load()
.await
} else {
aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await
};
let queue_name = std::env::var("KINETICS_QUEUE_NAME")
.or_else(|_| {
Ok::<String, std::env::VarError>(resource_name(
&std::env::var("KINETICS_USERNAME")
.expect("KINETICS_USERNAME is not set"),
project_name,
&ParsedFunction::path_to_name(&function_path.replace("::", "/")),
))
})
.expect("Queue name is not set");
let account_id = std::env::var("KINETICS_CLOUD_ACCOUNT_ID")
.expect("KINETICS_CLOUD_ACCOUNT_ID is not set");
aws_sdk_sqs::Client::new(&config)
.send_message()
.queue_url(format!("{queue_endpoint_url}/{account_id}/{queue_name}"))
},
};
write_guard.insert(cache_key.to_string(), client.clone());
Ok(client)
}
}
#[derive(Default)]
pub struct Retries {
ids: Vec<String>,
}
impl Retries {
pub fn new() -> Self {
Default::default()
}
pub fn add(&mut self, item: &str) {
self.ids.push(item.to_string());
}
pub fn collect(&self) -> SqsBatchResponse {
let mut sqs_batch_response = SqsBatchResponse::default();
for id in self.ids.iter() {
let mut item = BatchItemFailure::default();
item.item_identifier = id.to_owned();
sqs_batch_response.batch_item_failures.push(item);
}
sqs_batch_response
}
}
#[derive(Deserialize, Serialize, Debug)]
pub struct Record {
#[serde(default)]
pub message_id: Option<String>,
#[serde(default)]
pub body: Option<String>,
}
impl Record {
pub fn from_sqsevent(event: LambdaEvent<SqsEvent>) -> eyre::Result<Vec<Record>> {
Ok(event
.payload
.records
.iter()
.map(|r| Record {
message_id: r.message_id.clone(),
body: r.body.clone(),
})
.collect())
}
}