#![cfg_attr(not(test), no_std)]
use core::any::TypeId;
use core::fmt;
use core::marker::{PhantomData, PhantomPinned};
use core::pin::Pin;
struct ReqRef<T: ?Sized + 'static>(&'static T);
struct ReqVal<T: 'static>(T);
#[repr(C)]
pub struct Request<'a> {
type_id: TypeId,
_pinned: PhantomPinned,
_marker: PhantomData<&'a ()>,
}
impl<'a> Request<'a> {
pub fn provide_ref<T: ?Sized + 'static>(self: Pin<&mut Self>, value: &'a T) -> Pin<&mut Self> {
self.provide_ref_with(|| value)
}
pub fn provide_ref_with<T: ?Sized + 'static, F>(
mut self: Pin<&mut Self>,
cb: F,
) -> Pin<&mut Self>
where
F: FnOnce() -> &'a T,
{
if self.is_ref::<T>() {
unsafe {
*self.as_mut().downcast_unchecked::<&'a T>() = Some(cb());
}
}
self
}
pub fn provide_value<T: 'static>(self: Pin<&mut Self>, value: T) -> Pin<&mut Self> {
self.provide_value_with(|| value)
}
pub fn provide_value_with<T: 'static, F>(mut self: Pin<&mut Self>, cb: F) -> Pin<&mut Self>
where
F: FnOnce() -> T,
{
if self.is_value::<T>() {
unsafe {
*self.as_mut().downcast_unchecked::<T>() = Some(cb());
}
}
self
}
pub fn is_ref<T: ?Sized + 'static>(&self) -> bool {
self.type_id == TypeId::of::<ReqRef<T>>()
}
pub fn is_value<T: 'static>(&self) -> bool {
self.type_id == TypeId::of::<ReqVal<T>>()
}
unsafe fn downcast_unchecked<T>(self: Pin<&mut Self>) -> &mut Option<T> {
let ptr = self.get_unchecked_mut() as *mut Self as *mut RequestBuf<'a, T>;
&mut (*ptr).value
}
pub fn request_ref<T: ?Sized + 'static, F>(f: F) -> Option<&'a T>
where
F: FnOnce(Pin<&mut Request<'a>>),
{
let mut buf = RequestBuf::for_ref();
let mut pinned = unsafe { Pin::new_unchecked(&mut buf) };
f(pinned.as_mut().request());
pinned.take()
}
pub fn request_value<T: 'static, F>(f: F) -> Option<T>
where
F: FnOnce(Pin<&mut Request<'a>>),
{
let mut buf = RequestBuf::for_value();
let mut pinned = unsafe { Pin::new_unchecked(&mut buf) };
f(pinned.as_mut().request());
pinned.take()
}
}
impl<'a> fmt::Debug for Request<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Request")
.field("type_id", &self.type_id)
.finish()
}
}
#[repr(C)]
#[derive(Debug)]
struct RequestBuf<'a, T> {
request: Request<'a>,
value: Option<T>,
}
impl<'a, T: ?Sized + 'static> RequestBuf<'a, &'a T> {
fn for_ref() -> Self {
unsafe { Self::new_internal(TypeId::of::<ReqRef<T>>()) }
}
}
impl<'a, T: 'static> RequestBuf<'a, T> {
fn for_value() -> Self {
unsafe { Self::new_internal(TypeId::of::<ReqVal<T>>()) }
}
}
impl<'a, T> RequestBuf<'a, T> {
unsafe fn new_internal(type_id: TypeId) -> Self {
RequestBuf {
request: Request {
type_id,
_pinned: PhantomPinned,
_marker: PhantomData,
},
value: None,
}
}
fn request(self: Pin<&mut Self>) -> Pin<&mut Request<'a>> {
unsafe { self.map_unchecked_mut(|this| &mut this.request) }
}
fn take(self: Pin<&mut Self>) -> Option<T> {
unsafe { self.get_unchecked_mut().value.take() }
}
}
pub trait ObjectProvider {
fn provide<'a>(&'a self, request: Pin<&mut Request<'a>>);
}
pub trait ObjectProviderExt {
fn request_ref<T: ?Sized + 'static>(&self) -> Option<&T>;
fn request_value<T: 'static>(&self) -> Option<T>;
}
impl<O: ?Sized + ObjectProvider> ObjectProviderExt for O {
fn request_ref<T: ?Sized + 'static>(&self) -> Option<&T> {
Request::request_ref::<T, _>(|req| self.provide(req))
}
fn request_value<T: 'static>(&self) -> Option<T> {
Request::request_value::<T, _>(|req| self.provide(req))
}
}
#[cfg(test)]
mod test {
use super::*;
use std::path::{Path, PathBuf};
#[test]
fn basic_context() {
struct HasContext {
int: i32,
path: PathBuf,
}
impl ObjectProvider for HasContext {
fn provide<'a>(&'a self, request: Pin<&mut Request<'a>>) {
request
.provide_ref::<i32>(&self.int)
.provide_ref::<Path>(&self.path)
.provide_ref::<dyn fmt::Display>(&self.int)
.provide_value::<i32>(self.int);
}
}
let provider: &dyn ObjectProvider = &HasContext {
int: 10,
path: PathBuf::new(),
};
assert_eq!(provider.request_ref::<i32>(), Some(&10));
assert_eq!(provider.request_value::<i32>(), Some(10));
assert!(provider.request_ref::<u32>().is_none());
assert_eq!(
provider
.request_ref::<dyn fmt::Display>()
.map(|d| d.to_string()),
Some("10".to_owned())
);
assert!(provider.request_ref::<dyn fmt::Debug>().is_none());
assert_eq!(provider.request_ref::<Path>(), Some(Path::new("")));
}
}