use std::{collections::BTreeMap, sync::Arc, time::Duration};
use hass_rs::{HassClient, HassEntity};
use tokio::sync::RwLock;
pub struct PersistentHassConnection {
hass: Arc<RwLock<HassClient>>,
url: String,
token: String,
close: tokio::sync::mpsc::Sender<()>,
states: RwLock<BTreeMap<String, HassEntity>>,
update_interval: Duration,
}
impl PersistentHassConnection {
pub async fn new(
url: String,
token: String,
update_interval: Duration,
) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
let (tx, rx) = tokio::sync::mpsc::channel::<()>(1);
let mut hass = HassClient::new(&url).await?;
hass.auth_with_longlivedtoken(&token).await?;
let connection = Self {
hass: Arc::new(RwLock::new(hass)),
url,
token,
close: tx,
states: RwLock::new(BTreeMap::new()),
update_interval,
};
let connection = Arc::new(connection);
let connection_clone = connection.clone();
tokio::spawn(async move {
connection_clone.keep_alive(rx).await;
});
Ok(connection)
}
async fn create_client(&self) -> Result<HassClient, Box<dyn std::error::Error>> {
let mut client = HassClient::new(&self.url).await?;
client.auth_with_longlivedtoken(&self.token).await?;
Ok(client)
}
async fn replace_client(&self) -> Result<(), Box<dyn std::error::Error>> {
let client = self.create_client().await?;
let mut hass = self.hass.write().await;
*hass = client;
Ok(())
}
pub async fn call_service(
&self,
domain: &str,
service: &str,
data: Option<serde_json::Value>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut client = self.hass.write().await;
client
.call_service(domain.to_string(), service.to_string(), data)
.await?;
Ok(())
}
pub async fn fetch_states(&self) -> Result<(), String> {
let mut client = self.hass.write().await;
let states = client.get_states().await.map_err(|e| e.to_string())?;
let mut state_map = self.states.write().await;
for state in states {
state_map.insert(state.entity_id.clone(), state);
}
Ok(())
}
pub async fn get_state(&self, entity_id: &str) -> Option<HassEntity> {
let state_map = self.states.read().await;
state_map.get(entity_id).cloned()
}
async fn keep_alive(self: Arc<Self>, mut end: tokio::sync::mpsc::Receiver<()>) {
loop {
let close_future = end.recv();
let fetch_future = self.fetch_states();
tokio::select! {
_ = close_future => {
println!("Closing connection");
break;
}
result = fetch_future => {
if let Err(e) = result {
eprintln!("Error fetching states: {}", e);
match self.replace_client().await {
Ok(_) => {
println!("Replaced client");
}
Err(e) => {
eprintln!("Error replacing client: {}", e);
}
}
}
}
}
tokio::time::sleep(self.update_interval).await;
}
}
}
impl Drop for PersistentHassConnection {
fn drop(&mut self) {
let _ = self.close.try_send(());
}
}