use crate::consistency::ConsistencyEngine;
use crate::Result;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
pub struct StateModelRegistry {
model_to_workspaces: Arc<RwLock<HashMap<String, HashSet<String>>>>,
workspace_to_model: Arc<RwLock<HashMap<String, String>>>,
consistency_engine: Arc<ConsistencyEngine>,
}
impl StateModelRegistry {
pub fn new(consistency_engine: Arc<ConsistencyEngine>) -> Self {
Self {
model_to_workspaces: Arc::new(RwLock::new(HashMap::new())),
workspace_to_model: Arc::new(RwLock::new(HashMap::new())),
consistency_engine,
}
}
pub async fn register_workspace(
&self,
workspace_id: String,
state_model: String,
) -> Result<()> {
let old_model = {
let workspace_to_model = self.workspace_to_model.read().await;
workspace_to_model.get(&workspace_id).cloned()
};
if let Some(ref old) = old_model {
if old == &state_model {
return Ok(());
}
let mut model_to_workspaces = self.model_to_workspaces.write().await;
if let Some(workspaces) = model_to_workspaces.get_mut(old) {
workspaces.remove(&workspace_id);
if workspaces.is_empty() {
model_to_workspaces.remove(old);
}
}
}
{
let mut model_to_workspaces = self.model_to_workspaces.write().await;
let mut workspace_to_model = self.workspace_to_model.write().await;
model_to_workspaces
.entry(state_model.clone())
.or_insert_with(HashSet::new)
.insert(workspace_id.clone());
workspace_to_model.insert(workspace_id.clone(), state_model.clone());
}
info!("Registered workspace {} with state model {}", workspace_id, state_model);
Ok(())
}
pub async fn unregister_workspace(&self, workspace_id: &str) -> Result<()> {
let state_model = {
let mut workspace_to_model = self.workspace_to_model.write().await;
workspace_to_model.remove(workspace_id)
};
if let Some(model) = state_model {
let mut model_to_workspaces = self.model_to_workspaces.write().await;
if let Some(workspaces) = model_to_workspaces.get_mut(&model) {
workspaces.remove(workspace_id);
if workspaces.is_empty() {
model_to_workspaces.remove(&model);
}
}
info!("Unregistered workspace {} from state model {}", workspace_id, model);
}
Ok(())
}
pub async fn get_state_model(&self, workspace_id: &str) -> Option<String> {
let workspace_to_model = self.workspace_to_model.read().await;
workspace_to_model.get(workspace_id).cloned()
}
pub async fn get_workspaces_for_model(&self, state_model: &str) -> Vec<String> {
let model_to_workspaces = self.model_to_workspaces.read().await;
model_to_workspaces
.get(state_model)
.map(|workspaces| workspaces.iter().cloned().collect())
.unwrap_or_default()
}
pub async fn list_state_models(&self) -> Vec<String> {
let model_to_workspaces = self.model_to_workspaces.read().await;
model_to_workspaces.keys().cloned().collect()
}
pub async fn sync_persona_graph(&self, workspace_id: &str) -> Result<()> {
let state_model = match self.get_state_model(workspace_id).await {
Some(model) => model,
None => {
debug!("Workspace {} not registered with any state model", workspace_id);
return Ok(());
}
};
let _source_state = match self.consistency_engine.get_state(workspace_id).await {
Some(state) => state,
None => {
warn!("Source workspace {} not found", workspace_id);
return Ok(());
}
};
let _target_workspaces = self.get_workspaces_for_model(&state_model).await;
#[cfg(feature = "persona-graph")]
if let Some(_source_graph) = _source_state.persona_graph() {
for target_workspace_id in _target_workspaces {
if target_workspace_id == workspace_id {
continue; }
let mut target_state =
self.consistency_engine.get_or_create_state(&target_workspace_id).await;
let _target_graph = target_state.get_or_create_persona_graph();
debug!(
"Persona graph sync needed from {} to {} (state model: {})",
workspace_id, target_workspace_id, state_model
);
}
}
Ok(())
}
pub async fn ensure_consistency(&self, state_model: &str) -> Result<()> {
let workspaces = self.get_workspaces_for_model(state_model).await;
if workspaces.is_empty() {
return Ok(());
}
let source_workspace = &workspaces[0];
for workspace_id in workspaces.iter().skip(1) {
if let Some(source_state) = self.consistency_engine.get_state(source_workspace).await {
if let Some(ref persona) = source_state.active_persona {
if let Err(e) = self
.consistency_engine
.set_active_persona(workspace_id, persona.clone())
.await
{
warn!(
"Failed to sync persona from {} to {}: {}",
source_workspace, workspace_id, e
);
}
}
}
}
debug!("Ensured consistency for state model {}", state_model);
Ok(())
}
}