use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::{RwLock, mpsc::Sender};
use super::{Observer, ObserverValue};
type ObserverSender = Arc<Sender<ObserverValue>>;
type PathChannels = HashMap<String, ObserverSender>;
type DeviceChannels = HashMap<String, PathChannels>;
#[derive(Clone, Debug)]
pub struct MemObserver {
db: HashMap<String, Value>, channels: Arc<RwLock<DeviceChannels>>, }
impl MemObserver {
pub fn new() -> Self {
Self {
db: HashMap::new(),
channels: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemObserver {
fn default() -> Self {
Self::new()
}
}
use std::fmt;
#[derive(Debug)]
pub enum MemObserverError {
IoError(std::io::Error),
IdNotSet,
}
impl fmt::Display for MemObserverError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MemObserverError::IoError(err) => write!(f, "IO error: {}", err),
MemObserverError::IdNotSet => write!(f, "Device ID must be set before use!"),
}
}
}
impl std::error::Error for MemObserverError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
MemObserverError::IoError(err) => Some(err),
MemObserverError::IdNotSet => None,
}
}
}
impl From<std::io::Error> for MemObserverError {
fn from(err: std::io::Error) -> MemObserverError {
MemObserverError::IoError(err)
}
}
#[async_trait]
impl Observer for MemObserver {
type Error = MemObserverError;
async fn register(
&mut self,
device_id: &str,
path: &str,
sender: Arc<Sender<ObserverValue>>,
) -> Result<(), Self::Error> {
let mut channels = self.channels.write().await;
channels
.entry(device_id.to_string())
.or_insert_with(HashMap::new)
.insert(path.to_string(), sender);
log::debug!(
"Registered observer for device '{}' at path '{}'",
device_id,
path
);
Ok(())
}
async fn unregister(&mut self, device_id: &str, path: &str) -> Result<(), Self::Error> {
let mut channels = self.channels.write().await;
if let Some(device_channels) = channels.get_mut(device_id) {
device_channels.remove(path);
if device_channels.is_empty() {
channels.remove(device_id);
}
}
Ok(())
}
async fn unregister_all(&mut self) -> Result<(), Self::Error> {
self.channels.write().await.clear();
Ok(())
}
async fn write(
&mut self,
device_id: &str,
path: &str,
payload: &Value,
) -> Result<(), Self::Error> {
let new_value = super::path_to_json(path, payload);
log::debug!("New value: {:?} for path: {}", new_value, path);
let mut current_value = Value::Null;
let value = if let Some(value) = self.db.get(device_id) {
current_value = value.clone();
let mut merged_value = value.clone();
super::merge_json(&mut merged_value, &new_value);
log::debug!("Merged value: {:?}", merged_value);
merged_value
} else {
new_value
};
let device_channels = {
let channels = self.channels.read().await;
log::debug!(
"Looking for observers for device '{}' with write to path '{}'",
device_id,
path
);
log::debug!(
"Currently registered devices: {:?}",
channels.keys().collect::<Vec<_>>()
);
channels.get(device_id).cloned()
};
if let Some(device_channels) = device_channels {
log::debug!(
"Found device '{}' with {} observers",
device_id,
device_channels.len()
);
for (obs_path, sender) in device_channels.iter() {
let json_pointer = if obs_path.starts_with('/') {
obs_path.clone()
} else {
format!("/{}", obs_path)
};
let current_value = current_value.pointer(&json_pointer);
let incoming_value = value.pointer(&json_pointer);
log::debug!("Comparing paths: {} for device: {}", obs_path, device_id);
log::debug!("Current value at path: {:?}", current_value);
log::debug!("Incoming value at path: {:?}", incoming_value);
if current_value != incoming_value {
log::debug!(
"Value changed at path: {} for device: {}",
obs_path,
device_id
);
let value = match incoming_value {
Some(value) => value.clone(),
None => Value::Null,
};
if let Err(e) = sender
.send(ObserverValue {
path: obs_path.clone(),
value,
})
.await
{
log::warn!(
"Failed to send observer notification for device {} path {}: {}",
device_id,
obs_path,
e
);
}
}
}
} else {
log::warn!("No observers found for device '{}'", device_id);
}
self.db.insert(device_id.to_string(), value);
Ok(())
}
async fn read(&mut self, device_id: &str, path: &str) -> Result<Option<Value>, Self::Error> {
match self.db.get(device_id) {
Some(value) => {
log::debug!("Got value: {:?}", value);
let pointer_value = value.pointer(path).cloned();
log::debug!("Pointer value: {:?}", pointer_value);
Ok(pointer_value)
}
None => Ok(None),
}
}
async fn clear(&mut self, device_id: &str) -> Result<(), Self::Error> {
let _ = self.db.remove(device_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use serde_json::json;
use tokio::time::sleep;
use super::*;
lazy_static! {
static ref OBSERVER: MemObserver = MemObserver::new();
}
#[tokio::test]
async fn test_sled_observer_write_and_read() {
let _ = env_logger::try_init();
let mut observer = OBSERVER.clone();
observer.clear("123").await.unwrap();
observer
.write("123", "/test_path", &json!({"test_key": "test_value"}))
.await
.unwrap();
let result = observer.read("123", "/test_path").await.unwrap();
assert_eq!(result, Some(json!({"test_key": "test_value"})));
observer
.write(
"123",
"/test_path/second_level",
&json!({"test_key": "test_value"}),
)
.await
.unwrap();
let result = observer
.read("123", "/test_path/second_level")
.await
.unwrap();
assert_eq!(result, Some(json!({"test_key": "test_value"})));
let result = observer.read("123", "/test_path").await.unwrap();
assert_eq!(
result,
Some(json!({"test_key": "test_value", "second_level": {"test_key": "test_value"}}))
);
}
#[tokio::test]
async fn test_sled_observer_observe_and_write() {
let _ = env_logger::try_init();
let mut observer = OBSERVER.clone();
observer.clear("123").await.unwrap();
let (tx, mut rx) = tokio::sync::mpsc::channel::<ObserverValue>(10);
let fut = tokio::spawn(async move {
if let Some(r) = rx.recv().await {
assert_eq!(r.value, json!({"test_key": "test_value"}));
assert_eq!(r.path, "/observe_and_write".to_string());
}
});
sleep(Duration::from_secs(1)).await;
observer
.register("123", "/observe_and_write", Arc::new(tx.clone()))
.await
.unwrap();
observer
.write(
"123",
"/observe_and_write",
&json!({"test_key": "test_value"}),
)
.await
.unwrap();
observer
.write("123", "/observe", &json!({"test": "mest"}))
.await
.unwrap();
fut.await.unwrap();
observer
.unregister("123", "/observe_and_write")
.await
.unwrap();
assert!(
!observer
.channels
.read()
.await
.get("123")
.map(|device_channels| device_channels.contains_key("/observe_and_write"))
.unwrap_or(false)
);
observer
.register("123", "/observe_and_write", Arc::new(tx.clone()))
.await
.unwrap();
observer.unregister_all().await.unwrap();
assert!(observer.channels.read().await.is_empty());
}
}