Skip to main content

stateset_nsr/knowledge/
query.rs

1//! Query Language for Knowledge Base
2//!
3//! Provides a query interface for retrieving information from the
4//! knowledge base using pattern matching and logical conditions.
5
6use super::*;
7use std::collections::HashMap;
8
9/// A query against the knowledge base
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Query {
12    /// Pattern to match
13    pub patterns: Vec<QueryPattern>,
14    /// Filter conditions
15    pub filters: Vec<QueryFilter>,
16    /// Variables to return
17    pub select: Vec<String>,
18    /// Maximum results
19    pub limit: Option<usize>,
20    /// Results offset
21    pub offset: Option<usize>,
22    /// Order by
23    pub order_by: Option<OrderBy>,
24}
25
26impl Query {
27    pub fn new() -> Self {
28        Self {
29            patterns: Vec::new(),
30            filters: Vec::new(),
31            select: Vec::new(),
32            limit: None,
33            offset: None,
34            order_by: None,
35        }
36    }
37
38    /// Add a triple pattern to match
39    pub fn pattern(mut self, subject: QueryTerm, predicate: &str, object: QueryTerm) -> Self {
40        self.patterns.push(QueryPattern::Triple {
41            subject,
42            predicate: predicate.to_string(),
43            object,
44        });
45        self
46    }
47
48    /// Add a filter condition
49    pub fn filter(mut self, filter: QueryFilter) -> Self {
50        self.filters.push(filter);
51        self
52    }
53
54    /// Specify which variables to return
55    pub fn select_vars(mut self, vars: Vec<&str>) -> Self {
56        self.select = vars.into_iter().map(String::from).collect();
57        self
58    }
59
60    /// Limit the number of results
61    pub fn limit(mut self, limit: usize) -> Self {
62        self.limit = Some(limit);
63        self
64    }
65
66    /// Skip first N results
67    pub fn offset(mut self, offset: usize) -> Self {
68        self.offset = Some(offset);
69        self
70    }
71
72    /// Order results
73    pub fn order_by(mut self, var: &str, ascending: bool) -> Self {
74        self.order_by = Some(OrderBy {
75            variable: var.to_string(),
76            ascending,
77        });
78        self
79    }
80}
81
82impl Default for Query {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88/// A term in a query pattern
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub enum QueryTerm {
91    /// A variable to bind
92    Variable(String),
93    /// A specific entity
94    Entity(EntityId),
95    /// A constant value
96    Constant(String),
97    /// Any value (wildcard)
98    Any,
99}
100
101impl QueryTerm {
102    pub fn var(name: &str) -> Self {
103        QueryTerm::Variable(name.to_string())
104    }
105
106    pub fn entity(id: EntityId) -> Self {
107        QueryTerm::Entity(id)
108    }
109
110    pub fn constant(value: &str) -> Self {
111        QueryTerm::Constant(value.to_string())
112    }
113}
114
115/// A pattern to match in the query
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub enum QueryPattern {
118    /// Match a triple
119    Triple {
120        subject: QueryTerm,
121        predicate: String,
122        object: QueryTerm,
123    },
124    /// Match an entity by type
125    EntityType {
126        variable: String,
127        entity_type: EntityType,
128    },
129    /// Match an entity by property
130    Property {
131        entity: QueryTerm,
132        property: String,
133        value: QueryTerm,
134    },
135    /// Optional pattern (LEFT JOIN semantics)
136    Optional(Box<QueryPattern>),
137    /// Union of patterns (OR)
138    Union(Vec<QueryPattern>),
139    /// Negation (NOT EXISTS)
140    NotExists(Box<QueryPattern>),
141}
142
143/// Filter condition for query results
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum QueryFilter {
146    /// Equality comparison
147    Equals {
148        variable: String,
149        value: PropertyValue,
150    },
151    /// Not equal
152    NotEquals {
153        variable: String,
154        value: PropertyValue,
155    },
156    /// Greater than (for numeric values)
157    GreaterThan { variable: String, value: f64 },
158    /// Less than
159    LessThan { variable: String, value: f64 },
160    /// String contains
161    Contains { variable: String, substring: String },
162    /// Regex match
163    Regex { variable: String, pattern: String },
164    /// Confidence threshold
165    MinConfidence { variable: String, threshold: f32 },
166    /// Has embedding
167    HasEmbedding { variable: String },
168    /// Similarity threshold (requires embeddings)
169    SimilarTo {
170        variable: String,
171        target: EntityId,
172        threshold: f32,
173    },
174    /// Logical AND of filters
175    And(Vec<QueryFilter>),
176    /// Logical OR of filters
177    Or(Vec<QueryFilter>),
178    /// Negation
179    Not(Box<QueryFilter>),
180}
181
182/// Ordering specification
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct OrderBy {
185    pub variable: String,
186    pub ascending: bool,
187}
188
189/// Result of a query execution
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct QueryResult {
192    /// Bindings for each result row
193    pub bindings: Vec<HashMap<String, QueryValue>>,
194    /// Total count before limit/offset
195    pub total_count: usize,
196    /// Execution time in milliseconds
197    pub execution_time_ms: u64,
198}
199
200impl QueryResult {
201    pub fn empty() -> Self {
202        Self {
203            bindings: Vec::new(),
204            total_count: 0,
205            execution_time_ms: 0,
206        }
207    }
208
209    pub fn is_empty(&self) -> bool {
210        self.bindings.is_empty()
211    }
212
213    pub fn len(&self) -> usize {
214        self.bindings.len()
215    }
216}
217
218/// A value bound to a variable in query results
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub enum QueryValue {
221    Entity(Entity),
222    EntityId(EntityId),
223    Property(PropertyValue),
224    Literal(String),
225    Number(f64),
226    Boolean(bool),
227    Null,
228}
229
230impl QueryValue {
231    pub fn as_entity_id(&self) -> Option<&EntityId> {
232        match self {
233            QueryValue::EntityId(id) => Some(id),
234            QueryValue::Entity(e) => Some(&e.id),
235            _ => None,
236        }
237    }
238
239    pub fn as_entity(&self) -> Option<&Entity> {
240        match self {
241            QueryValue::Entity(e) => Some(e),
242            _ => None,
243        }
244    }
245}
246
247/// Query executor for the knowledge base
248pub struct QueryExecutor<'a> {
249    kb: &'a KnowledgeBase,
250    org_id: Option<String>,
251}
252
253impl<'a> QueryExecutor<'a> {
254    pub fn new(kb: &'a KnowledgeBase) -> Self {
255        Self { kb, org_id: None }
256    }
257
258    /// Restrict query execution to a single organization (multi-tenant isolation).
259    pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
260        self.org_id = Some(org_id.into());
261        self
262    }
263
264    /// Execute a query and return results
265    pub fn execute(&self, query: &Query) -> QueryResult {
266        let start = std::time::Instant::now();
267
268        // Start with empty bindings
269        let mut bindings: Vec<HashMap<String, QueryValue>> = vec![HashMap::new()];
270
271        // Process each pattern
272        for pattern in &query.patterns {
273            bindings = self.apply_pattern(pattern, bindings);
274        }
275
276        // Apply filters
277        for filter in &query.filters {
278            bindings.retain(|b| self.evaluate_filter(filter, b));
279        }
280
281        let total_count = bindings.len();
282
283        // Apply ordering
284        if let Some(ref order) = query.order_by {
285            bindings.sort_by(|a, b| {
286                let av = a.get(&order.variable);
287                let bv = b.get(&order.variable);
288                let cmp = self.compare_values(av, bv);
289                if order.ascending {
290                    cmp
291                } else {
292                    cmp.reverse()
293                }
294            });
295        }
296
297        // Apply offset
298        if let Some(offset) = query.offset {
299            bindings = bindings.into_iter().skip(offset).collect();
300        }
301
302        // Apply limit
303        if let Some(limit) = query.limit {
304            bindings = bindings.into_iter().take(limit).collect();
305        }
306
307        // Project to selected variables
308        if !query.select.is_empty() {
309            bindings = bindings
310                .into_iter()
311                .map(|b| {
312                    b.into_iter()
313                        .filter(|(k, _)| query.select.contains(k))
314                        .collect()
315                })
316                .collect();
317        }
318
319        QueryResult {
320            bindings,
321            total_count,
322            execution_time_ms: start.elapsed().as_millis() as u64,
323        }
324    }
325
326    fn apply_pattern(
327        &self,
328        pattern: &QueryPattern,
329        bindings: Vec<HashMap<String, QueryValue>>,
330    ) -> Vec<HashMap<String, QueryValue>> {
331        match pattern {
332            QueryPattern::Triple {
333                subject,
334                predicate,
335                object,
336            } => {
337                let mut new_bindings = Vec::new();
338
339                for binding in bindings {
340                    let matching_triples =
341                        self.find_matching_triples(subject, predicate, object, &binding);
342
343                    for triple in matching_triples {
344                        let mut new_binding = binding.clone();
345
346                        // Bind subject variable
347                        if let QueryTerm::Variable(v) = subject {
348                            new_binding
349                                .insert(v.clone(), QueryValue::EntityId(triple.subject.clone()));
350                        }
351
352                        // Bind object variable
353                        if let QueryTerm::Variable(v) = object {
354                            new_binding
355                                .insert(v.clone(), QueryValue::EntityId(triple.object.clone()));
356                        }
357
358                        new_bindings.push(new_binding);
359                    }
360                }
361
362                new_bindings
363            }
364
365            QueryPattern::EntityType {
366                variable,
367                entity_type,
368            } => {
369                let mut new_bindings = Vec::new();
370
371                for binding in bindings {
372                    let entities = if let Some(ref org_id) = self.org_id {
373                        self.kb.get_entities_by_type_and_org(entity_type, org_id)
374                    } else {
375                        self.kb.get_entities_by_type(entity_type)
376                    };
377
378                    for entity in entities {
379                        let mut new_binding = binding.clone();
380                        new_binding.insert(variable.clone(), QueryValue::Entity(entity));
381                        new_bindings.push(new_binding);
382                    }
383                }
384
385                new_bindings
386            }
387
388            QueryPattern::Property {
389                entity,
390                property,
391                value,
392            } => {
393                let mut new_bindings = Vec::new();
394
395                for binding in bindings {
396                    // Get the entity
397                    let entity_id = match entity {
398                        QueryTerm::Variable(v) => {
399                            binding.get(v).and_then(|qv| qv.as_entity_id()).cloned()
400                        }
401                        QueryTerm::Entity(id) => Some(id.clone()),
402                        _ => None,
403                    };
404
405                    if let Some(id) = entity_id {
406                        if let Some(ent) = self.kb.get_entity(&id) {
407                            if let Some(ref org_id) = self.org_id {
408                                if ent.org_id != *org_id {
409                                    continue;
410                                }
411                            }
412                            if let Some(prop_value) = ent.properties.get(property) {
413                                let mut new_binding = binding.clone();
414
415                                // Handle value binding
416                                match value {
417                                    QueryTerm::Variable(v) => {
418                                        let qv = self.property_to_query_value(&prop_value);
419                                        new_binding.insert(v.clone(), qv);
420                                        new_bindings.push(new_binding);
421                                    }
422                                    QueryTerm::Constant(c) => {
423                                        if self.property_matches_constant(&prop_value, c) {
424                                            new_bindings.push(new_binding);
425                                        }
426                                    }
427                                    QueryTerm::Any => {
428                                        new_bindings.push(new_binding);
429                                    }
430                                    _ => {}
431                                }
432                            }
433                        }
434                    }
435                }
436
437                new_bindings
438            }
439
440            QueryPattern::Optional(inner) => {
441                let mut new_bindings = Vec::new();
442
443                for binding in bindings {
444                    let matched = self.apply_pattern(inner, vec![binding.clone()]);
445
446                    if matched.is_empty() {
447                        new_bindings.push(binding);
448                    } else {
449                        new_bindings.extend(matched);
450                    }
451                }
452
453                new_bindings
454            }
455
456            QueryPattern::Union(patterns) => {
457                let mut all_bindings = Vec::new();
458                for pattern in patterns {
459                    let matched = self.apply_pattern(pattern, bindings.clone());
460                    all_bindings.extend(matched);
461                }
462                all_bindings
463            }
464
465            QueryPattern::NotExists(inner) => bindings
466                .into_iter()
467                .filter(|b| {
468                    let matched = self.apply_pattern(inner, vec![b.clone()]);
469                    matched.is_empty()
470                })
471                .collect(),
472        }
473    }
474
475    fn find_matching_triples(
476        &self,
477        subject: &QueryTerm,
478        predicate: &str,
479        object: &QueryTerm,
480        binding: &HashMap<String, QueryValue>,
481    ) -> Vec<Triple> {
482        self.kb
483            .triples
484            .iter()
485            .filter(|t| {
486                if let Some(ref org_id) = self.org_id {
487                    if t.org_id != *org_id {
488                        return false;
489                    }
490                }
491                // Check predicate
492                if t.predicate != predicate {
493                    return false;
494                }
495
496                // Check subject
497                let subject_matches = match subject {
498                    QueryTerm::Variable(v) => {
499                        if let Some(bound) = binding.get(v) {
500                            bound.as_entity_id() == Some(&t.subject)
501                        } else {
502                            true // Unbound, will bind
503                        }
504                    }
505                    QueryTerm::Entity(id) => &t.subject == id,
506                    QueryTerm::Constant(c) => t.subject.0 == *c,
507                    QueryTerm::Any => true,
508                };
509
510                if !subject_matches {
511                    return false;
512                }
513
514                // Check object
515                match object {
516                    QueryTerm::Variable(v) => {
517                        if let Some(bound) = binding.get(v) {
518                            bound.as_entity_id() == Some(&t.object)
519                        } else {
520                            true // Unbound, will bind
521                        }
522                    }
523                    QueryTerm::Entity(id) => &t.object == id,
524                    QueryTerm::Constant(c) => t.object.0 == *c,
525                    QueryTerm::Any => true,
526                }
527            })
528            .map(|t| t.clone())
529            .collect()
530    }
531
532    fn evaluate_filter(&self, filter: &QueryFilter, binding: &HashMap<String, QueryValue>) -> bool {
533        match filter {
534            QueryFilter::Equals { variable, value } => binding
535                .get(variable)
536                .is_some_and(|v| self.query_value_equals(v, value)),
537
538            QueryFilter::NotEquals { variable, value } => binding
539                .get(variable)
540                .map_or(true, |v| !self.query_value_equals(v, value)),
541
542            QueryFilter::GreaterThan { variable, value } => binding
543                .get(variable)
544                .and_then(|v| self.query_value_as_f64(v))
545                .is_some_and(|n| n > *value),
546
547            QueryFilter::LessThan { variable, value } => binding
548                .get(variable)
549                .and_then(|v| self.query_value_as_f64(v))
550                .is_some_and(|n| n < *value),
551
552            QueryFilter::Contains {
553                variable,
554                substring,
555            } => binding
556                .get(variable)
557                .and_then(|v| self.query_value_as_string(v))
558                .is_some_and(|s| s.contains(substring)),
559
560            QueryFilter::Regex { variable, pattern } => {
561                if let Ok(re) = regex::Regex::new(pattern) {
562                    binding
563                        .get(variable)
564                        .and_then(|v| self.query_value_as_string(v))
565                        .is_some_and(|s| re.is_match(&s))
566                } else {
567                    false
568                }
569            }
570
571            QueryFilter::MinConfidence {
572                variable,
573                threshold,
574            } => binding
575                .get(variable)
576                .and_then(|v| v.as_entity())
577                .is_some_and(|e| e.confidence >= *threshold),
578
579            QueryFilter::HasEmbedding { variable } => binding
580                .get(variable)
581                .and_then(|v| v.as_entity())
582                .is_some_and(|e| e.embedding.is_some()),
583
584            QueryFilter::SimilarTo {
585                variable,
586                target,
587                threshold,
588            } => {
589                if let Some(entity_id) = binding.get(variable).and_then(|v| v.as_entity_id()) {
590                    if let Some(ref org_id) = self.org_id {
591                        if self
592                            .kb
593                            .get_entity(target)
594                            .is_some_and(|e| e.org_id != *org_id)
595                        {
596                            return false;
597                        }
598                    }
599                    self.kb
600                        .embedding_similarity(entity_id, target)
601                        .is_some_and(|sim| sim >= *threshold)
602                } else {
603                    false
604                }
605            }
606
607            QueryFilter::And(filters) => filters.iter().all(|f| self.evaluate_filter(f, binding)),
608
609            QueryFilter::Or(filters) => filters.iter().any(|f| self.evaluate_filter(f, binding)),
610
611            QueryFilter::Not(inner) => !self.evaluate_filter(inner, binding),
612        }
613    }
614
615    fn query_value_equals(&self, qv: &QueryValue, pv: &PropertyValue) -> bool {
616        match (qv, pv) {
617            (QueryValue::Literal(s), PropertyValue::String(ps)) => s == ps,
618            (QueryValue::Number(n), PropertyValue::Float(f)) => (*n - f).abs() < f64::EPSILON,
619            (QueryValue::Number(n), PropertyValue::Integer(i)) => {
620                (*n - *i as f64).abs() < f64::EPSILON
621            }
622            (QueryValue::Boolean(b), PropertyValue::Boolean(pb)) => b == pb,
623            _ => false,
624        }
625    }
626
627    fn query_value_as_f64(&self, qv: &QueryValue) -> Option<f64> {
628        match qv {
629            QueryValue::Number(n) => Some(*n),
630            QueryValue::Property(PropertyValue::Float(f)) => Some(*f),
631            QueryValue::Property(PropertyValue::Integer(i)) => Some(*i as f64),
632            _ => None,
633        }
634    }
635
636    fn query_value_as_string(&self, qv: &QueryValue) -> Option<String> {
637        match qv {
638            QueryValue::Literal(s) => Some(s.clone()),
639            QueryValue::Property(PropertyValue::String(s)) => Some(s.clone()),
640            QueryValue::Entity(e) => Some(e.name.clone()),
641            _ => None,
642        }
643    }
644
645    fn property_to_query_value(&self, pv: &PropertyValue) -> QueryValue {
646        match pv {
647            PropertyValue::String(s) => QueryValue::Literal(s.clone()),
648            PropertyValue::Integer(i) => QueryValue::Number(*i as f64),
649            PropertyValue::Float(f) => QueryValue::Number(*f),
650            PropertyValue::Boolean(b) => QueryValue::Boolean(*b),
651            PropertyValue::EntityRef(id) => QueryValue::EntityId(id.clone()),
652            _ => QueryValue::Null,
653        }
654    }
655
656    fn property_matches_constant(&self, pv: &PropertyValue, constant: &str) -> bool {
657        match pv {
658            PropertyValue::String(s) => s == constant,
659            PropertyValue::Integer(i) => constant.parse::<i64>() == Ok(*i),
660            PropertyValue::Float(f) => constant
661                .parse::<f64>()
662                .is_ok_and(|c| (*f - c).abs() < f64::EPSILON),
663            PropertyValue::Boolean(b) => constant.parse::<bool>() == Ok(*b),
664            _ => false,
665        }
666    }
667
668    fn compare_values(&self, a: Option<&QueryValue>, b: Option<&QueryValue>) -> std::cmp::Ordering {
669        use std::cmp::Ordering;
670
671        match (a, b) {
672            (None, None) => Ordering::Equal,
673            (None, Some(_)) => Ordering::Less,
674            (Some(_), None) => Ordering::Greater,
675            (Some(av), Some(bv)) => match (av, bv) {
676                (QueryValue::Number(an), QueryValue::Number(bn)) => {
677                    an.partial_cmp(bn).unwrap_or(Ordering::Equal)
678                }
679                (QueryValue::Literal(as_), QueryValue::Literal(bs)) => as_.cmp(bs),
680                _ => Ordering::Equal,
681            },
682        }
683    }
684}
685
686/// Convenience builder for creating queries
687#[derive(Default)]
688pub struct QueryBuilder {
689    query: Query,
690}
691
692impl QueryBuilder {
693    pub fn new() -> Self {
694        Self {
695            query: Query::new(),
696        }
697    }
698
699    /// Match a triple pattern with variables
700    pub fn matches(mut self, subject: &str, predicate: &str, object: &str) -> Self {
701        let subj = subject
702            .strip_prefix('?')
703            .map(QueryTerm::var)
704            .unwrap_or_else(|| QueryTerm::constant(subject));
705
706        let obj = object
707            .strip_prefix('?')
708            .map(QueryTerm::var)
709            .unwrap_or_else(|| QueryTerm::constant(object));
710
711        self.query.patterns.push(QueryPattern::Triple {
712            subject: subj,
713            predicate: predicate.to_string(),
714            object: obj,
715        });
716        self
717    }
718
719    /// Filter by property value
720    pub fn where_eq(mut self, variable: &str, value: PropertyValue) -> Self {
721        self.query.filters.push(QueryFilter::Equals {
722            variable: variable.to_string(),
723            value,
724        });
725        self
726    }
727
728    /// Select specific variables
729    pub fn select(mut self, vars: &[&str]) -> Self {
730        self.query.select = vars.iter().map(|s| s.to_string()).collect();
731        self
732    }
733
734    /// Limit results
735    pub fn limit(mut self, n: usize) -> Self {
736        self.query.limit = Some(n);
737        self
738    }
739
740    /// Build the query
741    pub fn build(self) -> Query {
742        self.query
743    }
744}
745
746// ============================================================================
747// Cached Query Executor
748// ============================================================================
749
750use crate::cache::{Cache, CacheConfig};
751use std::hash::{Hash, Hasher};
752use std::sync::Arc;
753
754/// A cache key for queries based on their serialized representation
755#[derive(Debug, Clone, PartialEq, Eq)]
756pub struct QueryCacheKey(String);
757
758impl Hash for QueryCacheKey {
759    fn hash<H: Hasher>(&self, state: &mut H) {
760        self.0.hash(state);
761    }
762}
763
764impl QueryCacheKey {
765    /// Create a cache key from a query
766    pub fn from_query(query: &Query) -> Self {
767        Self::from_query_with_org(query, None)
768    }
769
770    /// Create a cache key from a query scoped to an organization.
771    pub fn from_query_with_org(query: &Query, org_id: Option<&str>) -> Self {
772        let key = serde_json::to_string(query).unwrap_or_default();
773        let key = match org_id {
774            Some(org) => format!("org:{}|{}", org, key),
775            None => key,
776        };
777        QueryCacheKey(key)
778    }
779}
780
781/// Configuration for the query cache
782#[derive(Debug, Clone)]
783pub struct QueryCacheConfig {
784    /// Maximum number of cached query results
785    pub max_entries: usize,
786    /// Time-to-live for cached results in seconds
787    pub ttl_seconds: u64,
788    /// Whether caching is enabled
789    pub enabled: bool,
790}
791
792impl Default for QueryCacheConfig {
793    fn default() -> Self {
794        Self {
795            max_entries: 1000,
796            ttl_seconds: 300, // 5 minutes
797            enabled: true,
798        }
799    }
800}
801
802/// A query executor with built-in caching for repeated queries.
803///
804/// The cached executor wraps a standard `QueryExecutor` and transparently
805/// caches query results. Cache keys are computed from the serialized query,
806/// so identical queries will hit the cache.
807///
808/// # Cache Invalidation
809///
810/// The cache does **not** automatically invalidate when the knowledge base
811/// changes. For write-heavy workloads, consider:
812/// - Using a shorter TTL
813/// - Calling `invalidate_all()` after batch writes
814/// - Using the non-cached `QueryExecutor` directly
815///
816/// # Example
817///
818/// ```rust,ignore
819/// use stateset_nsr::knowledge::query::{CachedQueryExecutor, QueryCacheConfig};
820///
821/// let cache_config = QueryCacheConfig {
822///     max_entries: 5000,
823///     ttl_seconds: 600,
824///     enabled: true,
825/// };
826///
827/// let executor = CachedQueryExecutor::new(&kb, cache_config);
828///
829/// // First query - cache miss, executes query
830/// let result1 = executor.execute(&query);
831///
832/// // Second identical query - cache hit, returns cached result
833/// let result2 = executor.execute(&query);
834///
835/// // Check cache statistics
836/// let stats = executor.cache_stats();
837/// println!("Cache hit rate: {:.2}%", stats.hit_rate * 100.0);
838/// ```
839pub struct CachedQueryExecutor<'a> {
840    executor: QueryExecutor<'a>,
841    cache: Arc<Cache<QueryCacheKey, QueryResult>>,
842    config: QueryCacheConfig,
843    org_id: Option<String>,
844}
845
846impl<'a> CachedQueryExecutor<'a> {
847    /// Create a new cached query executor
848    pub fn new(kb: &'a KnowledgeBase, config: QueryCacheConfig) -> Self {
849        let cache = Arc::new(Cache::new(CacheConfig {
850            max_entries: config.max_entries,
851            ttl_seconds: config.ttl_seconds,
852            enable_stats: true,
853        }));
854
855        Self {
856            executor: QueryExecutor::new(kb),
857            cache,
858            config,
859            org_id: None,
860        }
861    }
862
863    /// Restrict query execution (and cache keys) to a single organization.
864    pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
865        let org_id = org_id.into();
866        self.executor = self.executor.with_org_id(org_id.clone());
867        self.org_id = Some(org_id);
868        self
869    }
870
871    /// Execute a query with caching
872    ///
873    /// If caching is enabled and a cached result exists for this query,
874    /// returns the cached result. Otherwise, executes the query and
875    /// caches the result before returning.
876    pub fn execute(&self, query: &Query) -> QueryResult {
877        if !self.config.enabled {
878            return self.executor.execute(query);
879        }
880
881        let cache_key = QueryCacheKey::from_query_with_org(query, self.org_id.as_deref());
882
883        // Try to get from cache
884        if let Some(cached) = self.cache.get(&cache_key) {
885            return cached;
886        }
887
888        // Execute query and cache result
889        let result = self.executor.execute(query);
890        self.cache.insert(cache_key, result.clone());
891        result
892    }
893
894    /// Execute a query bypassing the cache
895    ///
896    /// Useful when you need fresh results or when the query is unlikely
897    /// to be repeated.
898    pub fn execute_uncached(&self, query: &Query) -> QueryResult {
899        self.executor.execute(query)
900    }
901
902    /// Invalidate all cached results
903    ///
904    /// Call this after making significant changes to the knowledge base
905    /// to ensure queries return fresh results.
906    pub fn invalidate_all(&self) {
907        self.cache.clear();
908    }
909
910    /// Get cache statistics
911    pub fn cache_stats(&self) -> crate::cache::CacheStats {
912        self.cache.stats()
913    }
914
915    /// Check if a query result is cached
916    pub fn is_cached(&self, query: &Query) -> bool {
917        let cache_key = QueryCacheKey::from_query_with_org(query, self.org_id.as_deref());
918        self.cache.contains(&cache_key)
919    }
920
921    /// Get the number of cached query results
922    pub fn cached_count(&self) -> usize {
923        self.cache.len()
924    }
925}
926
927#[cfg(test)]
928mod tests {
929    use super::*;
930
931    #[test]
932    fn query_executor_respects_org_id_filter() {
933        let kb = KnowledgeBase::new();
934
935        let paris_a = kb.add_entity(Entity::with_org_id(
936            "Paris",
937            EntityType::Instance,
938            "org_a".to_string(),
939        ));
940        let france_a = kb.add_entity(Entity::with_org_id(
941            "France",
942            EntityType::Instance,
943            "org_a".to_string(),
944        ));
945        let paris_b = kb.add_entity(Entity::with_org_id(
946            "Paris",
947            EntityType::Instance,
948            "org_b".to_string(),
949        ));
950        let france_b = kb.add_entity(Entity::with_org_id(
951            "France",
952            EntityType::Instance,
953            "org_b".to_string(),
954        ));
955
956        kb.add_triple(Triple::with_org_id(
957            paris_a,
958            "capital_of",
959            france_a,
960            "org_a".to_string(),
961        ))
962        .unwrap();
963        kb.add_triple(Triple::with_org_id(
964            paris_b,
965            "capital_of",
966            france_b,
967            "org_b".to_string(),
968        ))
969        .unwrap();
970
971        let query = QueryBuilder::new()
972            .matches("?city", "capital_of", "?country")
973            .build();
974
975        let result_a = QueryExecutor::new(&kb).with_org_id("org_a").execute(&query);
976        assert_eq!(result_a.bindings.len(), 1);
977
978        let result_b = QueryExecutor::new(&kb).with_org_id("org_b").execute(&query);
979        assert_eq!(result_b.bindings.len(), 1);
980    }
981
982    #[test]
983    fn query_cache_key_includes_org_id() {
984        let query = QueryBuilder::new()
985            .matches("?city", "capital_of", "?country")
986            .build();
987
988        let key_a = QueryCacheKey::from_query_with_org(&query, Some("org_a"));
989        let key_b = QueryCacheKey::from_query_with_org(&query, Some("org_b"));
990        assert_ne!(key_a, key_b);
991    }
992
993    #[test]
994    fn test_query_execution() {
995        let kb = KnowledgeBase::new();
996
997        let paris = kb.add_entity(
998            Entity::new("Paris", EntityType::Instance)
999                .with_property("population", PropertyValue::Integer(2_161_000)),
1000        );
1001        let france = kb.add_entity(Entity::new("France", EntityType::Instance));
1002        let berlin = kb.add_entity(
1003            Entity::new("Berlin", EntityType::Instance)
1004                .with_property("population", PropertyValue::Integer(3_645_000)),
1005        );
1006        let germany = kb.add_entity(Entity::new("Germany", EntityType::Instance));
1007
1008        kb.add_triple(Triple::new(paris.clone(), "capital_of", france.clone()))
1009            .unwrap();
1010        kb.add_triple(Triple::new(berlin.clone(), "capital_of", germany.clone()))
1011            .unwrap();
1012
1013        let executor = QueryExecutor::new(&kb);
1014
1015        // Query: Find all capitals
1016        let query = QueryBuilder::new()
1017            .matches("?city", "capital_of", "?country")
1018            .build();
1019
1020        let result = executor.execute(&query);
1021        assert_eq!(result.bindings.len(), 2);
1022    }
1023
1024    #[test]
1025    fn optional_pattern_preserves_base_binding_when_no_match() {
1026        let kb = KnowledgeBase::new();
1027
1028        let alice = kb.add_entity(Entity::new("Alice", EntityType::Instance));
1029        let bob = kb.add_entity(Entity::new("Bob", EntityType::Instance));
1030
1031        kb.add_triple(Triple::new(alice.clone(), "friend_of", bob.clone()))
1032            .unwrap();
1033
1034        let query = Query {
1035            patterns: vec![
1036                QueryPattern::Triple {
1037                    subject: QueryTerm::Variable("s".into()),
1038                    predicate: "friend_of".to_string(),
1039                    object: QueryTerm::Variable("o".into()),
1040                },
1041                QueryPattern::Optional(Box::new(QueryPattern::Triple {
1042                    subject: QueryTerm::Variable("s".into()),
1043                    predicate: "likes".to_string(),
1044                    object: QueryTerm::Variable("x".into()),
1045                })),
1046            ],
1047            filters: Vec::new(),
1048            select: vec!["s".to_string(), "o".to_string(), "x".to_string()],
1049            limit: None,
1050            offset: None,
1051            order_by: None,
1052        };
1053
1054        let executor = QueryExecutor::new(&kb);
1055        let result = executor.execute(&query);
1056
1057        assert_eq!(result.bindings.len(), 1);
1058        let binding = &result.bindings[0];
1059        assert!(binding.contains_key("s"));
1060        assert!(binding.contains_key("o"));
1061        assert!(!binding.contains_key("x")); // optional pattern had no match, but base binding remained
1062    }
1063
1064    #[test]
1065    fn test_cached_query_executor() {
1066        let kb = KnowledgeBase::new();
1067
1068        let paris = kb.add_entity(Entity::new("Paris", EntityType::Instance));
1069        let france = kb.add_entity(Entity::new("France", EntityType::Instance));
1070
1071        kb.add_triple(Triple::new(paris.clone(), "capital_of", france.clone()))
1072            .unwrap();
1073
1074        let config = QueryCacheConfig::default();
1075        let cached_executor = CachedQueryExecutor::new(&kb, config);
1076
1077        let query = QueryBuilder::new()
1078            .matches("?city", "capital_of", "?country")
1079            .build();
1080
1081        // First execution - cache miss
1082        let result1 = cached_executor.execute(&query);
1083        assert_eq!(result1.bindings.len(), 1);
1084
1085        let stats1 = cached_executor.cache_stats();
1086        assert_eq!(stats1.misses, 1);
1087        assert_eq!(stats1.hits, 0);
1088
1089        // Second execution - cache hit
1090        let result2 = cached_executor.execute(&query);
1091        assert_eq!(result2.bindings.len(), 1);
1092
1093        let stats2 = cached_executor.cache_stats();
1094        assert_eq!(stats2.misses, 1);
1095        assert_eq!(stats2.hits, 1);
1096        assert!(stats2.hit_rate > 0.4); // Should be ~50%
1097
1098        // Verify cache contains the query
1099        assert!(cached_executor.is_cached(&query));
1100        assert_eq!(cached_executor.cached_count(), 1);
1101    }
1102
1103    #[test]
1104    fn test_cached_executor_invalidation() {
1105        let kb = KnowledgeBase::new();
1106
1107        let alice = kb.add_entity(Entity::new("Alice", EntityType::Instance));
1108        let bob = kb.add_entity(Entity::new("Bob", EntityType::Instance));
1109
1110        kb.add_triple(Triple::new(alice.clone(), "knows", bob.clone()))
1111            .unwrap();
1112
1113        let cached_executor = CachedQueryExecutor::new(&kb, QueryCacheConfig::default());
1114
1115        let query = QueryBuilder::new()
1116            .matches("?person", "knows", "?other")
1117            .build();
1118
1119        // Execute and cache
1120        cached_executor.execute(&query);
1121        assert_eq!(cached_executor.cached_count(), 1);
1122
1123        // Invalidate cache
1124        cached_executor.invalidate_all();
1125        assert_eq!(cached_executor.cached_count(), 0);
1126        assert!(!cached_executor.is_cached(&query));
1127    }
1128}