cognee-search 0.1.3

Context retrieval (search) over the cognee knowledge graph and vector store.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
use std::collections::{HashMap, HashSet};

use cognee_embedding::EmbeddingEngine;
use cognee_graph::GraphDBTrait;
use cognee_vector::VectorDB;
use tracing::debug;

use crate::graph_retrieval::rank_edge_score;
use crate::types::SearchError;

const DEFAULT_WIDE_SEARCH_TOP_K: usize = 100;

/// Default cosine distance assigned to graph elements (nodes or edges) that have no
/// vector match for the current query. Matches Python's `triplet_distance_penalty`
/// default of 6.5 in
/// `cognee/modules/retrieval/utils/brute_force_triplet_search.py:56,227`.
pub const DEFAULT_TRIPLET_DISTANCE_PENALTY: f32 = 6.5;

/// Collections searched to find candidate graph nodes and edge-type distances.
/// Each entry is (data_type, field_name).
///
/// Mirrors Python's dynamic enumeration of all DataPoint subclass index collections
/// (graph_completion_retriever.py:88-99). Python reflects over every DataPoint
/// subclass at query time and includes `Triplet_text` when memify has populated it.
/// Rust uses a static list that covers the same set; the per-collection
/// `has_collection` guard below ensures collections absent from the store are skipped.
///
/// `Triplet_text` is included so that after `memify` runs, triplet vectors influence
/// graph search ranking exactly as they do in Python.
///
/// Note on Triplet point IDs: Triplet vector points are identified by
/// `generate_node_id(start_id + relationship_name + end_id)`, which is NOT the same
/// as any graph node ID. Python's `map_vector_distances_to_graph_nodes` silently skips
/// hits whose IDs don't match a known graph node (CogneeGraph.py:428-429). Rust does
/// the same: the candidate_node_ids set collects point IDs from all collections, and
/// only IDs that match an actual graph node endpoint end up contributing to edge scores.
/// Triplet hits that don't match a graph node ID are therefore harmless — they add the
/// point ID to candidate_node_ids but never match any graph edge endpoint, so no
/// spurious edges are surfaced.
const SEARCH_COLLECTIONS: [(&str, &str); 6] = [
    ("Entity", "name"),
    ("TextSummary", "text"),
    ("EntityType", "name"),
    ("DocumentChunk", "text"),
    ("EdgeType", "relationship_name"),
    ("Triplet", "text"),
];

#[derive(Debug, Clone)]
pub struct GraphRetrievalConfig {
    pub top_k: usize,
    pub wide_search_top_k: usize,
    /// Default cosine distance used for nodes/edges not found in vector search.
    /// Matches Python's `triplet_distance_penalty` semantics (default 6.5).
    pub triplet_distance_penalty: f32,
    /// How much per-node `feedback_weight` values influence triplet ranking.
    /// Must be in [0.0, 1.0]. 0.0 (default) means pure similarity-based ranking.
    pub feedback_influence: f32,
    /// Filter graph to nodes of this type before scoring.
    /// When combined with `node_name`, calls `get_nodeset_subgraph` instead of
    /// `get_graph_data`.
    pub node_type: Option<String>,
    /// Filter graph to nodes with these names (paired with `node_type`).
    pub node_name: Option<Vec<String>>,
    /// "OR" (default): include neighbors of ANY named node.
    /// "AND": include only neighbors connected to ALL named nodes.
    pub node_name_filter_operator: String,
}

impl Default for GraphRetrievalConfig {
    fn default() -> Self {
        Self {
            top_k: 10,
            wide_search_top_k: DEFAULT_WIDE_SEARCH_TOP_K,
            triplet_distance_penalty: DEFAULT_TRIPLET_DISTANCE_PENALTY,
            feedback_influence: 0.0,
            node_type: None,
            node_name: None,
            node_name_filter_operator: "OR".to_string(),
        }
    }
}

#[derive(Debug, Clone)]
pub struct RankedGraphEdge {
    pub source_id: String,
    pub target_id: String,
    pub relationship_name: String,
    /// Total triplet distance (lower = better match).
    /// Sum of source_node_distance + edge_distance + target_node_distance.
    pub score: f32,
    pub source_name: String,
    pub target_name: String,
    /// Dataset ID of the source or target entity, for context scoping.
    pub dataset_id: Option<String>,
    /// Text content of the source node (present on DocumentChunk nodes).
    pub source_text: Option<String>,
    /// Text content of the target node (present on DocumentChunk nodes).
    pub target_text: Option<String>,
    /// Description of the source node (present on Entity nodes).
    pub source_description: Option<String>,
    /// Description of the target node (present on Entity nodes).
    pub target_description: Option<String>,
}

