use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Default, Clone)]
pub struct Extensions {
map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl Extensions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
self.map.insert(TypeId::of::<T>(), Arc::new(value));
}
#[must_use]
pub fn get<T: 'static>(&self) -> Option<&T> {
self.map
.get(&TypeId::of::<T>())
.and_then(|v| v.downcast_ref::<T>())
}
#[must_use]
pub fn contains<T: 'static>(&self) -> bool {
self.map.contains_key(&TypeId::of::<T>())
}
pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<Arc<T>> {
self.map
.remove(&TypeId::of::<T>())
.and_then(|v| Arc::downcast::<T>(v).ok())
}
}
impl std::fmt::Debug for Extensions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Extensions")
.field("count", &self.map.len())
.finish()
}
}
#[derive(Clone)]
pub struct Context {
deadline: Option<Instant>,
extensions: Extensions,
transport: Option<Arc<dyn Any + Send + Sync>>,
cancelled: bool,
}
impl Default for Context {
fn default() -> Self {
Self::new()
}
}
impl Context {
#[must_use]
pub fn new() -> Self {
Self {
deadline: None,
extensions: Extensions::new(),
transport: None,
cancelled: false,
}
}
#[must_use]
pub fn with_deadline(deadline: Instant) -> Self {
Self {
deadline: Some(deadline),
extensions: Extensions::new(),
transport: None,
cancelled: false,
}
}
#[must_use]
pub fn with_timeout(timeout: Duration) -> Self {
Self::with_deadline(Instant::now() + timeout)
}
#[must_use]
pub fn deadline(&self) -> Option<Instant> {
self.deadline
}
#[must_use]
pub fn time_remaining(&self) -> Option<Duration> {
self.deadline
.and_then(|d| d.checked_duration_since(Instant::now()))
}
#[must_use]
pub fn is_done(&self) -> bool {
if self.cancelled {
return true;
}
match self.deadline {
Some(d) => Instant::now() >= d,
None => false,
}
}
#[must_use]
pub fn get<T: 'static>(&self) -> Option<&T> {
self.extensions.get::<T>()
}
#[must_use]
pub fn with_value<T: Send + Sync + 'static>(mut self, value: T) -> Self {
self.extensions.insert(value);
self
}
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
#[must_use]
pub fn extensions(&self) -> &Extensions {
&self.extensions
}
#[must_use]
pub fn with_transport<T: Send + Sync + 'static>(mut self, transport: T) -> Self {
self.transport = Some(Arc::new(transport));
self
}
#[must_use]
pub fn has_transport(&self) -> bool {
self.transport.is_some()
}
pub fn cancel(&mut self) {
self.cancelled = true;
}
#[must_use]
pub fn with_shorter_deadline(mut self, deadline: Instant) -> Self {
self.deadline = match self.deadline {
Some(existing) if existing < deadline => Some(existing),
_ => Some(deadline),
};
self
}
#[must_use]
pub fn with_shorter_timeout(self, timeout: Duration) -> Self {
self.with_shorter_deadline(Instant::now() + timeout)
}
}
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("deadline", &self.deadline)
.field("extensions", &self.extensions)
.field("has_transport", &self.transport.is_some())
.field("cancelled", &self.cancelled)
.finish()
}
}
pub trait TransportAs {
fn transport_as<T: 'static>(&self) -> Option<&T>;
}
impl TransportAs for Context {
fn transport_as<T: 'static>(&self) -> Option<&T> {
self.transport.as_ref().and_then(|t| t.downcast_ref::<T>())
}
}
#[derive(Default)]
pub struct ContextBuilder {
deadline: Option<Instant>,
extensions: Extensions,
transport: Option<Arc<dyn Any + Send + Sync>>,
}
impl ContextBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_deadline(mut self, deadline: Instant) -> Self {
self.deadline = Some(deadline);
self
}
#[must_use]
pub fn with_timeout(self, timeout: Duration) -> Self {
self.with_deadline(Instant::now() + timeout)
}
#[must_use]
pub fn with_value<T: Send + Sync + 'static>(mut self, value: T) -> Self {
self.extensions.insert(value);
self
}
#[must_use]
pub fn with_transport<T: Send + Sync + 'static>(mut self, transport: T) -> Self {
self.transport = Some(Arc::new(transport));
self
}
#[must_use]
pub fn build(self) -> Context {
Context {
deadline: self.deadline,
extensions: self.extensions,
transport: self.transport,
cancelled: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extensions_insert_and_get() {
let mut ext = Extensions::new();
ext.insert(42u32);
ext.insert("hello".to_string());
assert_eq!(ext.get::<u32>(), Some(&42));
assert_eq!(ext.get::<String>(), Some(&"hello".to_string()));
assert!(ext.get::<i64>().is_none());
}
#[test]
fn extensions_contains() {
let mut ext = Extensions::new();
ext.insert(42u32);
assert!(ext.contains::<u32>());
assert!(!ext.contains::<String>());
}
#[test]
fn context_deadline() {
let deadline = Instant::now() + Duration::from_secs(60);
let ctx = Context::with_deadline(deadline);
assert_eq!(ctx.deadline(), Some(deadline));
assert!(!ctx.is_done());
}
#[test]
fn context_timeout() {
let ctx = Context::with_timeout(Duration::from_secs(60));
assert!(ctx.deadline().is_some());
assert!(ctx.time_remaining().is_some());
assert!(!ctx.is_done());
}
#[test]
fn context_expired() {
let ctx = Context::with_timeout(Duration::from_nanos(1));
std::thread::sleep(Duration::from_millis(1));
assert!(ctx.is_done());
}
#[test]
fn context_cancelled() {
let mut ctx = Context::new();
assert!(!ctx.is_done());
ctx.cancel();
assert!(ctx.is_done());
}
#[test]
fn context_with_value() {
let ctx = Context::new()
.with_value(42u32)
.with_value("test".to_string());
assert_eq!(ctx.get::<u32>(), Some(&42));
assert_eq!(ctx.get::<String>(), Some(&"test".to_string()));
}
#[test]
fn context_transport() {
struct TestTransport {
id: u32,
}
let ctx = Context::new().with_transport(TestTransport { id: 123 });
assert!(ctx.has_transport());
let transport = ctx.transport_as::<TestTransport>();
assert!(transport.is_some());
assert_eq!(transport.map(|t| t.id), Some(123));
}
#[test]
fn context_shorter_deadline() {
let far = Instant::now() + Duration::from_secs(60);
let near = Instant::now() + Duration::from_secs(10);
let ctx = Context::with_deadline(far);
let child = ctx.with_shorter_deadline(near);
assert_eq!(child.deadline(), Some(near));
let extended = Context::with_deadline(near).with_shorter_deadline(far);
assert_eq!(extended.deadline(), Some(near));
}
#[test]
fn context_builder() {
let ctx = ContextBuilder::new()
.with_timeout(Duration::from_secs(30))
.with_value(100i32)
.build();
assert!(ctx.deadline().is_some());
assert_eq!(ctx.get::<i32>(), Some(&100));
}
}