use surrealdb::Surreal;
use crate::embed::Embedder;
use crate::error::GraphError;
use crate::store::Db;
use crate::types::*;
use crate::{deserialize_take, deserialize_take_opt};
pub async fn add_entity(
db: &Surreal<Db>,
embedder: &dyn Embedder,
entity: NewEntity,
) -> Result<Entity, GraphError> {
let embedding = embedder.embed_single(&entity.abstract_text)?;
let mutable = entity.entity_type.is_mutable();
let mut response = db
.query(
r#"
CREATE entity SET
name = $name,
entity_type = $entity_type,
abstract = $abstract,
overview = $overview,
content = $content,
attributes = $attributes,
embedding = $embedding,
mutable = $mutable,
access_count = 0,
created_at = time::now(),
updated_at = time::now(),
source = $source
"#,
)
.bind(("name", entity.name))
.bind(("entity_type", entity.entity_type.to_string()))
.bind(("abstract", entity.abstract_text))
.bind(("overview", entity.overview.unwrap_or_default()))
.bind(("content", entity.content))
.bind(("attributes", entity.attributes))
.bind(("embedding", embedding))
.bind(("mutable", mutable))
.bind(("source", entity.source))
.await?;
let created: Option<Entity> = deserialize_take_opt(&mut response, 0)?;
created
.ok_or_else(|| GraphError::Db(surrealdb::Error::thrown("failed to create entity".into())))
}
pub async fn get_entity_by_name(
db: &Surreal<Db>,
name: &str,
) -> Result<Option<Entity>, GraphError> {
let mut response = db
.query("SELECT * FROM entity WHERE name = $name LIMIT 1")
.bind(("name", name.to_string()))
.await?;
deserialize_take_opt(&mut response, 0)
}
pub async fn get_entity_by_id(db: &Surreal<Db>, id: &str) -> Result<Option<Entity>, GraphError> {
let mut response = db
.query("SELECT * FROM type::record($id)")
.bind(("id", id.to_string()))
.await?;
deserialize_take_opt(&mut response, 0)
}
pub async fn update_entity(
db: &Surreal<Db>,
embedder: &dyn Embedder,
id: &str,
updates: EntityUpdate,
) -> Result<Entity, GraphError> {
let mut sets = vec!["updated_at = time::now()".to_string()];
let mut bindings: Vec<(String, serde_json::Value)> = vec![];
if let Some(ref abs) = updates.abstract_text {
sets.push("abstract = $new_abstract".to_string());
bindings.push((
"new_abstract".to_string(),
serde_json::Value::String(abs.clone()),
));
let embedding = embedder.embed_single(abs)?;
sets.push("embedding = $new_embedding".to_string());
bindings.push(("new_embedding".to_string(), serde_json::json!(embedding)));
}
if let Some(ref ov) = updates.overview {
sets.push("overview = $new_overview".to_string());
bindings.push((
"new_overview".to_string(),
serde_json::Value::String(ov.clone()),
));
}
if let Some(ref ct) = updates.content {
sets.push("content = $new_content".to_string());
bindings.push((
"new_content".to_string(),
serde_json::Value::String(ct.clone()),
));
}
if let Some(ref attr) = updates.attributes {
sets.push("attributes = $new_attributes".to_string());
bindings.push(("new_attributes".to_string(), attr.clone()));
}
let query = format!(
"UPDATE type::record($id) SET {} RETURN AFTER",
sets.join(", ")
);
let id_owned = id.to_string();
let mut q = db.query(&query).bind(("id", id_owned));
for (k, v) in bindings {
q = q.bind((k, v));
}
let mut response = q.await?;
let updated: Vec<Entity> = deserialize_take(&mut response, 0)?;
updated
.into_iter()
.next()
.ok_or_else(|| GraphError::NotFound(id.to_string()))
}
pub async fn delete_entity(db: &Surreal<Db>, id: &str) -> Result<(), GraphError> {
let id_owned = id.to_string();
db.query(
r#"
DELETE FROM relates_to WHERE in = type::record($id) OR out = type::record($id);
DELETE FROM type::record($id);
"#,
)
.bind(("id", id_owned))
.await?
.check()?;
Ok(())
}
pub async fn list_entities(
db: &Surreal<Db>,
entity_type: Option<&str>,
) -> Result<Vec<Entity>, GraphError> {
let mut response = if let Some(et) = entity_type {
db.query("SELECT * FROM entity WHERE entity_type = $et ORDER BY name")
.bind(("et", et.to_string()))
.await?
} else {
db.query("SELECT * FROM entity ORDER BY name").await?
};
deserialize_take(&mut response, 0)
}
pub async fn add_relationship(
db: &Surreal<Db>,
rel: NewRelationship,
) -> Result<Relationship, GraphError> {
let from = get_entity_by_name(db, &rel.from_entity)
.await?
.ok_or_else(|| GraphError::NotFound(rel.from_entity.clone()))?;
let to = get_entity_by_name(db, &rel.to_entity)
.await?
.ok_or_else(|| GraphError::NotFound(rel.to_entity.clone()))?;
let from_id = from.id_string();
let to_id = to.id_string();
let mut response = db
.query(
r#"
LET $from = type::record($from_id);
LET $to = type::record($to_id);
RELATE $from -> relates_to -> $to SET
rel_type = $rel_type,
description = $description,
valid_from = time::now(),
valid_until = NONE,
confidence = $confidence,
source = $source
"#,
)
.bind(("from_id", from_id))
.bind(("to_id", to_id))
.bind(("rel_type", rel.rel_type))
.bind(("description", rel.description))
.bind(("confidence", rel.confidence.unwrap_or(1.0) as f64))
.bind(("source", rel.source))
.await?;
let created: Option<Relationship> = deserialize_take_opt(&mut response, 2)?;
created.ok_or_else(|| {
GraphError::Db(surrealdb::Error::thrown(
"failed to create relationship".into(),
))
})
}
pub async fn get_relationships(
db: &Surreal<Db>,
entity_name: &str,
direction: Direction,
) -> Result<Vec<Relationship>, GraphError> {
let entity = get_entity_by_name(db, entity_name)
.await?
.ok_or_else(|| GraphError::NotFound(entity_name.to_string()))?;
let entity_id = entity.id_string();
let query = match direction {
Direction::Outgoing => "SELECT * FROM relates_to WHERE in = type::record($id)",
Direction::Incoming => "SELECT * FROM relates_to WHERE out = type::record($id)",
Direction::Both => {
"SELECT * FROM relates_to WHERE in = type::record($id) OR out = type::record($id)"
}
};
let mut response = db.query(query).bind(("id", entity_id)).await?;
deserialize_take(&mut response, 0)
}
pub async fn supersede_relationship(
db: &Surreal<Db>,
old_id: &str,
new: NewRelationship,
) -> Result<Relationship, GraphError> {
let old_id_owned = old_id.to_string();
db.query("UPDATE type::record($id) SET valid_until = time::now()")
.bind(("id", old_id_owned))
.await?
.check()?;
add_relationship(db, new).await
}
pub async fn get_entity_summary(
db: &Surreal<Db>,
id: &str,
) -> Result<Option<EntitySummary>, GraphError> {
let mut response = db
.query("SELECT id, name, entity_type, abstract FROM type::record($id)")
.bind(("id", id.to_string()))
.await?;
deserialize_take_opt(&mut response, 0)
}
pub async fn get_entity_detail(
db: &Surreal<Db>,
id: &str,
) -> Result<Option<EntityDetail>, GraphError> {
let mut response = db
.query(
r#"SELECT id, name, entity_type, abstract, overview, attributes,
access_count, updated_at, source
FROM type::record($id)"#,
)
.bind(("id", id.to_string()))
.await?;
deserialize_take_opt(&mut response, 0)
}
pub async fn increment_access_counts(db: &Surreal<Db>, ids: &[String]) -> Result<(), GraphError> {
if ids.is_empty() {
return Ok(());
}
for id in ids {
let _ = db
.query("UPDATE type::record($id) SET access_count += 1")
.bind(("id", id.clone()))
.await;
}
Ok(())
}
pub async fn add_episode(
db: &Surreal<Db>,
embedder: &dyn Embedder,
episode: NewEpisode,
) -> Result<Episode, GraphError> {
let embedding = embedder.embed_single(&episode.abstract_text)?;
let mut response = db
.query(
r#"
CREATE episode SET
session_id = $session_id,
timestamp = time::now(),
abstract = $abstract,
overview = $overview,
content = $content,
embedding = $embedding,
log_number = $log_number
"#,
)
.bind(("session_id", episode.session_id))
.bind(("abstract", episode.abstract_text))
.bind(("overview", episode.overview))
.bind(("content", episode.content))
.bind(("embedding", embedding))
.bind(("log_number", episode.log_number.map(|n| n as i64)))
.await?;
let created: Option<Episode> = deserialize_take_opt(&mut response, 0)?;
created
.ok_or_else(|| GraphError::Db(surrealdb::Error::thrown("failed to create episode".into())))
}
pub async fn get_episodes_by_session(
db: &Surreal<Db>,
session_id: &str,
) -> Result<Vec<Episode>, GraphError> {
let mut response = db
.query("SELECT * FROM episode WHERE session_id = $sid ORDER BY timestamp")
.bind(("sid", session_id.to_string()))
.await?;
deserialize_take(&mut response, 0)
}
pub async fn mark_episodes_extracted(db: &Surreal<Db>, log_number: u32) -> Result<(), GraphError> {
db.query("UPDATE episode SET extracted = true WHERE log_number = $ln")
.bind(("ln", log_number as i64))
.await?
.check()?;
Ok(())
}
pub async fn get_unextracted_log_numbers(db: &Surreal<Db>) -> Result<Vec<i64>, GraphError> {
let mut response = db
.query("SELECT log_number FROM episode WHERE extracted = false AND log_number IS NOT NONE GROUP BY log_number ORDER BY log_number")
.await?;
#[derive(serde::Deserialize)]
struct Row {
log_number: i64,
}
let rows: Vec<Row> = crate::deserialize_take(&mut response, 0)?;
Ok(rows.into_iter().map(|r| r.log_number).collect())
}
pub async fn get_episode_by_log_number(
db: &Surreal<Db>,
log_number: u32,
) -> Result<Option<Episode>, GraphError> {
let mut response = db
.query("SELECT * FROM episode WHERE log_number = $ln LIMIT 1")
.bind(("ln", log_number as i64))
.await?;
deserialize_take_opt(&mut response, 0)
}