use std::collections::BTreeSet;
use std::sync::Arc;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::Deserialize;
use serde::de::DeserializeOwned;
use serde_json::{Value, json};
use uuid::Uuid;
use crate::domain::contracts::{NodeStore, NodeStoreInitializer};
use crate::domain::models::{
AvecState, BatchRekeyResult, NodeQuery, ScopeRekeyResult, SttpNode,
};
use crate::storage::surrealdb::client::{QueryParams, SurrealDbClient};
use crate::storage::surrealdb::models::{
SurrealAvecRecord, SurrealNodeRecord, SurrealTriggerRecord,
};
use crate::storage::surrealdb::raw_queries;
const DEFAULT_TENANT: &str = "default";
const TENANT_SCOPE_PREFIX: &str = "tenant:";
const TENANT_SCOPE_SEPARATOR: &str = "::session:";
#[derive(Debug, Deserialize)]
struct MissingTenantRecord {
#[serde(default)]
id: Value,
#[serde(default)]
session_id: String,
}
#[derive(Debug, Clone, Deserialize)]
struct ScopeAnchorRecord {
#[serde(rename = "TenantId", default)]
tenant_id: Option<String>,
#[serde(rename = "SessionId", default)]
session_id: String,
}
#[derive(Debug, Clone, Deserialize)]
struct ScopeCountRecord {
#[serde(rename = "Count", default)]
count: usize,
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
struct ScopeKey {
tenant_id: String,
session_id: String,
}
pub struct SurrealDbNodeStore {
client: Arc<dyn SurrealDbClient>,
}
impl SurrealDbNodeStore {
pub fn new(client: Arc<dyn SurrealDbClient>) -> Self {
Self { client }
}
async fn backfill_missing_tenant_ids_async(&self) -> Result<()> {
self.backfill_table_tenant_ids_async(
"temporal_node",
raw_queries::SELECT_TEMPORAL_NODE_MISSING_TENANT_QUERY,
)
.await?;
self.backfill_table_tenant_ids_async(
"calibration",
raw_queries::SELECT_CALIBRATION_MISSING_TENANT_QUERY,
)
.await?;
Ok(())
}
async fn backfill_table_tenant_ids_async(&self, table: &str, select_query: &str) -> Result<()> {
let rows = self
.client
.raw_query(select_query, QueryParams::new())
.await?;
let records: Vec<MissingTenantRecord> = decode_rows(rows)?;
for record in records {
let session_id = record.session_id;
if session_id.trim().is_empty() {
continue;
}
let Some(record_id) = normalize_record_id(record.id, table) else {
continue;
};
let mut parameters = QueryParams::new();
parameters.insert(
"tenant_id".to_string(),
json!(derive_tenant_id_from_session(&session_id)),
);
let query = raw_queries::update_record_tenant_query(&record_id);
self.client.raw_query(&query, parameters).await?;
}
Ok(())
}
async fn count_scope_rows_async(
&self,
query: &str,
session_id: &str,
tenant_id: &str,
include_legacy: bool,
) -> Result<usize> {
let mut parameters = QueryParams::new();
parameters.insert("session_id".to_string(), json!(session_id));
parameters.insert("tenant_id".to_string(), json!(tenant_id));
parameters.insert("include_legacy".to_string(), json!(include_legacy));
let rows = self.client.raw_query(query, parameters).await?;
let counts: Vec<ScopeCountRecord> = decode_rows(rows)?;
Ok(counts.first().map(|value| value.count).unwrap_or(0))
}
async fn apply_scope_rekey_async(
&self,
source_session_id: &str,
source_tenant_id: &str,
target_session_id: &str,
target_tenant_id: &str,
) -> Result<()> {
let mut parameters = QueryParams::new();
parameters.insert("source_session_id".to_string(), json!(source_session_id));
parameters.insert("source_tenant_id".to_string(), json!(source_tenant_id));
parameters.insert(
"source_include_legacy".to_string(),
json!(includes_legacy_tenant_bucket(source_tenant_id)),
);
parameters.insert("target_session_id".to_string(), json!(target_session_id));
parameters.insert("target_tenant_id".to_string(), json!(target_tenant_id));
self.client
.raw_query(raw_queries::APPLY_SCOPE_REKEY_QUERY, parameters)
.await?;
Ok(())
}
}
#[async_trait]
impl NodeStoreInitializer for SurrealDbNodeStore {
async fn initialize_async(&self) -> Result<()> {
self.client
.raw_query(raw_queries::INIT_SCHEMA_QUERY, QueryParams::new())
.await?;
self.backfill_missing_tenant_ids_async().await?;
Ok(())
}
}
#[async_trait]
impl NodeStore for SurrealDbNodeStore {
async fn query_nodes_async(&self, query: NodeQuery) -> Result<Vec<SttpNode>> {
let capped_limit = query.limit.max(1);
let mut clauses = Vec::new();
let mut parameters = QueryParams::new();
if let Some(session_id) = query.session_id.as_ref().filter(|s| !s.trim().is_empty()) {
clauses.push(
"(tenant_id = $tenant_id OR tenant_id = NONE OR tenant_id = '')".to_string(),
);
parameters.insert(
"tenant_id".to_string(),
json!(derive_tenant_id_from_session(session_id)),
);
clauses.push("session_id = $session_id".to_string());
parameters.insert("session_id".to_string(), json!(session_id));
}
if let Some(from_utc) = query.from_utc {
clauses.push("timestamp >= <datetime>$from_utc".to_string());
parameters.insert("from_utc".to_string(), json!(from_utc.to_rfc3339()));
}
if let Some(to_utc) = query.to_utc {
clauses.push("timestamp <= <datetime>$to_utc".to_string());
parameters.insert("to_utc".to_string(), json!(to_utc.to_rfc3339()));
}
let where_clause = if clauses.is_empty() {
String::new()
} else {
format!("WHERE {}", clauses.join(" AND "))
};
let query_text = raw_queries::query_nodes_query(&where_clause, capped_limit);
let rows = self.client.raw_query(&query_text, parameters).await?;
let records: Vec<SurrealNodeRecord> = decode_rows(rows)?;
Ok(records.into_iter().map(map_to_node).collect())
}
async fn store_async(&self, node: SttpNode) -> Result<String> {
let compression_avec_to_use = match node.compression_avec {
Some(avec) if avec.psi() != 0.0 => avec,
_ => node.model_avec,
};
let include_parent_assignment = node.parent_node_id.is_some();
let mut parameters = QueryParams::new();
parameters.insert(
"tenant_id".to_string(),
json!(derive_tenant_id_from_session(&node.session_id)),
);
parameters.insert("session_id".to_string(), json!(node.session_id));
parameters.insert("raw".to_string(), json!(node.raw));
parameters.insert("tier".to_string(), json!(node.tier));
parameters.insert("timestamp".to_string(), json!(node.timestamp.to_rfc3339()));
parameters.insert("compression_depth".to_string(), json!(node.compression_depth));
parameters.insert("psi".to_string(), json!(node.psi));
parameters.insert("rho".to_string(), json!(node.rho));
parameters.insert("kappa".to_string(), json!(node.kappa));
parameters.insert("user_stability".to_string(), json!(node.user_avec.stability));
parameters.insert("user_friction".to_string(), json!(node.user_avec.friction));
parameters.insert("user_logic".to_string(), json!(node.user_avec.logic));
parameters.insert("user_autonomy".to_string(), json!(node.user_avec.autonomy));
parameters.insert("user_psi".to_string(), json!(node.user_avec.psi()));
parameters.insert("model_stability".to_string(), json!(node.model_avec.stability));
parameters.insert("model_friction".to_string(), json!(node.model_avec.friction));
parameters.insert("model_logic".to_string(), json!(node.model_avec.logic));
parameters.insert("model_autonomy".to_string(), json!(node.model_avec.autonomy));
parameters.insert("model_psi".to_string(), json!(node.model_avec.psi()));
parameters.insert(
"comp_stability".to_string(),
json!(compression_avec_to_use.stability),
);
parameters.insert(
"comp_friction".to_string(),
json!(compression_avec_to_use.friction),
);
parameters.insert("comp_logic".to_string(), json!(compression_avec_to_use.logic));
parameters.insert(
"comp_autonomy".to_string(),
json!(compression_avec_to_use.autonomy),
);
parameters.insert("comp_psi".to_string(), json!(compression_avec_to_use.psi()));
if let Some(parent_node_id) = node.parent_node_id {
parameters.insert("parent_node_id".to_string(), json!(parent_node_id));
}
let record_id = Uuid::new_v4().simple().to_string();
let query_text =
raw_queries::create_temporal_node_query(&record_id, include_parent_assignment);
self.client.raw_query(&query_text, parameters).await?;
Ok(record_id)
}
async fn get_by_resonance_async(
&self,
session_id: &str,
current_avec: AvecState,
limit: usize,
) -> Result<Vec<SttpNode>> {
let query_text = raw_queries::get_by_resonance_query(current_avec.psi(), limit.max(1));
let mut parameters = QueryParams::new();
parameters.insert(
"tenant_id".to_string(),
json!(derive_tenant_id_from_session(session_id)),
);
parameters.insert("session_id".to_string(), json!(session_id));
let rows = self.client.raw_query(&query_text, parameters).await?;
let records: Vec<SurrealNodeRecord> = decode_rows(rows)?;
Ok(records.into_iter().map(map_to_node).collect())
}
async fn list_nodes_async(&self, limit: usize, session_id: Option<&str>) -> Result<Vec<SttpNode>> {
self.query_nodes_async(NodeQuery {
limit: limit.clamp(1, 200),
session_id: session_id.map(|s| s.to_string()),
from_utc: None,
to_utc: None,
})
.await
}
async fn get_last_avec_async(&self, session_id: &str) -> Result<Option<AvecState>> {
let mut parameters = QueryParams::new();
parameters.insert(
"tenant_id".to_string(),
json!(derive_tenant_id_from_session(session_id)),
);
parameters.insert("session_id".to_string(), json!(session_id));
let rows = self
.client
.raw_query(raw_queries::GET_LAST_AVEC_QUERY, parameters)
.await?;
let records: Vec<SurrealAvecRecord> = decode_rows(rows)?;
Ok(records.first().map(|last| AvecState {
stability: last.stability,
friction: last.friction,
logic: last.logic,
autonomy: last.autonomy,
}))
}
async fn get_trigger_history_async(&self, session_id: &str) -> Result<Vec<String>> {
let mut parameters = QueryParams::new();
parameters.insert(
"tenant_id".to_string(),
json!(derive_tenant_id_from_session(session_id)),
);
parameters.insert("session_id".to_string(), json!(session_id));
let rows = self
.client
.raw_query(raw_queries::GET_TRIGGER_HISTORY_QUERY, parameters)
.await?;
let records: Vec<SurrealTriggerRecord> = decode_rows(rows)?;
Ok(records.into_iter().map(|record| record.trigger).collect())
}
async fn store_calibration_async(
&self,
session_id: &str,
avec: AvecState,
trigger: &str,
) -> Result<()> {
let mut parameters = QueryParams::new();
parameters.insert(
"tenant_id".to_string(),
json!(derive_tenant_id_from_session(session_id)),
);
parameters.insert("session_id".to_string(), json!(session_id));
parameters.insert("stability".to_string(), json!(avec.stability));
parameters.insert("friction".to_string(), json!(avec.friction));
parameters.insert("logic".to_string(), json!(avec.logic));
parameters.insert("autonomy".to_string(), json!(avec.autonomy));
parameters.insert("psi".to_string(), json!(avec.psi()));
parameters.insert("trigger".to_string(), json!(trigger));
parameters.insert("created_at".to_string(), json!(Utc::now().to_rfc3339()));
self.client
.raw_query(raw_queries::STORE_CALIBRATION_QUERY, parameters)
.await?;
Ok(())
}
async fn batch_rekey_scopes_async(
&self,
node_ids: Vec<String>,
target_tenant_id: &str,
target_session_id: &str,
dry_run: bool,
allow_merge: bool,
) -> Result<BatchRekeyResult> {
if node_ids.is_empty() {
return Err(anyhow!("at least one node id is required"));
}
if target_session_id.trim().is_empty() {
return Err(anyhow!("target session id cannot be empty"));
}
let target_tenant_id = normalize_tenant_id(Some(target_tenant_id));
let normalized_node_ids = node_ids
.into_iter()
.filter_map(|node_id| normalize_temporal_node_id(&node_id))
.collect::<Vec<_>>();
if normalized_node_ids.is_empty() {
return Err(anyhow!("no valid node ids were provided"));
}
let mut missing_node_ids = Vec::new();
let mut scope_keys = BTreeSet::new();
for node_id in &normalized_node_ids {
let mut parameters = QueryParams::new();
parameters.insert("node_id".to_string(), json!(node_id));
let rows = self
.client
.raw_query(raw_queries::SELECT_SCOPE_BY_NODE_ID_QUERY, parameters)
.await?;
let anchors: Vec<ScopeAnchorRecord> = decode_rows(rows)?;
let Some(anchor) = anchors.first() else {
missing_node_ids.push(node_id.clone());
continue;
};
if anchor.session_id.trim().is_empty() {
missing_node_ids.push(node_id.clone());
continue;
}
scope_keys.insert(ScopeKey {
tenant_id: normalize_tenant_id(anchor.tenant_id.as_deref()),
session_id: anchor.session_id.clone(),
});
}
let mut scope_results = Vec::new();
let mut temporal_nodes_updated = 0usize;
let mut calibrations_updated = 0usize;
for scope in scope_keys {
let source_include_legacy = includes_legacy_tenant_bucket(&scope.tenant_id);
let temporal_nodes = self
.count_scope_rows_async(
raw_queries::COUNT_TEMPORAL_SCOPE_QUERY,
&scope.session_id,
&scope.tenant_id,
source_include_legacy,
)
.await?;
let calibrations = self
.count_scope_rows_async(
raw_queries::COUNT_CALIBRATION_SCOPE_QUERY,
&scope.session_id,
&scope.tenant_id,
source_include_legacy,
)
.await?;
let same_scope = scope.tenant_id == target_tenant_id && scope.session_id == target_session_id;
let target_include_legacy = includes_legacy_tenant_bucket(&target_tenant_id);
let target_temporal_nodes = if same_scope {
0
} else {
self.count_scope_rows_async(
raw_queries::COUNT_TEMPORAL_SCOPE_QUERY,
target_session_id,
&target_tenant_id,
target_include_legacy,
)
.await?
};
let target_calibrations = if same_scope {
0
} else {
self.count_scope_rows_async(
raw_queries::COUNT_CALIBRATION_SCOPE_QUERY,
target_session_id,
&target_tenant_id,
target_include_legacy,
)
.await?
};
let conflict = !allow_merge
&& !same_scope
&& (target_temporal_nodes > 0 || target_calibrations > 0);
let mut applied = false;
let message = if same_scope {
Some("source and target scopes are identical".to_string())
} else if conflict {
Some("target scope already contains rows; set allow_merge=true to override"
.to_string())
} else {
if !dry_run {
self.apply_scope_rekey_async(
&scope.session_id,
&scope.tenant_id,
target_session_id,
&target_tenant_id,
)
.await?;
applied = true;
temporal_nodes_updated += temporal_nodes;
calibrations_updated += calibrations;
}
None
};
scope_results.push(ScopeRekeyResult {
source_tenant_id: scope.tenant_id,
source_session_id: scope.session_id,
target_tenant_id: target_tenant_id.clone(),
target_session_id: target_session_id.to_string(),
temporal_nodes,
calibrations,
target_temporal_nodes,
target_calibrations,
applied,
conflict,
message,
});
}
Ok(BatchRekeyResult {
dry_run,
requested_node_ids: normalized_node_ids.len(),
resolved_node_ids: normalized_node_ids.len().saturating_sub(missing_node_ids.len()),
missing_node_ids,
scopes: scope_results,
temporal_nodes_updated,
calibrations_updated,
})
}
}
fn map_to_node(record: SurrealNodeRecord) -> SttpNode {
let _ = (record.user_psi, record.model_psi, record.comp_psi, record.resonance_delta);
SttpNode {
raw: record.raw,
session_id: record.session_id,
tier: record.tier,
timestamp: parse_timestamp(&record.timestamp),
compression_depth: record.compression_depth,
parent_node_id: record.parent_node_id,
psi: record.psi as f32,
rho: record.rho as f32,
kappa: record.kappa as f32,
user_avec: AvecState {
stability: record.user_stability as f32,
friction: record.user_friction as f32,
logic: record.user_logic as f32,
autonomy: record.user_autonomy as f32,
},
model_avec: AvecState {
stability: record.model_stability as f32,
friction: record.model_friction as f32,
logic: record.model_logic as f32,
autonomy: record.model_autonomy as f32,
},
compression_avec: Some(AvecState {
stability: record.comp_stability as f32,
friction: record.comp_friction as f32,
logic: record.comp_logic as f32,
autonomy: record.comp_autonomy as f32,
}),
}
}
fn parse_timestamp(value: &str) -> DateTime<Utc> {
DateTime::parse_from_rfc3339(value)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now())
}
fn decode_rows<T>(rows: Vec<serde_json::Value>) -> Result<Vec<T>>
where
T: DeserializeOwned,
{
rows.into_iter()
.map(serde_json::from_value)
.collect::<std::result::Result<Vec<T>, _>>()
.map_err(Into::into)
}
fn normalize_record_id(raw_id: Value, fallback_table: &str) -> Option<String> {
match raw_id {
Value::String(id) => {
if id.trim().is_empty() {
None
} else {
Some(id)
}
}
Value::Object(map) => {
let table_name = map
.get("tb")
.and_then(value_to_record_component)
.unwrap_or_else(|| fallback_table.to_string());
let id_component = map.get("id").and_then(value_to_record_component)?;
Some(format!("{table_name}:{id_component}"))
}
_ => None,
}
}
fn value_to_record_component(value: &Value) -> Option<String> {
match value {
Value::String(v) => Some(v.clone()),
Value::Number(v) => Some(v.to_string()),
Value::Bool(v) => Some(v.to_string()),
_ => None,
}
}
fn derive_tenant_id_from_session(session_id: &str) -> String {
session_id
.strip_prefix(TENANT_SCOPE_PREFIX)
.and_then(|remainder| remainder.split_once(TENANT_SCOPE_SEPARATOR))
.map(|(tenant, _)| tenant)
.filter(|tenant| !tenant.trim().is_empty())
.unwrap_or(DEFAULT_TENANT)
.to_string()
}
fn normalize_tenant_id(value: Option<&str>) -> String {
value
.map(str::trim)
.filter(|tenant| !tenant.is_empty())
.unwrap_or(DEFAULT_TENANT)
.to_string()
}
fn includes_legacy_tenant_bucket(tenant_id: &str) -> bool {
tenant_id == DEFAULT_TENANT
}
fn normalize_temporal_node_id(value: &str) -> Option<String> {
let trimmed = value.trim();
if trimmed.is_empty() {
return None;
}
if let Some((table, id)) = trimmed.split_once(':') {
if table != "temporal_node" {
return None;
}
let id = id.trim();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
} else {
Some(trimmed.to_string())
}
}