use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::binding::{BindingValue, UiBindable};
#[derive(Debug)]
pub struct SharedContext<S>
where
S: UiBindable + Send + Sync + 'static,
{
state: Arc<RwLock<S>>,
}
impl<S> SharedContext<S>
where
S: UiBindable + Send + Sync + 'static,
{
pub fn new(initial: S) -> Self {
Self {
state: Arc::new(RwLock::new(initial)),
}
}
#[allow(clippy::expect_used)]
pub fn read(&self) -> RwLockReadGuard<'_, S> {
self.state.read().expect("SharedContext lock poisoned")
}
#[allow(clippy::expect_used)]
pub fn write(&self) -> RwLockWriteGuard<'_, S> {
self.state.write().expect("SharedContext lock poisoned")
}
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, S>> {
self.state.try_read().ok()
}
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, S>> {
self.state.try_write().ok()
}
}
impl<S> Clone for SharedContext<S>
where
S: UiBindable + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
}
}
}
impl<S> UiBindable for SharedContext<S>
where
S: UiBindable + Send + Sync + 'static,
{
fn get_field(&self, path: &[&str]) -> Option<BindingValue> {
let guard = self.read();
guard.get_field(path)
}
fn available_fields() -> Vec<String> {
S::available_fields()
}
}
impl SharedContext<()> {
pub fn empty() -> Self {
Self::new(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::BindingValue;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
#[derive(Default, Clone)]
struct TestState {
counter: i32,
name: String,
}
impl UiBindable for TestState {
fn get_field(&self, path: &[&str]) -> Option<BindingValue> {
match path {
["counter"] => Some(BindingValue::Integer(self.counter as i64)),
["name"] => Some(BindingValue::String(self.name.clone())),
_ => None,
}
}
fn available_fields() -> Vec<String> {
vec!["counter".to_string(), "name".to_string()]
}
}
#[test]
fn test_shared_context_new() {
let ctx = SharedContext::new(TestState {
counter: 42,
name: "test".to_string(),
});
assert_eq!(ctx.read().counter, 42);
assert_eq!(ctx.read().name, "test");
}
#[test]
fn test_shared_context_read() {
let ctx = SharedContext::new(TestState {
counter: 10,
name: "hello".to_string(),
});
let guard = ctx.read();
assert_eq!(guard.counter, 10);
assert_eq!(guard.name, "hello");
}
#[test]
fn test_shared_context_write() {
let ctx = SharedContext::new(TestState::default());
{
let mut guard = ctx.write();
guard.counter = 100;
guard.name = "updated".to_string();
}
assert_eq!(ctx.read().counter, 100);
assert_eq!(ctx.read().name, "updated");
}
#[test]
fn test_shared_context_try_read() {
let ctx = SharedContext::new(TestState {
counter: 5,
..Default::default()
});
let guard = ctx.try_read();
assert!(guard.is_some());
assert_eq!(guard.unwrap().counter, 5);
}
#[test]
fn test_shared_context_try_write() {
let ctx = SharedContext::new(TestState::default());
let guard = ctx.try_write();
assert!(guard.is_some());
}
#[test]
fn test_shared_context_clone_shares_state() {
let ctx1 = SharedContext::new(TestState {
counter: 0,
name: "original".to_string(),
});
let ctx2 = ctx1.clone();
ctx1.write().counter = 42;
ctx1.write().name = "modified".to_string();
assert_eq!(ctx2.read().counter, 42);
assert_eq!(ctx2.read().name, "modified");
ctx2.write().counter = 100;
assert_eq!(ctx1.read().counter, 100);
}
#[test]
fn test_shared_context_multiple_clones() {
let original = SharedContext::new(TestState {
counter: 1,
..Default::default()
});
let clone1 = original.clone();
let clone2 = original.clone();
let clone3 = clone1.clone();
original.write().counter = 999;
assert_eq!(clone1.read().counter, 999);
assert_eq!(clone2.read().counter, 999);
assert_eq!(clone3.read().counter, 999);
}
#[test]
fn test_shared_context_concurrent_reads() {
let ctx = SharedContext::new(TestState {
counter: 42,
name: "concurrent".to_string(),
});
let read_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let ctx_clone = ctx.clone();
let count = Arc::clone(&read_count);
let handle = thread::spawn(move || {
let guard = ctx_clone.read();
assert_eq!(guard.counter, 42);
assert_eq!(guard.name, "concurrent");
count.fetch_add(1, Ordering::SeqCst);
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(read_count.load(Ordering::SeqCst), 10);
}
#[test]
fn test_shared_context_concurrent_writes() {
let ctx = SharedContext::new(TestState::default());
let mut handles = vec![];
for i in 0..10 {
let ctx_clone = ctx.clone();
let handle = thread::spawn(move || {
let mut guard = ctx_clone.write();
guard.counter += 1;
guard.name = format!("writer-{}", i);
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(ctx.read().counter, 10);
}
#[test]
fn test_shared_context_mixed_read_write() {
let ctx = SharedContext::new(TestState {
counter: 0,
..Default::default()
});
let mut handles = vec![];
for _ in 0..5 {
let ctx_clone = ctx.clone();
let handle = thread::spawn(move || {
for _ in 0..100 {
ctx_clone.write().counter += 1;
}
});
handles.push(handle);
}
for _ in 0..5 {
let ctx_clone = ctx.clone();
let handle = thread::spawn(move || {
for _ in 0..100 {
let _ = ctx_clone.read().counter;
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(ctx.read().counter, 500);
}
#[test]
fn test_shared_context_empty() {
let ctx = SharedContext::<()>::empty();
let _read_guard = ctx.read();
drop(_read_guard);
let _write_guard = ctx.write();
drop(_write_guard);
}
#[test]
fn test_shared_context_ui_bindable_integration() {
let ctx = SharedContext::new(TestState {
counter: 123,
name: "bindable".to_string(),
});
let guard = ctx.read();
assert_eq!(
guard.get_field(&["counter"]),
Some(BindingValue::Integer(123))
);
assert_eq!(
guard.get_field(&["name"]),
Some(BindingValue::String("bindable".to_string()))
);
assert_eq!(guard.get_field(&["nonexistent"]), None);
}
}