use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use awaken_runtime::registry::ModelCapabilityPatch;
use awaken_server_contract::ProviderSpec;
use awaken_server_contract::contract::executor::LlmExecutor;
pub(super) type ProviderExecutorCache = HashMap<String, (ProviderSpec, Arc<dyn LlmExecutor>)>;
const CAPABILITY_SNAPSHOT_TTL: Duration = Duration::from_secs(12 * 60 * 60);
#[derive(Clone)]
struct CachedCapabilitySnapshot {
signature: String,
discovered_at: SystemTime,
capabilities: HashMap<String, ModelCapabilityPatch>,
}
impl CachedCapabilitySnapshot {
fn is_expired(&self, now: SystemTime, ttl: Duration) -> bool {
now.duration_since(self.discovered_at)
.map(|age| age > ttl)
.unwrap_or(false)
}
}
type ProviderCapabilityCache = HashMap<String, CachedCapabilitySnapshot>;
pub(super) struct StagedCapabilityCache {
cache: ProviderCapabilityCache,
pub(super) resolved: HashMap<String, HashMap<String, ModelCapabilityPatch>>,
}
#[derive(Default)]
pub(super) struct ProviderRuntimeCache {
executors: ProviderExecutorCache,
capabilities: ProviderCapabilityCache,
}
impl ProviderRuntimeCache {
pub(super) fn executor_snapshot(&self) -> ProviderExecutorCache {
self.executors.clone()
}
pub(super) fn replace_executors(&mut self, next: ProviderExecutorCache) {
self.executors = next;
}
#[cfg(test)]
pub(super) fn executor_provider(&self, provider_id: &str) -> Option<ProviderSpec> {
self.executors
.get(provider_id)
.map(|(provider, _)| provider.clone())
}
pub(super) fn stage_capability_snapshots(
&self,
providers: &[ProviderSpec],
discovered: HashMap<String, HashMap<String, ModelCapabilityPatch>>,
attempted: &HashSet<String>,
provider_signature: impl Fn(&ProviderSpec) -> String,
now: SystemTime,
) -> StagedCapabilityCache {
self.stage_capability_snapshots_with_ttl(
providers,
discovered,
attempted,
provider_signature,
now,
CAPABILITY_SNAPSHOT_TTL,
)
}
fn stage_capability_snapshots_with_ttl(
&self,
providers: &[ProviderSpec],
discovered: HashMap<String, HashMap<String, ModelCapabilityPatch>>,
attempted: &HashSet<String>,
provider_signature: impl Fn(&ProviderSpec) -> String,
now: SystemTime,
ttl: Duration,
) -> StagedCapabilityCache {
let signatures = providers
.iter()
.map(|provider| (provider.id.clone(), provider_signature(provider)))
.collect::<HashMap<_, _>>();
let mut staged: ProviderCapabilityCache = self
.capabilities
.iter()
.filter(|(provider_id, snapshot)| {
signatures
.get(*provider_id)
.is_some_and(|current| *current == snapshot.signature)
&& !snapshot.is_expired(now, ttl)
})
.map(|(provider_id, snapshot)| (provider_id.clone(), snapshot.clone()))
.collect();
let discovered_provider_ids = discovered.keys().cloned().collect::<HashSet<_>>();
for (provider_id, capabilities) in discovered {
let Some(signature) = signatures.get(&provider_id) else {
continue;
};
staged.insert(
provider_id,
CachedCapabilitySnapshot {
signature: signature.clone(),
discovered_at: now,
capabilities,
},
);
}
for provider_id in staged.keys() {
if !discovered_provider_ids.contains(provider_id) && attempted.contains(provider_id) {
tracing::warn!(
provider_id,
"using stale provider capability snapshot after discovery failure"
);
}
}
let resolved = staged
.iter()
.map(|(provider_id, snapshot)| (provider_id.clone(), snapshot.capabilities.clone()))
.collect();
StagedCapabilityCache {
cache: staged,
resolved,
}
}
pub(super) fn commit_capabilities(&mut self, staged: StagedCapabilityCache) {
self.capabilities = staged.cache;
}
#[cfg(test)]
fn update_capability_snapshots_with_ttl(
&mut self,
providers: &[ProviderSpec],
discovered: HashMap<String, HashMap<String, ModelCapabilityPatch>>,
provider_signature: impl Fn(&ProviderSpec) -> String,
now: SystemTime,
ttl: Duration,
) -> HashMap<String, HashMap<String, ModelCapabilityPatch>> {
let attempted: HashSet<String> = discovered.keys().cloned().collect();
let staged = self.stage_capability_snapshots_with_ttl(
providers,
discovered,
&attempted,
provider_signature,
now,
ttl,
);
let resolved = staged.resolved.clone();
self.commit_capabilities(staged);
resolved
}
#[cfg(test)]
fn update_capability_snapshots(
&mut self,
providers: &[ProviderSpec],
discovered: HashMap<String, HashMap<String, ModelCapabilityPatch>>,
provider_signature: impl Fn(&ProviderSpec) -> String,
now: SystemTime,
) -> HashMap<String, HashMap<String, ModelCapabilityPatch>> {
self.update_capability_snapshots_with_ttl(
providers,
discovered,
provider_signature,
now,
CAPABILITY_SNAPSHOT_TTL,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn signature(provider: &ProviderSpec) -> String {
provider.base_url.clone().unwrap_or_default()
}
fn patch(context_window: u32) -> ModelCapabilityPatch {
ModelCapabilityPatch {
context_window: Some(context_window),
max_output_tokens: None,
modalities: None,
knowledge_cutoff: None,
}
}
#[test]
fn staged_capability_snapshot_is_not_served_until_committed() {
let provider = ProviderSpec {
id: "p".into(),
adapter: "openai".into(),
base_url: Some("https://example.test/v1".into()),
..ProviderSpec::default()
};
let mut cache = ProviderRuntimeCache::default();
let now = SystemTime::UNIX_EPOCH;
let attempted = HashSet::from(["p".to_string()]);
let staged = cache.stage_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::from([(
"p".into(),
HashMap::from([("gpt-4o".into(), patch(128_000))]),
)]),
&attempted,
signature,
now,
);
assert!(staged.resolved.contains_key("p"));
let after_failed_publish = cache.stage_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::new(),
&attempted,
signature,
now + Duration::from_secs(60),
);
assert!(
after_failed_publish.resolved.is_empty(),
"an uncommitted (failed-publish) snapshot must not be served"
);
let committed = cache.stage_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::from([(
"p".into(),
HashMap::from([("gpt-4o".into(), patch(128_000))]),
)]),
&attempted,
signature,
now,
);
cache.commit_capabilities(committed);
let after_commit = cache.stage_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::new(),
&attempted,
signature,
now + Duration::from_secs(60),
);
assert_eq!(
after_commit.resolved["p"]["gpt-4o"].context_window,
Some(128_000)
);
}
#[test]
fn capability_snapshot_merge_keeps_cached_snapshot_on_discovery_failure() {
let provider = ProviderSpec {
id: "p".into(),
adapter: "openai".into(),
base_url: Some("https://example.test/v1".into()),
..ProviderSpec::default()
};
let mut cache = ProviderRuntimeCache::default();
let now = SystemTime::UNIX_EPOCH;
let first = cache.update_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::from([(
"p".into(),
HashMap::from([("gpt-4o".into(), patch(128_000))]),
)]),
signature,
now,
);
let second = cache.update_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::new(),
signature,
now + Duration::from_secs(60),
);
assert_eq!(first, second);
}
#[test]
fn capability_snapshot_expires_after_ttl_on_discovery_failure() {
let provider = ProviderSpec {
id: "p".into(),
adapter: "openai".into(),
base_url: Some("https://example.test/v1".into()),
..ProviderSpec::default()
};
let mut cache = ProviderRuntimeCache::default();
let ttl = Duration::from_secs(3_600);
let now = SystemTime::UNIX_EPOCH;
let first = cache.update_capability_snapshots_with_ttl(
std::slice::from_ref(&provider),
HashMap::from([(
"p".into(),
HashMap::from([("gpt-4o".into(), patch(128_000))]),
)]),
signature,
now,
ttl,
);
assert!(!first.is_empty());
let expired = cache.update_capability_snapshots_with_ttl(
std::slice::from_ref(&provider),
HashMap::new(),
signature,
now + ttl + Duration::from_secs(1),
ttl,
);
assert!(expired.is_empty());
}
#[test]
fn capability_snapshot_within_ttl_is_still_served() {
let provider = ProviderSpec {
id: "p".into(),
adapter: "openai".into(),
base_url: Some("https://example.test/v1".into()),
..ProviderSpec::default()
};
let mut cache = ProviderRuntimeCache::default();
let ttl = Duration::from_secs(3_600);
let now = SystemTime::UNIX_EPOCH;
cache.update_capability_snapshots_with_ttl(
std::slice::from_ref(&provider),
HashMap::from([(
"p".into(),
HashMap::from([("gpt-4o".into(), patch(128_000))]),
)]),
signature,
now,
ttl,
);
let still_fresh = cache.update_capability_snapshots_with_ttl(
std::slice::from_ref(&provider),
HashMap::new(),
signature,
now + ttl,
ttl,
);
assert_eq!(still_fresh["p"]["gpt-4o"].context_window, Some(128_000));
}
#[test]
fn capability_snapshot_empty_success_replaces_cached_snapshot() {
let provider = ProviderSpec {
id: "p".into(),
adapter: "openai".into(),
base_url: Some("https://example.test/v1".into()),
..ProviderSpec::default()
};
let mut cache = ProviderRuntimeCache::default();
let now = SystemTime::UNIX_EPOCH;
cache.update_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::from([(
"p".into(),
HashMap::from([("gpt-4o".into(), patch(128_000))]),
)]),
signature,
now,
);
let refreshed = cache.update_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::from([("p".into(), HashMap::new())]),
signature,
now + Duration::from_secs(60),
);
assert_eq!(refreshed.get("p"), Some(&HashMap::new()));
}
#[test]
fn capability_snapshot_merge_drops_cached_snapshot_after_provider_change() {
let provider = ProviderSpec {
id: "p".into(),
adapter: "openai".into(),
base_url: Some("https://example.test/v1".into()),
..ProviderSpec::default()
};
let changed = ProviderSpec {
base_url: Some("https://other.example.test/v1".into()),
..provider.clone()
};
let mut cache = ProviderRuntimeCache::default();
let now = SystemTime::UNIX_EPOCH;
cache.update_capability_snapshots(
std::slice::from_ref(&provider),
HashMap::from([(
"p".into(),
HashMap::from([("gpt-4o".into(), patch(128_000))]),
)]),
signature,
now,
);
let merged = cache.update_capability_snapshots(
std::slice::from_ref(&changed),
HashMap::new(),
signature,
now,
);
assert!(merged.is_empty());
}
}