use super::*;
use std::cell::RefCell;
use std::collections::HashSet;
use std::fmt;
use std::ops::Deref;
use std::rc::Rc;
pub struct StateHandle<T: 'static>(Rc<RefCell<SignalInner<T>>>);
impl<T: 'static> StateHandle<T> {
pub fn get(&self) -> Rc<T> {
CONTEXTS.with(|contexts| {
if let Some(last_context) = contexts.borrow().last() {
let signal = Rc::downgrade(&self.0);
last_context
.upgrade()
.expect("Running should be valid while inside reactive scope")
.borrow_mut()
.as_mut()
.unwrap()
.dependencies
.insert(Dependency(signal));
}
});
self.get_untracked()
}
pub fn get_untracked(&self) -> Rc<T> {
Rc::clone(&self.0.borrow().inner)
}
}
impl<T: 'static> Clone for StateHandle<T> {
fn clone(&self) -> Self {
Self(Rc::clone(&self.0))
}
}
impl<T: fmt::Debug> fmt::Debug for StateHandle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("StateHandle")
.field(&self.get_untracked())
.finish()
}
}
#[cfg(feature = "serde")]
impl<T: serde::Serialize> serde::Serialize for StateHandle<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.get_untracked().as_ref().serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, T: serde::Deserialize<'de>> serde::Deserialize<'de> for StateHandle<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Signal::new(T::deserialize(deserializer)?).handle())
}
}
pub struct Signal<T: 'static> {
handle: StateHandle<T>,
}
impl<T: 'static> Signal<T> {
pub fn new(value: T) -> Self {
Self {
handle: StateHandle(Rc::new(RefCell::new(SignalInner::new(value)))),
}
}
pub fn set(&self, new_value: T) {
self.handle.0.borrow_mut().update(new_value);
self.trigger_subscribers();
}
pub fn handle(&self) -> StateHandle<T> {
self.handle.clone()
}
pub fn into_handle(self) -> StateHandle<T> {
self.handle
}
pub fn trigger_subscribers(&self) {
let subscribers = self.handle.0.borrow().subscribers.clone();
for subscriber in subscribers {
if let Some(callback) = subscriber.try_callback() {
callback()
}
}
}
}
impl<T: 'static> Deref for Signal<T> {
type Target = StateHandle<T>;
fn deref(&self) -> &Self::Target {
&self.handle
}
}
impl<T: 'static> Clone for Signal<T> {
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
}
}
}
impl<T: PartialEq> PartialEq for Signal<T> {
fn eq(&self, other: &Signal<T>) -> bool {
self.get_untracked().eq(&other.get_untracked())
}
}
impl<T: fmt::Debug> fmt::Debug for Signal<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Signal")
.field(&self.get_untracked())
.finish()
}
}
#[cfg(feature = "serde")]
impl<T: serde::Serialize> serde::Serialize for Signal<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.get_untracked().as_ref().serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, T: serde::Deserialize<'de>> serde::Deserialize<'de> for Signal<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Signal::new(T::deserialize(deserializer)?))
}
}
pub(super) struct SignalInner<T> {
inner: Rc<T>,
subscribers: HashSet<Callback>,
}
impl<T> SignalInner<T> {
fn new(value: T) -> Self {
Self {
inner: Rc::new(value),
subscribers: HashSet::new(),
}
}
fn subscribe(&mut self, handler: Callback) {
self.subscribers.insert(handler);
}
fn unsubscribe(&mut self, handler: &Callback) {
self.subscribers.remove(handler);
}
fn update(&mut self, new_value: T) {
self.inner = Rc::new(new_value);
}
}
pub(super) trait AnySignalInner {
fn subscribe(&self, handler: Callback);
fn unsubscribe(&self, handler: &Callback);
}
impl<T> AnySignalInner for RefCell<SignalInner<T>> {
fn subscribe(&self, handler: Callback) {
self.borrow_mut().subscribe(handler);
}
fn unsubscribe(&self, handler: &Callback) {
self.borrow_mut().unsubscribe(handler);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn signals() {
let state = Signal::new(0);
assert_eq!(*state.get(), 0);
state.set(1);
assert_eq!(*state.get(), 1);
}
#[test]
fn signal_composition() {
let state = Signal::new(0);
let double = || *state.get() * 2;
assert_eq!(double(), 0);
state.set(1);
assert_eq!(double(), 2);
}
#[test]
fn state_handle() {
let state = Signal::new(0);
let readonly = state.handle();
assert_eq!(*readonly.get(), 0);
state.set(1);
assert_eq!(*readonly.get(), 1);
}
}