#[tracing::instrument(
    name = "cognee.retrieval.graph_search",
    skip(graph_db, vector_db, embedding_engine, config),
    fields(
        cognee.result.count = tracing::field::Empty,
    )
)]
pub async fn brute_force_triplet_search(
    query: &str,
    vector_db: &dyn VectorDB,
    embedding_engine: &dyn EmbeddingEngine,
    graph_db: &dyn GraphDBTrait,
    config: &GraphRetrievalConfig,
) -> Result<Vec<RankedGraphEdge>, SearchError> {
    if config.feedback_influence < 0.0 || config.feedback_influence > 1.0 {
        return Err(SearchError::InvalidInput(
            "feedback_influence must be in range [0.0, 1.0]".to_string(),
        ));
    }

    let op = config.node_name_filter_operator.to_uppercase();
    if op != "AND" && op != "OR" {
        return Err(SearchError::InvalidInput(format!(
            "Invalid node_name_filter_operator: {:?}. Must be AND or OR.",
            config.node_name_filter_operator
        )));
    }

    let query_vectors = embedding_engine.embed(&[query]).await?;
    let query_vector = query_vectors.into_iter().next().ok_or_else(|| {
        SearchError::InvalidInput("embedding engine returned no vectors".to_string())
    })?;

    // node_id -> cosine distance (lower = better)
    let mut node_distances = HashMap::<String, f32>::new();
    let mut candidate_node_ids = HashSet::<String>::new();
    let mut node_dataset_ids = HashMap::<String, String>::new();

    // relationship_name -> cosine distance (lower = better)
    // Keyed by relationship_name because edge_type_id is NOT stored in graph edge
    // properties by cognify. The EdgeType vector points store relationship_name in
    // their metadata (confirmed in cognify tasks.rs).
    let mut edge_type_distances = HashMap::<String, f32>::new();

    for (data_type, field_name) in SEARCH_COLLECTIONS {
        if !vector_db.has_collection(data_type, field_name).await? {
            debug!("vector collection {data_type}/{field_name} does not exist — skipping");
            continue;
        }

        let results = vector_db
            .search_similar(
                data_type,
                field_name,
                &query_vector,
                config.wide_search_top_k,
            )
            .await?;

        for result in results {
            // Convert Qdrant cosine similarity to cosine distance: distance = 1 - similarity
            let distance = 1.0 - result.score;

            if data_type == "EdgeType" && field_name == "relationship_name" {
                // Edge distances keyed by relationship_name from vector point metadata.
                // edge_type_id is NOT stored in graph edge properties, so we key by
                // relationship_name to match graph edges at scoring time.
                if let Some(rel_name) = result
                    .metadata
                    .get("relationship_name")
                    .and_then(|v| v.as_str())
                {
                    let entry = edge_type_distances
                        .entry(rel_name.to_string())
                        .or_insert(distance);
                    if distance < *entry {
                        *entry = distance;
                    }
                }
            } else {
                // Node distances keyed by vector point ID.
                // Use min to merge across collections (lower distance = better match).
                let node_id = result.id.to_string();
                candidate_node_ids.insert(node_id.clone());
                let entry = node_distances.entry(node_id.clone()).or_insert(distance);
                if distance < *entry {
                    *entry = distance;
                }
                if let Some(dataset_id) = result.metadata.get("dataset_id").and_then(|v| v.as_str())
                {
                    node_dataset_ids
                        .entry(node_id)
                        .or_insert_with(|| dataset_id.to_string());
                }
            }
        }
    }

    if candidate_node_ids.is_empty() {
        debug!("no candidate nodes found from vector search — returning empty");
        tracing::Span::current().record("cognee.result.count", 0u64);
        return Ok(vec![]);
    }

    tracing::debug!(
        target: "cognee::search",
        wide_search_results = candidate_node_ids.len(),
        "Vector search complete"
    );

    let has_node_filter = config.node_type.is_some()
        && config
            .node_name
            .as_ref()
            .is_some_and(|names| !names.is_empty());

    let (graph_nodes, graph_edges) = if has_node_filter {
        #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
        let node_type = config
            .node_type
            .as_deref()
            .expect("node_type is checked non-None in has_node_filter");
        #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
        let node_names = config
            .node_name
            .as_deref()
            .expect("node_name is checked non-empty in has_node_filter");
        graph_db
            .get_nodeset_subgraph(node_type, node_names, &config.node_name_filter_operator)
            .await?
    } else {
        graph_db.get_graph_data().await?
    };

    // Extract name, text, description, and (optionally) feedback_weight from each node.
    let mut node_names: HashMap<String, String> = HashMap::new();
    let mut node_texts: HashMap<String, String> = HashMap::new();
    let mut node_descriptions: HashMap<String, String> = HashMap::new();
    let mut node_feedback_weights: HashMap<String, f32> = HashMap::new();

    for (node_id, properties) in graph_nodes {
        let name = properties
            .get("name")
            .and_then(|value| value.as_str())
            .unwrap_or(node_id.as_str())
            .to_string();
        node_names.insert(node_id.clone(), name);

        if let Some(text) = properties.get("text").and_then(|v| v.as_str()) {
            node_texts.insert(node_id.clone(), text.to_string());
        }
        if let Some(desc) = properties.get("description").and_then(|v| v.as_str()) {
            node_descriptions.insert(node_id.clone(), desc.to_string());
        }
        if config.feedback_influence > 0.0 {
            let fw = properties
                .get("feedback_weight")
                .and_then(|v| v.as_f64())
                .unwrap_or(0.5) as f32;
            node_feedback_weights.insert(node_id.clone(), fw);
        }
    }

    let default_penalty = config.triplet_distance_penalty;

    let mut ranked_edges = graph_edges
        .into_iter()
        .filter_map(|(source_id, target_id, relationship_name, _properties)| {
            // Only consider edges where at least one endpoint was found in vector search
            if !candidate_node_ids.contains(&source_id) && !candidate_node_ids.contains(&target_id)
            {
                return None;
            }

            // Unmatched nodes get the default penalty distance (not 0.0)
            let source_dist = node_distances
                .get(&source_id)
                .copied()
                .unwrap_or(default_penalty);
            let target_dist = node_distances
                .get(&target_id)
                .copied()
                .unwrap_or(default_penalty);

            // Look up edge distance by relationship_name.
            // Unmatched edge types also get the default penalty distance.
            let edge_dist = edge_type_distances
                .get(&relationship_name)
                .copied()
                .unwrap_or(default_penalty);

            let source_name = node_names
                .get(&source_id)
                .cloned()
                .unwrap_or(source_id.clone());
            let target_name = node_names
                .get(&target_id)
                .cloned()
                .unwrap_or(target_id.clone());

            let dataset_id = node_dataset_ids
                .get(&source_id)
                .or_else(|| node_dataset_ids.get(&target_id))
                .cloned();

            let source_text = node_texts.get(&source_id).cloned();
            let target_text = node_texts.get(&target_id).cloned();
            let source_description = node_descriptions.get(&source_id).cloned();
            let target_description = node_descriptions.get(&target_id).cloned();

            let source_fw = node_feedback_weights
                .get(&source_id)
                .copied()
                .unwrap_or(0.5);
            let target_fw = node_feedback_weights
                .get(&target_id)
                .copied()
                .unwrap_or(0.5);

            Some(RankedGraphEdge {
                source_id,
                target_id,
                relationship_name,
                score: rank_edge_score(
                    source_dist,
                    target_dist,
                    edge_dist,
                    config.feedback_influence,
                    source_fw,
                    target_fw,
                ),
                source_name,
                target_name,
                dataset_id,
                source_text,
                target_text,
                source_description,
                target_description,
            })
        })
        .collect::<Vec<_>>();

    // Sort ascending: lowest total distance = best match (matches Python heapq.nsmallest)
    ranked_edges.sort_by(|left, right| {
        left.score
            .partial_cmp(&right.score)
            .unwrap_or(std::cmp::Ordering::Equal)
    });

    let ranked_edges: Vec<_> = ranked_edges.into_iter().take(config.top_k).collect();
    tracing::Span::current().record("cognee.result.count", ranked_edges.len() as u64);
    Ok(ranked_edges)
}

