use dashmap::DashMap;
use riglr_config::Config;
use std::any::{Any, TypeId};
use std::sync::Arc;
use std::time::Duration;
use crate::util::RateLimiter;
#[derive(Clone, Debug)]
pub struct ApplicationContext {
pub config: Config,
pub rate_limiter: Arc<RateLimiter>,
extensions: Arc<DashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
}
impl ApplicationContext {
pub fn from_config(config: &Config) -> Self {
let rate_limiter = Arc::new(RateLimiter::new(100, Duration::from_secs(60)));
Self {
config: config.clone(),
rate_limiter,
extensions: Arc::new(DashMap::new()),
}
}
#[deprecated(
since = "0.3.0",
note = "Use Config::from_env() followed by ApplicationContext::from_config() instead. This ensures proper separation of concerns."
)]
pub fn from_env() -> Self {
let config = Config::from_env();
Self::from_config(&config)
}
pub fn set_extension<T: Send + Sync + 'static>(&self, extension: Arc<T>) {
self.extensions.insert(TypeId::of::<T>(), extension);
}
pub fn get_extension<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.extensions
.get(&TypeId::of::<T>())
.and_then(|ext| ext.clone().downcast::<T>().ok())
}
pub fn has_extension<T: Send + Sync + 'static>(&self) -> bool {
self.extensions.contains_key(&TypeId::of::<T>())
}
pub fn remove_extension<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.extensions
.remove(&TypeId::of::<T>())
.and_then(|(_, ext)| ext.downcast::<T>().ok())
}
pub fn clear_extensions(&self) {
self.extensions.clear();
}
pub fn extension_count(&self) -> usize {
self.extensions.len()
}
}
impl Default for ApplicationContext {
fn default() -> Self {
let config = riglr_config::ConfigBuilder::new()
.build()
.expect("Default config should be valid");
let rate_limiter = Arc::new(RateLimiter::new(100, Duration::from_secs(60)));
Self {
config,
rate_limiter,
extensions: Arc::new(DashMap::new()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct TestResource {
value: String,
}
#[test]
fn test_application_context_extensions() {
let context = ApplicationContext::default();
let resource = Arc::new(TestResource {
value: "test".to_string(),
});
context.set_extension(resource.clone());
let retrieved: Arc<TestResource> = context.get_extension().expect("Resource not found");
assert_eq!(retrieved.value, "test");
}
#[test]
fn test_application_context_multiple_extensions() {
let context = ApplicationContext::default();
let resource1 = Arc::new(TestResource {
value: "test1".to_string(),
});
let resource2 = Arc::new(42u32);
context.set_extension(resource1.clone());
context.set_extension(resource2.clone());
let retrieved1: Arc<TestResource> = context.get_extension().expect("Resource not found");
let retrieved2: Arc<u32> = context.get_extension().expect("u32 not found");
assert_eq!(retrieved1.value, "test1");
assert_eq!(*retrieved2, 42);
}
#[test]
fn test_application_context_has_extension() {
let context = ApplicationContext::default();
assert!(!context.has_extension::<TestResource>());
let resource = Arc::new(TestResource {
value: "test".to_string(),
});
context.set_extension(resource);
assert!(context.has_extension::<TestResource>());
}
#[test]
fn test_application_context_remove_extension() {
let context = ApplicationContext::default();
let resource = Arc::new(TestResource {
value: "test".to_string(),
});
context.set_extension(resource);
assert!(context.has_extension::<TestResource>());
let removed: Option<Arc<TestResource>> = context.remove_extension();
assert!(removed.is_some());
assert!(!context.has_extension::<TestResource>());
}
}