use crate::{
RUNTIME_EMULATOR_PATH,
error::ServerError,
instance_pool::{InstanceId, InstancePool},
requests::{InvokeRequest, LambdaResponse, NextEvent},
};
use cargo_lambda_metadata::cargo::{binary_targets, watch::FunctionRouter};
use miette::Result;
use mpsc::{Receiver, Sender, channel};
use std::{
collections::{HashMap, HashSet, hash_map::Entry},
net::SocketAddr,
path::PathBuf,
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
time::Duration,
};
use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
use tracing::debug;
use uuid::Uuid;
#[derive(Clone)]
pub(crate) struct RuntimeState {
runtime_addr: SocketAddr,
proxy_addr: Option<SocketAddr>,
runtime_url: String,
manifest_path: PathBuf,
only_lambda_apis: bool,
pub initial_functions: HashSet<String>,
pub function_router: Option<FunctionRouter>,
pub req_cache: RequestCache,
pub res_cache: ResponseCache,
pub ext_cache: ExtensionCache,
pub instance_pools: Arc<RwLock<HashMap<String, InstancePool>>>,
pub connection_tracker: ConnectionTracker,
pub max_concurrency: usize,
}
pub(crate) type RefRuntimeState = Arc<RuntimeState>;
impl RuntimeState {
pub(crate) fn new(
runtime_addr: SocketAddr,
proxy_addr: Option<SocketAddr>,
manifest_path: PathBuf,
only_lambda_apis: bool,
initial_functions: HashSet<String>,
function_router: Option<FunctionRouter>,
max_concurrency: usize,
) -> RuntimeState {
RuntimeState {
runtime_addr,
proxy_addr,
manifest_path,
only_lambda_apis,
initial_functions,
function_router,
runtime_url: format!("http://{runtime_addr}{RUNTIME_EMULATOR_PATH}"),
req_cache: RequestCache::new(),
res_cache: ResponseCache::new(),
ext_cache: ExtensionCache::default(),
instance_pools: Arc::new(RwLock::new(HashMap::new())),
connection_tracker: ConnectionTracker::new(),
max_concurrency,
}
}
pub(crate) fn addresses(&self) -> (SocketAddr, Option<SocketAddr>, String) {
(self.runtime_addr, self.proxy_addr, self.runtime_url.clone())
}
pub(crate) fn function_addr(&self, name: &str) -> String {
format!("{}/{}", &self.runtime_url, name)
}
pub(crate) fn is_default_function_enabled(&self) -> bool {
self.initial_functions.len() == 1 || self.only_lambda_apis
}
pub(crate) fn is_function_available(&self, name: &str) -> Result<(), HashSet<String>> {
if self.initial_functions.contains(name) {
return Ok(());
}
match binary_targets(&self.manifest_path, false) {
Err(err) => {
tracing::error!(?err, "failed to load the project's binaries");
Err(self.initial_functions.clone())
}
Ok(binaries) if binaries.contains(name) => Ok(()),
Ok(binaries) => Err(binaries),
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct RequestQueue {
tx: Arc<Sender<InvokeRequest>>,
rx: Arc<Mutex<Receiver<InvokeRequest>>>,
depth: Arc<AtomicUsize>,
}
impl RequestQueue {
pub fn new() -> RequestQueue {
let (tx, rx) = channel::<InvokeRequest>(100);
RequestQueue {
tx: Arc::new(tx),
rx: Arc::new(Mutex::new(rx)),
depth: Arc::new(AtomicUsize::new(0)),
}
}
pub async fn pop(&self) -> Option<InvokeRequest> {
let mut rx = self.rx.lock().await;
let result = rx.recv().await;
if result.is_some() {
self.depth.fetch_sub(1, Ordering::Relaxed);
}
result
}
pub async fn push(&self, req: InvokeRequest) -> Result<(), ServerError> {
self.tx
.send(req)
.await
.map_err(|e| ServerError::SendInvokeMessage(Box::new(e)))?;
self.depth.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn depth(&self) -> usize {
self.depth.load(Ordering::Relaxed)
}
}
#[derive(Clone, Debug)]
pub(crate) struct RequestCache {
inner: Arc<RwLock<HashMap<String, RequestQueue>>>,
}
impl RequestCache {
pub fn new() -> RequestCache {
RequestCache {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn init(&self, function_name: &str) {
let mut inner = self.inner.write().await;
inner.insert(function_name.into(), RequestQueue::new());
debug!(
function_name,
"request stack initialized before compilation"
);
}
pub async fn upsert(&self, req: InvokeRequest) -> Result<Option<String>, ServerError> {
let mut inner = self.inner.write().await;
let function_name = req.function_name.clone();
match inner.entry(function_name.clone()) {
Entry::Vacant(v) => {
let stack = RequestQueue::new();
stack.push(req).await?;
v.insert(stack);
debug!(?function_name, "request stack initialized in first request");
Ok(Some(function_name))
}
Entry::Occupied(o) => {
o.into_mut().push(req).await?;
debug!(?function_name, "request stack increased");
Ok(None)
}
}
}
pub async fn pop(&self, function_name: &str) -> Option<InvokeRequest> {
let inner = self.inner.read().await;
let stack = match inner.get(function_name) {
None => {
drop(inner);
let mut inner = self.inner.write().await;
let stack = match inner.entry(function_name.to_owned()) {
Entry::Occupied(o) => o.into_mut().clone(),
Entry::Vacant(v) => {
let stack = v.insert(RequestQueue::new()).clone();
debug!(
?function_name,
"request stack initialized in first lambda connection"
);
stack
}
};
drop(inner);
stack
}
Some(s) => {
let stack = s.clone();
drop(inner);
stack
}
};
stack.pop().await
}
pub async fn keys(&self) -> Vec<String> {
let inner = self.inner.read().await;
inner.keys().cloned().collect()
}
pub async fn queue_depth(&self, function_name: &str) -> usize {
let inner = self.inner.read().await;
inner
.get(function_name)
.map(|queue| queue.depth())
.unwrap_or(0)
}
}
#[derive(Clone)]
pub(crate) struct ResponseCache {
inner: Arc<Mutex<HashMap<String, oneshot::Sender<LambdaResponse>>>>,
}
impl ResponseCache {
pub fn new() -> ResponseCache {
ResponseCache {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn pop(&self, req_id: &str) -> Option<oneshot::Sender<LambdaResponse>> {
let mut cache = self.inner.lock().await;
cache.remove(req_id)
}
pub async fn push(&self, req_id: &str, resp_tx: oneshot::Sender<LambdaResponse>) {
let mut cache = self.inner.lock().await;
cache.insert(req_id.into(), resp_tx);
}
}
#[derive(Clone, Default)]
pub(crate) struct ExtensionCache {
extensions: Arc<Mutex<HashMap<String, Vec<String>>>>,
events: Arc<Mutex<HashMap<String, Vec<String>>>>,
senders: Arc<Mutex<HashMap<String, mpsc::Sender<NextEvent>>>>,
has_internal_extension: Arc<AtomicBool>,
has_external_extension: Arc<AtomicBool>,
}
impl ExtensionCache {
pub async fn register(&self, events: Vec<String>, extension_type: ExtensionType) -> String {
let mut extensions = self.extensions.lock().await;
let extension_id = Uuid::new_v4();
extensions.insert(extension_id.to_string(), events.clone());
match extension_type {
ExtensionType::Internal => self.has_internal_extension.store(true, Ordering::Relaxed),
ExtensionType::External => self.has_external_extension.store(true, Ordering::Relaxed),
};
let mut list = self.events.lock().await;
for event in events {
list.entry(event)
.and_modify(|e| e.push(extension_id.to_string()))
.or_insert(vec![extension_id.to_string()]);
}
extension_id.to_string()
}
pub async fn set_senders(&self, extension_id: &str, sender: mpsc::Sender<NextEvent>) {
let extensions = self.extensions.lock().await;
if let Some(events) = extensions.get(extension_id) {
let mut senders = self.senders.lock().await;
for event in events {
let name = format!("{extension_id}_{event}");
senders.insert(name, sender.clone());
}
}
}
pub async fn send_event(&self, event: NextEvent) -> Result<(), ServerError> {
let events = self.events.lock().await;
let queue = event.type_queue();
if let Some(ids) = events.get(queue) {
let senders = self.senders.lock().await;
for id in ids {
let name = format!("{id}_{queue}");
if let Some(tx) = senders.get(&name) {
tx.send(event.clone())
.await
.map_err(|e| ServerError::SendEventMessage(Box::new(e)))?;
}
}
}
Ok(())
}
pub async fn clear(&self, extension_id: &str) {
let extensions = self.extensions.lock().await;
if let Some(events) = extensions.get(extension_id) {
let mut senders = self.senders.lock().await;
for event in events {
let name = format!("{extension_id}_{event}");
senders.remove(&name);
}
}
}
pub async fn function_shutdown_delay(&self) -> Option<Duration> {
if self.has_internal_extension.load(Ordering::Relaxed) {
Some(Duration::from_millis(500))
} else if self.has_external_extension.load(Ordering::Relaxed) {
Some(Duration::from_millis(300))
} else {
None
}
}
}
#[derive(Debug)]
pub enum ExtensionType {
Internal,
External,
}
#[derive(Clone)]
pub struct ConnectionTracker {
connections: Arc<RwLock<HashMap<SocketAddr, (String, InstanceId)>>>,
}
impl ConnectionTracker {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, peer: SocketAddr, function_name: String, instance_id: InstanceId) {
let mut connections = self.connections.write().await;
connections.insert(peer, (function_name.clone(), instance_id));
debug!(?peer, ?function_name, ?instance_id, "registered connection");
}
pub async fn get(&self, peer: &SocketAddr) -> Option<(String, InstanceId)> {
let connections = self.connections.read().await;
connections.get(peer).cloned()
}
}