#[cfg(test)]
#[allow(
    clippy::unwrap_used,
    clippy::expect_used,
    reason = "test code — panics are acceptable failures"
)]
mod penalty_default_tests {
    use super::*;

    #[test]
    fn default_triplet_distance_penalty_matches_python() {
        // Python: cognee/modules/retrieval/utils/brute_force_triplet_search.py:56,227
        assert_eq!(DEFAULT_TRIPLET_DISTANCE_PENALTY, 6.5);
    }

    #[test]
    fn graph_retrieval_config_default_uses_python_penalty() {
        let cfg = GraphRetrievalConfig::default();
        assert_eq!(cfg.triplet_distance_penalty, 6.5);
    }

    /// Verifies that SEARCH_COLLECTIONS includes ("Triplet", "text"), mirroring
    /// Python's _get_vector_index_collections() which enumerates all DataPoint
    /// subclasses including Triplet (graph_completion_retriever.py:88-99).
    /// Triplet declares metadata = {"index_fields": ["text"]} (Triplet.py:9).
    #[test]
    fn search_collections_includes_triplet_text() {
        let has_triplet_text = SEARCH_COLLECTIONS
            .iter()
            .any(|&(dt, fn_)| dt == "Triplet" && fn_ == "text");
        assert!(
            has_triplet_text,
            "SEARCH_COLLECTIONS must include (\"Triplet\", \"text\") to match Python's \
             dynamic enumeration of DataPoint index collections after memify"
        );
    }

    /// Verifies that the false "intentionally excluded" comment is gone.
    /// Triplet_text is now included, not excluded.
    #[test]
    fn search_collections_has_six_entries() {
        // 5 original + Triplet_text = 6 total
        assert_eq!(
            SEARCH_COLLECTIONS.len(),
            6,
            "SEARCH_COLLECTIONS should have 6 entries after adding Triplet_text"
        );
    }
}