1pub mod crud;
7pub mod dedup;
8pub mod embed;
9pub mod error;
10pub mod extract;
11pub mod ingest;
12pub mod llm;
13pub mod pipeline;
14pub mod pipeline_sync;
15pub mod query;
16pub mod search;
17pub mod store;
18pub mod traverse;
19pub mod types;
20
21use std::collections::HashMap;
22use std::path::{Path, PathBuf};
23
24use embed::FastEmbedder;
25use error::GraphError;
26use store::Db;
27#[allow(unused_imports)]
28use surrealdb::types::SurrealValue;
29use surrealdb::Surreal;
30use types::*;
31
32pub(crate) fn deserialize_take<T: serde::de::DeserializeOwned>(
35 response: &mut surrealdb::IndexedResults,
36 index: usize,
37) -> Result<Vec<T>, GraphError> {
38 let values: Vec<serde_json::Value> = response.take(index)?;
39 values
40 .into_iter()
41 .map(|v| serde_json::from_value(v).map_err(GraphError::from))
42 .collect()
43}
44
45pub(crate) fn deserialize_take_opt<T: serde::de::DeserializeOwned>(
46 response: &mut surrealdb::IndexedResults,
47 index: usize,
48) -> Result<Option<T>, GraphError> {
49 let values: Vec<T> = deserialize_take(response, index)?;
50 Ok(values.into_iter().next())
51}
52
53pub struct GraphMemory {
55 db: Surreal<Db>,
56 embedder: FastEmbedder,
57 path: PathBuf,
58}
59
60impl GraphMemory {
61 pub async fn open(path: &Path) -> Result<Self, GraphError> {
64 std::fs::create_dir_all(path)?;
65
66 let db = store::open(path).await?;
67 store::init_schema(&db).await?;
68
69 let models_dir = path.join("models");
70 std::fs::create_dir_all(&models_dir)?;
71 let embedder = FastEmbedder::new(&models_dir)?;
72
73 Ok(Self {
74 db,
75 embedder,
76 path: path.to_path_buf(),
77 })
78 }
79
80 pub fn path(&self) -> &Path {
82 &self.path
83 }
84
85 #[allow(dead_code)]
87 pub(crate) fn db(&self) -> &Surreal<Db> {
88 &self.db
89 }
90
91 #[allow(dead_code)]
93 pub(crate) fn embedder(&self) -> &FastEmbedder {
94 &self.embedder
95 }
96
97 pub async fn add_entity(&self, entity: NewEntity) -> Result<Entity, GraphError> {
101 crud::add_entity(&self.db, &self.embedder, entity).await
102 }
103
104 pub async fn get_entity(&self, name: &str) -> Result<Option<Entity>, GraphError> {
106 crud::get_entity_by_name(&self.db, name).await
107 }
108
109 pub async fn get_entity_by_id(&self, id: &str) -> Result<Option<Entity>, GraphError> {
111 crud::get_entity_by_id(&self.db, id).await
112 }
113
114 pub async fn update_entity(
116 &self,
117 id: &str,
118 updates: EntityUpdate,
119 ) -> Result<Entity, GraphError> {
120 crud::update_entity(&self.db, &self.embedder, id, updates).await
121 }
122
123 pub async fn delete_entity(&self, id: &str) -> Result<(), GraphError> {
125 crud::delete_entity(&self.db, id).await
126 }
127
128 pub async fn list_entities(
130 &self,
131 entity_type: Option<&str>,
132 ) -> Result<Vec<Entity>, GraphError> {
133 crud::list_entities(&self.db, entity_type).await
134 }
135
136 pub async fn add_relationship(&self, rel: NewRelationship) -> Result<Relationship, GraphError> {
140 crud::add_relationship(&self.db, rel).await
141 }
142
143 pub async fn get_relationships(
145 &self,
146 entity_name: &str,
147 direction: Direction,
148 ) -> Result<Vec<Relationship>, GraphError> {
149 crud::get_relationships(&self.db, entity_name, direction).await
150 }
151
152 pub async fn supersede_relationship(
154 &self,
155 old_id: &str,
156 new: NewRelationship,
157 ) -> Result<Relationship, GraphError> {
158 crud::supersede_relationship(&self.db, old_id, new).await
159 }
160
161 pub async fn add_episode(&self, episode: NewEpisode) -> Result<Episode, GraphError> {
165 crud::add_episode(&self.db, &self.embedder, episode).await
166 }
167
168 pub async fn get_episodes_by_session(
170 &self,
171 session_id: &str,
172 ) -> Result<Vec<Episode>, GraphError> {
173 crud::get_episodes_by_session(&self.db, session_id).await
174 }
175
176 pub async fn get_episode_by_log_number(
178 &self,
179 log_number: u32,
180 ) -> Result<Option<Episode>, GraphError> {
181 crud::get_episode_by_log_number(&self.db, log_number).await
182 }
183
184 pub async fn ingest_archive(
188 &self,
189 archive_text: &str,
190 session_id: &str,
191 log_number: Option<u32>,
192 llm: Option<&dyn llm::LlmProvider>,
193 ) -> Result<IngestionReport, GraphError> {
194 ingest::ingest_archive(self, archive_text, session_id, log_number, llm).await
195 }
196
197 pub async fn extract_from_archive(
199 &self,
200 archive_text: &str,
201 session_id: &str,
202 log_number: Option<u32>,
203 llm: &dyn llm::LlmProvider,
204 ) -> Result<IngestionReport, GraphError> {
205 ingest::extract_from_archive(self, archive_text, session_id, log_number, llm).await
206 }
207
208 pub async fn mark_extracted(&self, log_number: u32) -> Result<(), GraphError> {
210 crud::mark_episodes_extracted(&self.db, log_number).await
211 }
212
213 pub async fn unextracted_log_numbers(&self) -> Result<Vec<i64>, GraphError> {
215 crud::get_unextracted_log_numbers(&self.db).await
216 }
217
218 pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>, GraphError> {
222 search::search(&self.db, &self.embedder, query, limit).await
223 }
224
225 pub async fn search_with_options(
227 &self,
228 query: &str,
229 options: &SearchOptions,
230 ) -> Result<Vec<ScoredEntity>, GraphError> {
231 search::search_with_options(&self.db, &self.embedder, query, options).await
232 }
233
234 pub async fn search_episodes(
236 &self,
237 query: &str,
238 limit: usize,
239 ) -> Result<Vec<EpisodeSearchResult>, GraphError> {
240 search::search_episodes(&self.db, &self.embedder, query, limit).await
241 }
242
243 pub async fn query(
247 &self,
248 query_text: &str,
249 options: &QueryOptions,
250 ) -> Result<QueryResult, GraphError> {
251 query::query(&self.db, &self.embedder, query_text, options).await
252 }
253
254 pub async fn traverse(
258 &self,
259 entity_name: &str,
260 depth: u32,
261 ) -> Result<TraversalNode, GraphError> {
262 traverse::traverse(&self.db, entity_name, depth).await
263 }
264
265 pub async fn traverse_filtered(
267 &self,
268 entity_name: &str,
269 depth: u32,
270 type_filter: Option<&str>,
271 ) -> Result<TraversalNode, GraphError> {
272 traverse::traverse_filtered(&self.db, entity_name, depth, type_filter).await
273 }
274
275 pub async fn sync_pipeline(
279 &self,
280 docs: &PipelineDocuments,
281 ) -> Result<PipelineSyncReport, GraphError> {
282 pipeline_sync::sync_pipeline(self, docs).await
283 }
284
285 pub async fn pipeline_stats(
287 &self,
288 staleness_days: u32,
289 ) -> Result<PipelineGraphStats, GraphError> {
290 query::pipeline_stats(&self.db, staleness_days).await
291 }
292
293 pub async fn pipeline_entities(
295 &self,
296 stage: &str,
297 status: Option<&str>,
298 ) -> Result<Vec<EntityDetail>, GraphError> {
299 query::pipeline_entities(&self.db, stage, status).await
300 }
301
302 pub async fn pipeline_flow(
304 &self,
305 entity_name: &str,
306 ) -> Result<Vec<(EntityDetail, String, EntityDetail)>, GraphError> {
307 query::pipeline_flow(&self.db, entity_name).await
308 }
309
310 pub async fn stats(&self) -> Result<GraphStats, GraphError> {
314 let entity_count = db_count(&self.db, "entity").await?;
315 let relationship_count = db_count(&self.db, "relates_to").await?;
316 let episode_count = db_count(&self.db, "episode").await?;
317
318 let mut type_response = self
320 .db
321 .query("SELECT entity_type, count() AS count FROM entity GROUP BY entity_type")
322 .await?;
323
324 let type_rows: Vec<TypeCount> = type_response.take(0)?;
325 let entity_type_counts: HashMap<String, u64> = type_rows
326 .into_iter()
327 .map(|r| (r.entity_type, r.count))
328 .collect();
329
330 Ok(GraphStats {
331 entity_count,
332 relationship_count,
333 episode_count,
334 entity_type_counts,
335 })
336 }
337}
338
339async fn db_count(db: &Surreal<Db>, table: &str) -> Result<u64, GraphError> {
340 let query = format!("SELECT count() AS count FROM {} GROUP ALL", table);
341 let mut response = db.query(&query).await?;
342 let rows: Vec<CountRow> = response.take(0)?;
343 Ok(rows.first().map(|r| r.count).unwrap_or(0))
344}
345
346#[derive(serde::Deserialize, surrealdb::types::SurrealValue)]
347struct CountRow {
348 count: u64,
349}
350
351#[derive(serde::Deserialize, surrealdb::types::SurrealValue)]
352struct TypeCount {
353 entity_type: String,
354 count: u64,
355}