Skip to main content

engram/context/
search.rs

1//! Scoped Operational Context search.
2
3use chrono::{DateTime, NaiveDateTime, Utc};
4use rusqlite::{params_from_iter, types::Value as SqlValue, Connection, Row};
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value};
7
8use crate::error::Result;
9
10const DEFAULT_MAX_RESULTS: usize = 25;
11const MAX_RESULTS: usize = 200;
12const DEFAULT_STALE_AFTER_DAYS: i64 = 7;
13
14#[derive(Debug, Clone, Default, Deserialize)]
15pub struct ContextSearchRequest {
16    #[serde(default)]
17    pub query: Option<String>,
18    #[serde(default)]
19    pub repo_id: Option<String>,
20    #[serde(default)]
21    pub workspace_path_hash: Option<String>,
22    #[serde(default)]
23    pub workspace: Option<String>,
24    #[serde(default)]
25    pub session_id: Option<String>,
26    #[serde(default)]
27    pub task_id: Option<String>,
28    #[serde(default)]
29    pub event_type: Option<String>,
30    #[serde(default)]
31    pub event_types: Vec<String>,
32    #[serde(default)]
33    pub event_type_filters: Vec<String>,
34    #[serde(default)]
35    pub failure_only: bool,
36    #[serde(default)]
37    pub max_results: Option<usize>,
38    #[serde(default)]
39    pub include_artifact_pointers: bool,
40    #[serde(default)]
41    pub current_git_branch: Option<String>,
42    #[serde(default)]
43    pub current_commit_hash: Option<String>,
44    #[serde(default)]
45    pub stale_after_days: Option<i64>,
46}
47
48#[derive(Debug, Clone, Serialize)]
49pub struct ContextSearchResponse {
50    pub query: Option<String>,
51    pub count: usize,
52    pub max_results: usize,
53    pub scope: ContextScopeView,
54    pub filters: ContextFilterView,
55    pub results: Vec<ContextSearchItem>,
56}
57
58#[derive(Debug, Clone, Serialize)]
59pub struct ContextScopeView {
60    pub repo_id: Option<String>,
61    pub workspace_path_hash: Option<String>,
62    pub session_id: Option<String>,
63    pub task_id: Option<String>,
64    pub isolation_applied: bool,
65}
66
67#[derive(Debug, Clone, Serialize)]
68pub struct ContextFilterView {
69    pub event_types: Vec<String>,
70    pub failure_only: bool,
71    pub include_artifact_pointers: bool,
72    pub stale_after_days: i64,
73}
74
75#[derive(Debug, Clone, Serialize)]
76pub struct ContextSearchItem {
77    pub result_type: String,
78    pub relevance_score: f64,
79    pub event: ContextEventView,
80    pub summary: Option<ContextSummaryView>,
81    pub metadata_keys: Vec<String>,
82    pub extracted_files: Vec<String>,
83    pub artifact_pointers: Vec<ArtifactPointer>,
84    pub staleness: Vec<StalenessWarning>,
85    pub provenance: ContextProvenance,
86}
87
88#[derive(Debug, Clone, Serialize)]
89pub struct ContextEventView {
90    pub id: i64,
91    pub repo_id: Option<String>,
92    pub workspace_path_hash: Option<String>,
93    pub git_branch: Option<String>,
94    pub worktree_name: Option<String>,
95    pub commit_hash: Option<String>,
96    pub session_id: String,
97    pub task_id: Option<String>,
98    pub agent_id: Option<String>,
99    pub source: String,
100    pub event_type: String,
101    pub command_name: Option<String>,
102    pub tool_name: Option<String>,
103    pub cwd: Option<String>,
104    pub exit_code: Option<i64>,
105    pub started_at: String,
106    pub finished_at: Option<String>,
107    pub redaction_status: String,
108    pub retention_policy: String,
109    pub metadata: Value,
110    pub created_at: String,
111}
112
113#[derive(Debug, Clone, Serialize)]
114pub struct ContextSummaryView {
115    pub id: i64,
116    pub source_event_id: i64,
117    pub source_artifact_id: Option<String>,
118    pub reducer_name: String,
119    pub reducer_version: String,
120    pub derived: bool,
121    pub lossy: bool,
122    pub confidence: f64,
123    pub summary: String,
124    pub structured_facts: Value,
125    pub warnings: Value,
126    pub tokens_raw_est: Option<i64>,
127    pub tokens_compact_est: Option<i64>,
128    pub created_at: String,
129}
130
131#[derive(Debug, Clone, Serialize)]
132pub struct ArtifactPointer {
133    pub pointer_type: String,
134    pub artifact_id: String,
135    pub event_id: i64,
136    pub summary_id: Option<i64>,
137    pub provenance: ContextProvenance,
138}
139
140#[derive(Debug, Clone, Serialize)]
141pub struct StalenessWarning {
142    pub kind: String,
143    pub message: String,
144    pub observed: Option<String>,
145    pub current: Option<String>,
146    pub age_days: Option<i64>,
147}
148
149#[derive(Debug, Clone, Serialize)]
150pub struct ContextProvenance {
151    pub event_id: i64,
152    pub summary_id: Option<i64>,
153    pub repo_id: Option<String>,
154    pub workspace_path_hash: Option<String>,
155    pub session_id: String,
156    pub task_id: Option<String>,
157    pub source: String,
158    pub started_at: String,
159    pub summary_created_at: Option<String>,
160    pub reducer_name: Option<String>,
161    pub reducer_version: Option<String>,
162    pub lossy: Option<bool>,
163}
164
165pub fn search_context(
166    conn: &Connection,
167    request: &ContextSearchRequest,
168) -> Result<ContextSearchResponse> {
169    let max_results = request
170        .max_results
171        .unwrap_or(DEFAULT_MAX_RESULTS)
172        .clamp(1, MAX_RESULTS);
173    let stale_after_days = request
174        .stale_after_days
175        .unwrap_or(DEFAULT_STALE_AFTER_DAYS)
176        .max(1);
177    let query = normalized_nonempty(request.query.as_deref());
178    let workspace_path_hash = request
179        .workspace_path_hash
180        .clone()
181        .or_else(|| request.workspace.clone());
182    let event_types = normalized_event_types(request);
183
184    let mut clauses = Vec::new();
185    let mut params = Vec::new();
186
187    if let Some(repo_id) = normalized_nonempty(request.repo_id.as_deref()) {
188        clauses.push("e.repo_id = ?".to_string());
189        params.push(SqlValue::Text(repo_id));
190    }
191    if let Some(workspace_hash) = normalized_nonempty(workspace_path_hash.as_deref()) {
192        clauses.push("e.workspace_path_hash = ?".to_string());
193        params.push(SqlValue::Text(workspace_hash));
194    }
195    if let Some(session_id) = normalized_nonempty(request.session_id.as_deref()) {
196        clauses.push("e.session_id = ?".to_string());
197        params.push(SqlValue::Text(session_id));
198    }
199    if let Some(task_id) = normalized_nonempty(request.task_id.as_deref()) {
200        clauses.push("e.task_id = ?".to_string());
201        params.push(SqlValue::Text(task_id));
202    }
203    if !event_types.is_empty() {
204        let placeholders = vec!["?"; event_types.len()].join(", ");
205        clauses.push(format!("e.event_type IN ({placeholders})"));
206        for event_type in &event_types {
207            params.push(SqlValue::Text(event_type.clone()));
208        }
209    }
210    if request.failure_only {
211        clauses.push(
212            "(e.exit_code IS NOT NULL AND e.exit_code <> 0
213              OR lower(e.event_type) LIKE '%fail%'
214              OR lower(e.event_type) LIKE '%error%')"
215                .to_string(),
216        );
217    }
218    if let Some(ref q) = query {
219        let like = format!("%{}%", q.to_lowercase());
220        let mut query_clause = String::from(
221            "(lower(e.source) LIKE ?
222              OR lower(e.event_type) LIKE ?
223              OR lower(COALESCE(e.command_name, '')) LIKE ?
224              OR lower(COALESCE(e.tool_name, '')) LIKE ?
225              OR lower(COALESCE(e.cwd, '')) LIKE ?
226              OR lower(COALESCE(e.raw_artifact_id, '')) LIKE ?
227              OR lower(e.metadata) LIKE ?
228              OR lower(COALESCE(s.source_artifact_id, '')) LIKE ?
229              OR lower(COALESCE(s.summary, '')) LIKE ?
230              OR lower(COALESCE(s.structured_facts, '')) LIKE ?
231              OR lower(COALESCE(s.warnings, '')) LIKE ?",
232        );
233        if query_mentions_failure(q) {
234            query_clause.push_str(
235                " OR (e.exit_code IS NOT NULL AND e.exit_code <> 0)
236                  OR lower(e.event_type) LIKE '%fail%'
237                  OR lower(e.event_type) LIKE '%error%'",
238            );
239        }
240        query_clause.push(')');
241        clauses.push(query_clause);
242        for _ in 0..11 {
243            params.push(SqlValue::Text(like.clone()));
244        }
245    }
246
247    let where_sql = if clauses.is_empty() {
248        "1 = 1".to_string()
249    } else {
250        clauses.join(" AND ")
251    };
252    let sql = format!(
253        "SELECT
254             e.id, e.repo_id, e.workspace_path_hash, e.git_branch, e.worktree_name,
255             e.commit_hash, e.session_id, e.task_id, e.agent_id, e.source, e.event_type,
256             e.command_name, e.tool_name, e.cwd, e.exit_code, e.started_at, e.finished_at,
257             e.redaction_status, e.retention_policy, e.raw_artifact_id, e.metadata,
258             e.created_at,
259             s.id, s.source_artifact_id, s.reducer_name, s.reducer_version, s.lossy,
260             s.confidence, s.summary, s.structured_facts, s.warnings, s.tokens_raw_est,
261             s.tokens_compact_est, s.created_at
262         FROM context_events e
263         LEFT JOIN context_summaries s ON s.source_event_id = e.id
264         WHERE {where_sql}
265         ORDER BY e.started_at DESC, e.id DESC, s.created_at DESC
266         LIMIT ?"
267    );
268    params.push(SqlValue::Integer(max_results as i64));
269
270    let mut stmt = conn.prepare(&sql)?;
271    let mut rows = stmt.query(params_from_iter(params))?;
272    let mut results = Vec::new();
273    while let Some(row) = rows.next()? {
274        results.push(row_to_search_item(
275            row,
276            query.as_deref(),
277            request.include_artifact_pointers,
278            request.current_git_branch.as_deref(),
279            request.current_commit_hash.as_deref(),
280            stale_after_days,
281        )?);
282    }
283
284    Ok(ContextSearchResponse {
285        query,
286        count: results.len(),
287        max_results,
288        scope: ContextScopeView {
289            repo_id: request.repo_id.clone(),
290            workspace_path_hash,
291            session_id: request.session_id.clone(),
292            task_id: request.task_id.clone(),
293            isolation_applied: request.repo_id.is_some()
294                || request.workspace_path_hash.is_some()
295                || request.workspace.is_some()
296                || request.session_id.is_some()
297                || request.task_id.is_some(),
298        },
299        filters: ContextFilterView {
300            event_types,
301            failure_only: request.failure_only,
302            include_artifact_pointers: request.include_artifact_pointers,
303            stale_after_days,
304        },
305        results,
306    })
307}
308
309fn row_to_search_item(
310    row: &Row<'_>,
311    query: Option<&str>,
312    include_artifact_pointers: bool,
313    current_git_branch: Option<&str>,
314    current_commit_hash: Option<&str>,
315    stale_after_days: i64,
316) -> rusqlite::Result<ContextSearchItem> {
317    let metadata_raw: String = row.get(20)?;
318    let metadata = parse_json_or(&metadata_raw, json!({}));
319    let event = ContextEventView {
320        id: row.get(0)?,
321        repo_id: row.get(1)?,
322        workspace_path_hash: row.get(2)?,
323        git_branch: row.get(3)?,
324        worktree_name: row.get(4)?,
325        commit_hash: row.get(5)?,
326        session_id: row.get(6)?,
327        task_id: row.get(7)?,
328        agent_id: row.get(8)?,
329        source: row.get(9)?,
330        event_type: row.get(10)?,
331        command_name: row.get(11)?,
332        tool_name: row.get(12)?,
333        cwd: row.get(13)?,
334        exit_code: row.get(14)?,
335        started_at: row.get(15)?,
336        finished_at: row.get(16)?,
337        redaction_status: row.get(17)?,
338        retention_policy: row.get(18)?,
339        metadata,
340        created_at: row.get(21)?,
341    };
342    let raw_artifact_id: Option<String> = row.get(19)?;
343    let summary_id: Option<i64> = row.get(22)?;
344    let summary = if let Some(id) = summary_id {
345        let structured_facts_raw: Option<String> = row.get(29)?;
346        let warnings_raw: Option<String> = row.get(30)?;
347        let lossy_int: Option<i64> = row.get(26)?;
348        Some(ContextSummaryView {
349            id,
350            source_event_id: event.id,
351            source_artifact_id: row.get(23)?,
352            reducer_name: row.get::<_, Option<String>>(24)?.unwrap_or_default(),
353            reducer_version: row.get::<_, Option<String>>(25)?.unwrap_or_default(),
354            derived: true,
355            lossy: lossy_int.unwrap_or(1) != 0,
356            confidence: row.get::<_, Option<f64>>(27)?.unwrap_or(0.0),
357            summary: row.get::<_, Option<String>>(28)?.unwrap_or_default(),
358            structured_facts: structured_facts_raw
359                .as_deref()
360                .map(|raw| parse_json_or(raw, json!({})))
361                .unwrap_or_else(|| json!({})),
362            warnings: warnings_raw
363                .as_deref()
364                .map(|raw| parse_json_or(raw, json!([])))
365                .unwrap_or_else(|| json!([])),
366            tokens_raw_est: row.get(31)?,
367            tokens_compact_est: row.get(32)?,
368            created_at: row.get::<_, Option<String>>(33)?.unwrap_or_default(),
369        })
370    } else {
371        None
372    };
373
374    let metadata_keys = metadata_keys(&event.metadata);
375    let extracted_files = extract_files_from_metadata(&event.metadata);
376    let artifact_pointers = artifact_pointers(
377        &event,
378        summary.as_ref(),
379        raw_artifact_id,
380        include_artifact_pointers,
381    );
382    let staleness = staleness_warnings(
383        &event,
384        current_git_branch,
385        current_commit_hash,
386        stale_after_days,
387    );
388    let provenance = provenance_for(&event, summary.as_ref());
389    let relevance_score = relevance_score(&event, summary.as_ref(), query);
390    let result_type = if summary.is_some() {
391        "summary".to_string()
392    } else {
393        "event".to_string()
394    };
395
396    Ok(ContextSearchItem {
397        result_type,
398        relevance_score,
399        event,
400        summary,
401        metadata_keys,
402        extracted_files,
403        artifact_pointers,
404        staleness,
405        provenance,
406    })
407}
408
409fn artifact_pointers(
410    event: &ContextEventView,
411    summary: Option<&ContextSummaryView>,
412    raw_artifact_id: Option<String>,
413    include: bool,
414) -> Vec<ArtifactPointer> {
415    if !include {
416        return Vec::new();
417    }
418    let mut pointers = Vec::new();
419    if let Some(artifact_id) = raw_artifact_id {
420        pointers.push(ArtifactPointer {
421            pointer_type: "event_raw_artifact".to_string(),
422            artifact_id,
423            event_id: event.id,
424            summary_id: None,
425            provenance: provenance_for(event, None),
426        });
427    }
428    if let Some(summary) = summary {
429        if let Some(artifact_id) = &summary.source_artifact_id {
430            pointers.push(ArtifactPointer {
431                pointer_type: "summary_source_artifact".to_string(),
432                artifact_id: artifact_id.clone(),
433                event_id: event.id,
434                summary_id: Some(summary.id),
435                provenance: provenance_for(event, Some(summary)),
436            });
437        }
438    }
439    pointers
440}
441
442fn staleness_warnings(
443    event: &ContextEventView,
444    current_git_branch: Option<&str>,
445    current_commit_hash: Option<&str>,
446    stale_after_days: i64,
447) -> Vec<StalenessWarning> {
448    let mut warnings = Vec::new();
449    if let (Some(observed), Some(current)) = (event.git_branch.as_deref(), current_git_branch) {
450        if !observed.is_empty() && !current.is_empty() && observed != current {
451            warnings.push(StalenessWarning {
452                kind: "branch_mismatch".to_string(),
453                message: "Event was recorded on a different branch.".to_string(),
454                observed: Some(observed.to_string()),
455                current: Some(current.to_string()),
456                age_days: None,
457            });
458        }
459    }
460    if let (Some(observed), Some(current)) = (event.commit_hash.as_deref(), current_commit_hash) {
461        if !observed.is_empty() && !current.is_empty() && observed != current {
462            warnings.push(StalenessWarning {
463                kind: "commit_mismatch".to_string(),
464                message: "Event was recorded at a different commit.".to_string(),
465                observed: Some(observed.to_string()),
466                current: Some(current.to_string()),
467                age_days: None,
468            });
469        }
470    }
471    if let Some(started_at) = parse_time(&event.started_at) {
472        let age_days = Utc::now()
473            .signed_duration_since(started_at)
474            .num_days()
475            .max(0);
476        if age_days >= stale_after_days {
477            warnings.push(StalenessWarning {
478                kind: "age".to_string(),
479                message: format!("Event is at least {stale_after_days} days old."),
480                observed: Some(event.started_at.clone()),
481                current: Some(Utc::now().to_rfc3339()),
482                age_days: Some(age_days),
483            });
484        }
485    }
486    warnings
487}
488
489fn provenance_for(
490    event: &ContextEventView,
491    summary: Option<&ContextSummaryView>,
492) -> ContextProvenance {
493    ContextProvenance {
494        event_id: event.id,
495        summary_id: summary.map(|s| s.id),
496        repo_id: event.repo_id.clone(),
497        workspace_path_hash: event.workspace_path_hash.clone(),
498        session_id: event.session_id.clone(),
499        task_id: event.task_id.clone(),
500        source: event.source.clone(),
501        started_at: event.started_at.clone(),
502        summary_created_at: summary.map(|s| s.created_at.clone()),
503        reducer_name: summary.map(|s| s.reducer_name.clone()),
504        reducer_version: summary.map(|s| s.reducer_version.clone()),
505        lossy: summary.map(|s| s.lossy),
506    }
507}
508
509fn relevance_score(
510    event: &ContextEventView,
511    summary: Option<&ContextSummaryView>,
512    query: Option<&str>,
513) -> f64 {
514    let Some(query) = query else {
515        return 0.0;
516    };
517    let q = query.to_lowercase();
518    let mut score = 0.0;
519    score += contains_score(&event.event_type, &q, 2.0);
520    score += contains_score(&event.source, &q, 1.0);
521    score += contains_score(event.command_name.as_deref().unwrap_or(""), &q, 2.0);
522    score += contains_score(event.tool_name.as_deref().unwrap_or(""), &q, 2.0);
523    score += contains_score(event.cwd.as_deref().unwrap_or(""), &q, 0.5);
524    score += contains_score(&event.metadata.to_string(), &q, 1.0);
525    if let Some(summary) = summary {
526        score += contains_score(&summary.summary, &q, 3.0);
527        score += contains_score(&summary.structured_facts.to_string(), &q, 1.5);
528        score += summary.confidence.min(1.0);
529    }
530    score
531}
532
533fn contains_score(haystack: &str, needle: &str, weight: f64) -> f64 {
534    if haystack.to_lowercase().contains(needle) {
535        weight
536    } else {
537        0.0
538    }
539}
540
541fn query_mentions_failure(query: &str) -> bool {
542    contains_score(query, "fail", 1.0) > 0.0
543        || contains_score(query, "error", 1.0) > 0.0
544        || contains_score(query, "failure", 1.0) > 0.0
545}
546
547fn normalized_event_types(request: &ContextSearchRequest) -> Vec<String> {
548    let mut event_types = Vec::new();
549    if let Some(value) = request.event_type.as_deref() {
550        let trimmed = value.trim();
551        if !trimmed.is_empty() {
552            event_types.push(trimmed.to_string());
553        }
554    }
555    for value in request
556        .event_types
557        .iter()
558        .chain(request.event_type_filters.iter())
559    {
560        let trimmed = value.trim();
561        if !trimmed.is_empty() && !event_types.iter().any(|v| v == trimmed) {
562            event_types.push(trimmed.to_string());
563        }
564    }
565    event_types
566}
567
568fn normalized_nonempty(value: Option<&str>) -> Option<String> {
569    value
570        .map(str::trim)
571        .filter(|s| !s.is_empty())
572        .map(str::to_string)
573}
574
575fn metadata_keys(metadata: &Value) -> Vec<String> {
576    match metadata {
577        Value::Object(map) => map.keys().cloned().collect(),
578        _ => Vec::new(),
579    }
580}
581
582fn extract_files_from_metadata(metadata: &Value) -> Vec<String> {
583    let mut files = Vec::new();
584    collect_file_values(None, metadata, &mut files);
585    files.sort();
586    files.dedup();
587    files
588}
589
590fn collect_file_values(key: Option<&str>, value: &Value, out: &mut Vec<String>) {
591    match value {
592        Value::String(s) => {
593            if key.map(is_file_key).unwrap_or(false) || looks_like_path(s) {
594                out.push(s.clone());
595            }
596        }
597        Value::Array(items) => {
598            for item in items {
599                collect_file_values(key, item, out);
600            }
601        }
602        Value::Object(map) => {
603            for (child_key, child_value) in map {
604                collect_file_values(Some(child_key), child_value, out);
605            }
606        }
607        _ => {}
608    }
609}
610
611fn is_file_key(key: &str) -> bool {
612    let key = key.to_lowercase();
613    key.contains("file") || key.contains("path")
614}
615
616fn looks_like_path(value: &str) -> bool {
617    let value = value.trim();
618    if value.is_empty() || value.len() > 512 || value.contains('\n') {
619        return false;
620    }
621    value.contains('/')
622        || value.ends_with(".rs")
623        || value.ends_with(".md")
624        || value.ends_with(".toml")
625        || value.ends_with(".json")
626        || value.ends_with(".yaml")
627        || value.ends_with(".yml")
628        || value.ends_with(".ts")
629        || value.ends_with(".tsx")
630        || value.ends_with(".py")
631}
632
633fn parse_json_or(raw: &str, fallback: Value) -> Value {
634    serde_json::from_str(raw).unwrap_or(fallback)
635}
636
637fn parse_time(value: &str) -> Option<DateTime<Utc>> {
638    DateTime::parse_from_rfc3339(value)
639        .map(|dt| dt.with_timezone(&Utc))
640        .ok()
641        .or_else(|| {
642            NaiveDateTime::parse_from_str(value, "%Y-%m-%d %H:%M:%S")
643                .map(|dt| dt.and_utc())
644                .ok()
645        })
646}