use chrono::{DateTime, Utc};
use std::collections::{HashMap, HashSet};
use tokio::sync::{Mutex, Notify};
#[derive(Debug)]
struct InvocationReadiness {
expected_extensions: HashSet<String>,
ready_extensions: HashSet<String>,
pending_ready: HashSet<String>,
runtime_done_at: Option<DateTime<Utc>>,
extensions_ready_at: Option<DateTime<Utc>>,
}
impl InvocationReadiness {
fn new(expected_extensions: HashSet<String>) -> Self {
Self {
expected_extensions,
ready_extensions: HashSet::new(),
pending_ready: HashSet::new(),
runtime_done_at: None,
extensions_ready_at: None,
}
}
fn all_ready(&self) -> bool {
self.expected_extensions.is_subset(&self.ready_extensions)
}
fn extension_overhead_ms(&self) -> Option<f64> {
match (self.runtime_done_at, self.extensions_ready_at) {
(Some(done), Some(ready)) => Some((ready - done).num_milliseconds() as f64),
_ => None,
}
}
}
#[derive(Debug)]
pub struct ExtensionReadinessTracker {
invocations: Mutex<HashMap<String, InvocationReadiness>>,
readiness_changed: Notify,
current_request: Mutex<Option<String>>,
last_completed_request: Mutex<Option<String>>,
}
impl ExtensionReadinessTracker {
pub fn new() -> Self {
Self {
invocations: Mutex::new(HashMap::new()),
readiness_changed: Notify::new(),
current_request: Mutex::new(None),
last_completed_request: Mutex::new(None),
}
}
pub async fn start_invocation(&self, request_id: &str, invoke_extension_ids: Vec<String>) {
let expected: HashSet<String> = invoke_extension_ids.into_iter().collect();
let readiness = InvocationReadiness::new(expected);
self.invocations
.lock()
.await
.insert(request_id.to_string(), readiness);
*self.current_request.lock().await = Some(request_id.to_string());
}
pub async fn mark_runtime_done(&self, request_id: &str) {
let mut invocations = self.invocations.lock().await;
if let Some(readiness) = invocations.get_mut(request_id) {
readiness.runtime_done_at = Some(Utc::now());
for ext in readiness.pending_ready.drain() {
readiness.ready_extensions.insert(ext);
}
if readiness.all_ready() {
readiness.extensions_ready_at = Some(Utc::now());
self.readiness_changed.notify_waiters();
}
}
*self.last_completed_request.lock().await = Some(request_id.to_string());
self.readiness_changed.notify_waiters();
}
pub async fn mark_extension_ready(&self, extension_id: &str) -> bool {
let current_request = self.current_request.lock().await.clone();
let last_request = self.last_completed_request.lock().await.clone();
let mut invocations = self.invocations.lock().await;
if let Some(ref request_id) = last_request
&& let Some(readiness) = invocations.get_mut(request_id)
&& readiness.runtime_done_at.is_some()
&& !readiness.all_ready()
{
readiness.ready_extensions.insert(extension_id.to_string());
if readiness.all_ready() {
readiness.extensions_ready_at = Some(Utc::now());
self.readiness_changed.notify_waiters();
}
return true;
}
if let Some(ref request_id) = current_request
&& let Some(readiness) = invocations.get_mut(request_id)
&& readiness.runtime_done_at.is_none()
&& readiness.expected_extensions.contains(extension_id)
{
readiness.pending_ready.insert(extension_id.to_string());
return true;
}
false
}
pub async fn is_all_ready(&self, request_id: &str) -> bool {
let invocations = self.invocations.lock().await;
invocations.get(request_id).is_none_or(|r| r.all_ready())
}
pub async fn wait_for_all_ready(&self, request_id: &str) {
loop {
if self.is_all_ready(request_id).await {
return;
}
self.readiness_changed.notified().await;
}
}
pub async fn get_extension_overhead_ms(&self, request_id: &str) -> Option<f64> {
let invocations = self.invocations.lock().await;
invocations
.get(request_id)
.and_then(|r| r.extension_overhead_ms())
}
#[allow(dead_code)]
pub async fn get_ready_extensions(&self, request_id: &str) -> Vec<String> {
let invocations = self.invocations.lock().await;
invocations
.get(request_id)
.map(|r| r.ready_extensions.iter().cloned().collect())
.unwrap_or_default()
}
#[allow(dead_code)]
pub async fn get_pending_extensions(&self, request_id: &str) -> Vec<String> {
let invocations = self.invocations.lock().await;
invocations
.get(request_id)
.map(|r| {
r.expected_extensions
.difference(&r.ready_extensions)
.cloned()
.collect()
})
.unwrap_or_default()
}
pub async fn cleanup_invocation(&self, request_id: &str) {
self.invocations.lock().await.remove(request_id);
}
#[allow(dead_code)]
pub(crate) fn notify_readiness_changed(&self) {
self.readiness_changed.notify_waiters();
}
}
impl Default for ExtensionReadinessTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_no_extensions_is_immediately_ready() {
let tracker = ExtensionReadinessTracker::new();
tracker.start_invocation("req-1", vec![]).await;
tracker.mark_runtime_done("req-1").await;
assert!(tracker.is_all_ready("req-1").await);
}
#[tokio::test]
async fn test_wait_for_all_ready_returns_immediately_with_no_extensions() {
let tracker = ExtensionReadinessTracker::new();
tracker.start_invocation("req-1", vec![]).await;
tracker.mark_runtime_done("req-1").await;
let result = tokio::time::timeout(std::time::Duration::from_millis(100), async {
tracker.wait_for_all_ready("req-1").await;
})
.await;
assert!(
result.is_ok(),
"wait_for_all_ready should return immediately with no extensions"
);
}
#[tokio::test]
async fn test_single_extension_readiness() {
let tracker = ExtensionReadinessTracker::new();
tracker
.start_invocation("req-1", vec!["ext-1".to_string()])
.await;
tracker.mark_runtime_done("req-1").await;
assert!(!tracker.is_all_ready("req-1").await);
tracker.mark_extension_ready("ext-1").await;
assert!(tracker.is_all_ready("req-1").await);
}
#[tokio::test]
async fn test_multiple_extensions_readiness() {
let tracker = ExtensionReadinessTracker::new();
tracker
.start_invocation(
"req-1",
vec![
"ext-1".to_string(),
"ext-2".to_string(),
"ext-3".to_string(),
],
)
.await;
tracker.mark_runtime_done("req-1").await;
assert!(!tracker.is_all_ready("req-1").await);
tracker.mark_extension_ready("ext-1").await;
assert!(!tracker.is_all_ready("req-1").await);
tracker.mark_extension_ready("ext-2").await;
assert!(!tracker.is_all_ready("req-1").await);
tracker.mark_extension_ready("ext-3").await;
assert!(tracker.is_all_ready("req-1").await);
}
#[tokio::test]
async fn test_pending_extensions() {
let tracker = ExtensionReadinessTracker::new();
tracker
.start_invocation("req-1", vec!["ext-1".to_string(), "ext-2".to_string()])
.await;
tracker.mark_runtime_done("req-1").await;
let pending = tracker.get_pending_extensions("req-1").await;
assert_eq!(pending.len(), 2);
assert!(pending.contains(&"ext-1".to_string()));
assert!(pending.contains(&"ext-2".to_string()));
tracker.mark_extension_ready("ext-1").await;
let pending = tracker.get_pending_extensions("req-1").await;
assert_eq!(pending.len(), 1);
assert!(pending.contains(&"ext-2".to_string()));
let ready = tracker.get_ready_extensions("req-1").await;
assert_eq!(ready.len(), 1);
assert!(ready.contains(&"ext-1".to_string()));
}
#[tokio::test]
async fn test_extension_overhead_calculation() {
let tracker = ExtensionReadinessTracker::new();
tracker
.start_invocation("req-1", vec!["ext-1".to_string()])
.await;
tracker.mark_runtime_done("req-1").await;
assert!(tracker.get_extension_overhead_ms("req-1").await.is_none());
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
tracker.mark_extension_ready("ext-1").await;
let overhead = tracker.get_extension_overhead_ms("req-1").await;
assert!(overhead.is_some());
assert!(overhead.unwrap() >= 10.0);
}
#[tokio::test]
async fn test_cleanup_invocation() {
let tracker = ExtensionReadinessTracker::new();
tracker
.start_invocation("req-1", vec!["ext-1".to_string()])
.await;
assert!(!tracker.is_all_ready("req-1").await);
tracker.cleanup_invocation("req-1").await;
assert!(tracker.is_all_ready("req-1").await);
}
#[tokio::test]
async fn test_unknown_request_is_ready() {
let tracker = ExtensionReadinessTracker::new();
assert!(tracker.is_all_ready("nonexistent").await);
}
}