1use std::collections::HashMap;
10
11use surrealdb::Surreal;
12
13use super::embed::Embedder;
14use super::error::GraphError;
15use super::store::Db;
16use super::types::*;
17
18pub async fn query(
20 db: &Surreal<Db>,
21 embedder: &dyn Embedder,
22 query_text: &str,
23 options: &QueryOptions,
24) -> Result<QueryResult, GraphError> {
25 let limit = if options.limit == 0 {
26 10
27 } else {
28 options.limit
29 };
30
31 let semantic_options = SearchOptions {
33 limit: limit * 2,
34 entity_type: options.entity_type.clone(),
35 keyword: options.keyword.clone(),
36 };
37 let semantic_results =
38 super::search::search_with_options(db, embedder, query_text, &semantic_options).await?;
39
40 let mut entity_map: HashMap<String, ScoredEntity> = HashMap::new();
42 for result in semantic_results {
43 entity_map.insert(result.entity.id_string(), result);
44 }
45
46 if options.graph_depth > 0 {
48 let top_n: Vec<(String, f64)> = {
49 let mut entries: Vec<_> = entity_map
50 .values()
51 .map(|e| (e.entity.id_string(), e.score))
52 .collect();
53 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54 entries.truncate(3); entries
56 };
57
58 for (parent_id, parent_score) in &top_n {
59 let parent_name = entity_map
60 .get(parent_id)
61 .map(|e| e.entity.name.clone())
62 .unwrap_or_default();
63
64 let neighbors = get_neighbor_details(db, parent_id).await?;
65
66 for (neighbor, rel_type, confidence) in neighbors {
67 let neighbor_id = neighbor.id_string();
68 if entity_map.contains_key(&neighbor_id) {
69 continue; }
71
72 if let Some(ref et) = options.entity_type {
74 if neighbor.entity_type.to_string() != *et {
75 continue;
76 }
77 }
78
79 let graph_score = parent_score * confidence;
80 entity_map.insert(
81 neighbor_id,
82 ScoredEntity {
83 entity: neighbor,
84 score: graph_score,
85 source: MatchSource::Graph {
86 parent: parent_name.clone(),
87 rel_type,
88 },
89 },
90 );
91 }
92 }
93 }
94
95 let mut entities: Vec<ScoredEntity> = entity_map.into_values().collect();
97 entities.sort_by(|a, b| {
98 b.score
99 .partial_cmp(&a.score)
100 .unwrap_or(std::cmp::Ordering::Equal)
101 });
102 entities.truncate(limit);
103
104 let episodes = if options.include_episodes {
106 super::search::search_episodes(db, embedder, query_text, limit).await?
107 } else {
108 vec![]
109 };
110
111 Ok(QueryResult { entities, episodes })
112}
113
114async fn get_neighbor_details(
116 db: &Surreal<Db>,
117 entity_id: &str,
118) -> Result<Vec<(EntityDetail, String, f64)>, GraphError> {
119 let mut response = db
121 .query(
122 r#"
123 SELECT rel_type, confidence, out AS target_id
124 FROM relates_to
125 WHERE in = type::record($id) AND valid_until IS NONE AND confidence >= 0.1
126 "#,
127 )
128 .bind(("id", entity_id.to_string()))
129 .await?;
130
131 let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
132
133 let mut response = db
135 .query(
136 r#"
137 SELECT rel_type, confidence, in AS target_id
138 FROM relates_to
139 WHERE out = type::record($id) AND valid_until IS NONE AND confidence >= 0.1
140 "#,
141 )
142 .bind(("id", entity_id.to_string()))
143 .await?;
144
145 let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
146
147 let mut results = Vec::new();
148 let all_edges: Vec<_> = outgoing.into_iter().chain(incoming).collect();
149
150 for edge in all_edges {
151 let tid = match &edge.target_id {
152 serde_json::Value::String(s) => s.clone(),
153 other => other.to_string(),
154 };
155
156 if let Some(detail) = super::crud::get_entity_detail(db, &tid).await? {
157 results.push((detail, edge.rel_type, edge.confidence));
158 }
159 }
160
161 Ok(results)
162}
163
164fn default_rel_confidence() -> f64 {
165 1.0
166}
167
168#[derive(serde::Deserialize)]
169struct RelTarget {
170 rel_type: String,
171 target_id: serde_json::Value,
172 #[serde(default = "default_rel_confidence")]
173 confidence: f64,
174}
175
176pub async fn pipeline_entities(
180 db: &Surreal<Db>,
181 stage: &str,
182 status: Option<&str>,
183) -> Result<Vec<EntityDetail>, GraphError> {
184 let query = match status {
185 Some(_) => {
186 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
187 FROM entity
188 WHERE attributes.pipeline_stage = $stage
189 AND attributes.pipeline_status = $status
190 ORDER BY updated_at DESC"#
191 }
192 None => {
193 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
194 FROM entity
195 WHERE attributes.pipeline_stage = $stage
196 ORDER BY updated_at DESC"#
197 }
198 };
199
200 let stage_owned = stage.to_string();
201 let mut response = match status {
202 Some(s) => {
203 let status_owned = s.to_string();
204 db.query(query)
205 .bind(("stage", stage_owned))
206 .bind(("status", status_owned))
207 .await?
208 }
209 None => db.query(query).bind(("stage", stage_owned)).await?,
210 };
211
212 let entities: Vec<EntityDetail> = super::deserialize_take(&mut response, 0)?;
213 Ok(entities)
214}
215
216pub async fn pipeline_stats(
218 db: &Surreal<Db>,
219 staleness_days: u32,
220) -> Result<PipelineGraphStats, GraphError> {
221 let mut response = db
223 .query(
224 r#"SELECT
225 attributes.pipeline_stage AS stage,
226 attributes.pipeline_status AS status,
227 count() AS count
228 FROM entity
229 WHERE attributes.pipeline_stage IS NOT NONE
230 GROUP BY attributes.pipeline_stage, attributes.pipeline_status"#,
231 )
232 .await?;
233
234 let rows: Vec<StageStatusCount> = super::deserialize_take(&mut response, 0)?;
235
236 let mut by_stage: std::collections::HashMap<String, std::collections::HashMap<String, u64>> =
237 std::collections::HashMap::new();
238 let mut total = 0u64;
239
240 for row in rows {
241 total += row.count;
242 by_stage
243 .entry(row.stage)
244 .or_default()
245 .insert(row.status, row.count);
246 }
247
248 let mut stale_response = db
250 .query(
251 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
252 FROM entity
253 WHERE attributes.pipeline_stage = 'thoughts'
254 AND attributes.pipeline_status = 'active'
255 AND updated_at < time::now() - type::duration($threshold)
256 ORDER BY updated_at ASC"#,
257 )
258 .bind(("threshold", format!("{}d", staleness_days)))
259 .await?;
260
261 let stale_thoughts: Vec<EntityDetail> = super::deserialize_take(&mut stale_response, 0)?;
262
263 let mut stale_q_response = db
265 .query(
266 r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
267 FROM entity
268 WHERE attributes.pipeline_stage = 'curiosity'
269 AND attributes.pipeline_status = 'active'
270 AND attributes.sub_type IS NONE
271 AND updated_at < time::now() - type::duration($threshold)
272 ORDER BY updated_at ASC"#,
273 )
274 .bind(("threshold", format!("{}d", staleness_days * 2)))
275 .await?;
276
277 let stale_questions: Vec<EntityDetail> = super::deserialize_take(&mut stale_q_response, 0)?;
278
279 let mut movement_response = db
281 .query(
282 r#"SELECT updated_at
283 FROM entity
284 WHERE attributes.pipeline_status IN ['graduated', 'dissolved', 'explored']
285 ORDER BY updated_at DESC
286 LIMIT 1"#,
287 )
288 .await?;
289
290 let movement_rows: Vec<UpdatedAtRow> = super::deserialize_take(&mut movement_response, 0)?;
291 let last_movement = movement_rows.first().map(|r| match &r.updated_at {
292 serde_json::Value::String(s) => s.clone(),
293 other => other.to_string(),
294 });
295
296 Ok(PipelineGraphStats {
297 by_stage,
298 stale_thoughts,
299 stale_questions,
300 total_entities: total,
301 last_movement,
302 })
303}
304
305pub async fn pipeline_flow(
307 db: &Surreal<Db>,
308 entity_name: &str,
309) -> Result<Vec<(EntityDetail, String, EntityDetail)>, GraphError> {
310 let entity = super::crud::get_entity_by_name(db, entity_name)
312 .await?
313 .ok_or_else(|| GraphError::NotFound(format!("entity: {}", entity_name)))?;
314
315 let entity_id = entity.id_string();
316 let mut chain = Vec::new();
317
318 let pipeline_rel_types = [
320 "EVOLVED_FROM",
321 "CRYSTALLIZED_FROM",
322 "INFORMED_BY",
323 "GRADUATED_TO",
324 "CONNECTED_TO",
325 "EXPLORES",
326 "ARCHIVED_FROM",
327 ];
328 let rel_types_str = pipeline_rel_types
329 .iter()
330 .map(|r| format!("'{}'", r))
331 .collect::<Vec<_>>()
332 .join(", ");
333
334 let query_out = format!(
336 r#"SELECT rel_type, out AS target_id
337 FROM relates_to
338 WHERE in = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
339 rel_types_str
340 );
341 let mut response = db.query(&query_out).bind(("id", entity_id.clone())).await?;
342 let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
343
344 for edge in &outgoing {
345 let tid = match &edge.target_id {
346 serde_json::Value::String(s) => s.clone(),
347 other => other.to_string(),
348 };
349 if let Some(target) = super::crud::get_entity_detail(db, &tid).await? {
350 let source_detail = super::crud::get_entity_detail(db, &entity_id)
351 .await?
352 .unwrap();
353 chain.push((source_detail, edge.rel_type.clone(), target));
354 }
355 }
356
357 let query_in = format!(
359 r#"SELECT rel_type, in AS target_id
360 FROM relates_to
361 WHERE out = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
362 rel_types_str
363 );
364 let mut response = db.query(&query_in).bind(("id", entity_id.clone())).await?;
365 let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
366
367 for edge in &incoming {
368 let tid = match &edge.target_id {
369 serde_json::Value::String(s) => s.clone(),
370 other => other.to_string(),
371 };
372 if let Some(source) = super::crud::get_entity_detail(db, &tid).await? {
373 let target_detail = super::crud::get_entity_detail(db, &entity_id)
374 .await?
375 .unwrap();
376 chain.push((source, edge.rel_type.clone(), target_detail));
377 }
378 }
379
380 Ok(chain)
381}
382
383#[derive(serde::Deserialize)]
384struct StageStatusCount {
385 stage: String,
386 status: String,
387 count: u64,
388}
389
390#[derive(serde::Deserialize)]
391struct UpdatedAtRow {
392 updated_at: serde_json::Value,
393}