1use std::collections::HashMap;
10
11use surrealdb::Surreal;
12
13use super::embed::Embedder;
14use super::error::GraphError;
15use super::store::Db;
16use super::types::*;
17
18pub async fn query(
20 db: &Surreal<Db>,
21 embedder: &dyn Embedder,
22 query_text: &str,
23 options: &QueryOptions,
24) -> Result<QueryResult, GraphError> {
25 let limit = if options.limit == 0 {
26 10
27 } else {
28 options.limit
29 };
30
31 let semantic_options = SearchOptions {
33 limit: limit * 2,
34 entity_type: options.entity_type.clone(),
35 keyword: options.keyword.clone(),
36 };
37 let semantic_results =
38 super::search::search_with_options(db, embedder, query_text, &semantic_options).await?;
39
40 let mut entity_map: HashMap<String, ScoredEntity> = HashMap::new();
42 for result in semantic_results {
43 entity_map.insert(result.entity.id_string(), result);
44 }
45
46 if options.graph_depth > 0 {
48 let top_n: Vec<(String, f64)> = {
49 let mut entries: Vec<_> = entity_map
50 .values()
51 .map(|e| (e.entity.id_string(), e.score))
52 .collect();
53 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54 entries.truncate(3); entries
56 };
57
58 for (parent_id, parent_score) in &top_n {
59 let parent_name = entity_map
60 .get(parent_id)
61 .map(|e| e.entity.name.clone())
62 .unwrap_or_default();
63
64 let neighbors = get_neighbor_details(db, parent_id).await?;
65
66 for (neighbor, rel_type) in neighbors {
67 let neighbor_id = neighbor.id_string();
68 if entity_map.contains_key(&neighbor_id) {
69 continue; }
71
72 if let Some(ref et) = options.entity_type {
74 if neighbor.entity_type.to_string() != *et {
75 continue;
76 }
77 }
78
79 let graph_score = parent_score * 0.5;
80 entity_map.insert(
81 neighbor_id,
82 ScoredEntity {
83 entity: neighbor,
84 score: graph_score,
85 source: MatchSource::Graph {
86 parent: parent_name.clone(),
87 rel_type,
88 },
89 },
90 );
91 }
92 }
93 }
94
95 let mut entities: Vec<ScoredEntity> = entity_map.into_values().collect();
97 entities.sort_by(|a, b| {
98 b.score
99 .partial_cmp(&a.score)
100 .unwrap_or(std::cmp::Ordering::Equal)
101 });
102 entities.truncate(limit);
103
104 let episodes = if options.include_episodes {
106 super::search::search_episodes(db, embedder, query_text, limit).await?
107 } else {
108 vec![]
109 };
110
111 Ok(QueryResult { entities, episodes })
112}
113
114async fn get_neighbor_details(
116 db: &Surreal<Db>,
117 entity_id: &str,
118) -> Result<Vec<(EntityDetail, String)>, GraphError> {
119 let mut response = db
121 .query(
122 r#"
123 SELECT rel_type, out AS target_id
124 FROM relates_to
125 WHERE in = type::record($id) AND valid_until IS NONE
126 "#,
127 )
128 .bind(("id", entity_id.to_string()))
129 .await?;
130
131 let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
132
133 let mut response = db
135 .query(
136 r#"
137 SELECT rel_type, in AS target_id
138 FROM relates_to
139 WHERE out = type::record($id) AND valid_until IS NONE
140 "#,
141 )
142 .bind(("id", entity_id.to_string()))
143 .await?;
144
145 let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
146
147 let mut results = Vec::new();
148 let all_edges: Vec<_> = outgoing.into_iter().chain(incoming).collect();
149
150 for edge in all_edges {
151 let tid = match &edge.target_id {
152 serde_json::Value::String(s) => s.clone(),
153 other => other.to_string(),
154 };
155
156 if let Some(detail) = super::crud::get_entity_detail(db, &tid).await? {
157 results.push((detail, edge.rel_type));
158 }
159 }
160
161 Ok(results)
162}
163
164#[derive(serde::Deserialize)]
165struct RelTarget {
166 rel_type: String,
167 target_id: serde_json::Value,
168}
169
170pub async fn pipeline_entities(
174 db: &Surreal<Db>,
175 stage: &str,
176 status: Option<&str>,
177) -> Result<Vec<EntityDetail>, GraphError> {
178 let query = match status {
179 Some(_) => {
180 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
181 FROM entity
182 WHERE attributes.pipeline_stage = $stage
183 AND attributes.pipeline_status = $status
184 ORDER BY updated_at DESC"#
185 }
186 None => {
187 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
188 FROM entity
189 WHERE attributes.pipeline_stage = $stage
190 ORDER BY updated_at DESC"#
191 }
192 };
193
194 let stage_owned = stage.to_string();
195 let mut response = match status {
196 Some(s) => {
197 let status_owned = s.to_string();
198 db.query(query)
199 .bind(("stage", stage_owned))
200 .bind(("status", status_owned))
201 .await?
202 }
203 None => db.query(query).bind(("stage", stage_owned)).await?,
204 };
205
206 let entities: Vec<EntityDetail> = super::deserialize_take(&mut response, 0)?;
207 Ok(entities)
208}
209
210pub async fn pipeline_stats(
212 db: &Surreal<Db>,
213 staleness_days: u32,
214) -> Result<PipelineGraphStats, GraphError> {
215 let mut response = db
217 .query(
218 r#"SELECT
219 attributes.pipeline_stage AS stage,
220 attributes.pipeline_status AS status,
221 count() AS count
222 FROM entity
223 WHERE attributes.pipeline_stage IS NOT NONE
224 GROUP BY attributes.pipeline_stage, attributes.pipeline_status"#,
225 )
226 .await?;
227
228 let rows: Vec<StageStatusCount> = super::deserialize_take(&mut response, 0)?;
229
230 let mut by_stage: std::collections::HashMap<String, std::collections::HashMap<String, u64>> =
231 std::collections::HashMap::new();
232 let mut total = 0u64;
233
234 for row in rows {
235 total += row.count;
236 by_stage
237 .entry(row.stage)
238 .or_default()
239 .insert(row.status, row.count);
240 }
241
242 let mut stale_response = db
244 .query(
245 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
246 FROM entity
247 WHERE attributes.pipeline_stage = 'thoughts'
248 AND attributes.pipeline_status = 'active'
249 AND updated_at < time::now() - type::duration($threshold)
250 ORDER BY updated_at ASC"#,
251 )
252 .bind(("threshold", format!("{}d", staleness_days)))
253 .await?;
254
255 let stale_thoughts: Vec<EntityDetail> = super::deserialize_take(&mut stale_response, 0)?;
256
257 let mut stale_q_response = db
259 .query(
260 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
261 FROM entity
262 WHERE attributes.pipeline_stage = 'curiosity'
263 AND attributes.pipeline_status = 'active'
264 AND attributes.sub_type IS NONE
265 AND updated_at < time::now() - type::duration($threshold)
266 ORDER BY updated_at ASC"#,
267 )
268 .bind(("threshold", format!("{}d", staleness_days * 2)))
269 .await?;
270
271 let stale_questions: Vec<EntityDetail> = super::deserialize_take(&mut stale_q_response, 0)?;
272
273 let mut movement_response = db
275 .query(
276 r#"SELECT updated_at
277 FROM entity
278 WHERE attributes.pipeline_status IN ['graduated', 'dissolved', 'explored']
279 ORDER BY updated_at DESC
280 LIMIT 1"#,
281 )
282 .await?;
283
284 let movement_rows: Vec<UpdatedAtRow> = super::deserialize_take(&mut movement_response, 0)?;
285 let last_movement = movement_rows.first().map(|r| match &r.updated_at {
286 serde_json::Value::String(s) => s.clone(),
287 other => other.to_string(),
288 });
289
290 Ok(PipelineGraphStats {
291 by_stage,
292 stale_thoughts,
293 stale_questions,
294 total_entities: total,
295 last_movement,
296 })
297}
298
299pub async fn pipeline_flow(
301 db: &Surreal<Db>,
302 entity_name: &str,
303) -> Result<Vec<(EntityDetail, String, EntityDetail)>, GraphError> {
304 let entity = super::crud::get_entity_by_name(db, entity_name)
306 .await?
307 .ok_or_else(|| GraphError::NotFound(format!("entity: {}", entity_name)))?;
308
309 let entity_id = entity.id_string();
310 let mut chain = Vec::new();
311
312 let pipeline_rel_types = [
314 "EVOLVED_FROM",
315 "CRYSTALLIZED_FROM",
316 "INFORMED_BY",
317 "GRADUATED_TO",
318 "CONNECTED_TO",
319 "EXPLORES",
320 "ARCHIVED_FROM",
321 ];
322 let rel_types_str = pipeline_rel_types
323 .iter()
324 .map(|r| format!("'{}'", r))
325 .collect::<Vec<_>>()
326 .join(", ");
327
328 let query_out = format!(
330 r#"SELECT rel_type, out AS target_id
331 FROM relates_to
332 WHERE in = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
333 rel_types_str
334 );
335 let mut response = db.query(&query_out).bind(("id", entity_id.clone())).await?;
336 let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
337
338 for edge in &outgoing {
339 let tid = match &edge.target_id {
340 serde_json::Value::String(s) => s.clone(),
341 other => other.to_string(),
342 };
343 if let Some(target) = super::crud::get_entity_detail(db, &tid).await? {
344 let source_detail = super::crud::get_entity_detail(db, &entity_id)
345 .await?
346 .unwrap();
347 chain.push((source_detail, edge.rel_type.clone(), target));
348 }
349 }
350
351 let query_in = format!(
353 r#"SELECT rel_type, in AS target_id
354 FROM relates_to
355 WHERE out = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
356 rel_types_str
357 );
358 let mut response = db.query(&query_in).bind(("id", entity_id.clone())).await?;
359 let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
360
361 for edge in &incoming {
362 let tid = match &edge.target_id {
363 serde_json::Value::String(s) => s.clone(),
364 other => other.to_string(),
365 };
366 if let Some(source) = super::crud::get_entity_detail(db, &tid).await? {
367 let target_detail = super::crud::get_entity_detail(db, &entity_id)
368 .await?
369 .unwrap();
370 chain.push((source, edge.rel_type.clone(), target_detail));
371 }
372 }
373
374 Ok(chain)
375}
376
377#[derive(serde::Deserialize)]
378struct StageStatusCount {
379 stage: String,
380 status: String,
381 count: u64,
382}
383
384#[derive(serde::Deserialize)]
385struct UpdatedAtRow {
386 updated_at: serde_json::Value,
387}