Skip to main content

hirn_exec/operators/
quality_gate.rs

1//! `QualityGateExec` — confidence-based gate that passes or escalates results.
2//!
3//! Computes a 4-dimension quality score (coverage, confidence, coherence,
4//! sufficiency) and emits an "escalate" flag when quality falls below threshold.
5//! Target: ≤20% of queries escalate to deliberation.
6
7use std::any::Any;
8use std::fmt;
9use std::sync::Arc;
10
11use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt32Array};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use datafusion_common::Result;
14use datafusion_execution::{SendableRecordBatchStream, TaskContext};
15use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
16use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
17use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
18
19/// Configuration for quality gate thresholds.
20#[derive(Debug, Clone)]
21pub struct QualityGateConfig {
22    /// Quality threshold below which escalation is triggered (default: 0.5).
23    pub threshold: f32,
24    /// Weight for coverage dimension (default: 0.3).
25    pub coverage_weight: f32,
26    /// Weight for confidence dimension (default: 0.3).
27    pub confidence_weight: f32,
28    /// Weight for coherence dimension (default: 0.2).
29    pub coherence_weight: f32,
30    /// Weight for sufficiency dimension (default: 0.2).
31    pub sufficiency_weight: f32,
32    /// Coherence fallback score used when fewer than 2 results have embeddings
33    /// (default: 0.6). When embeddings are present, the real pairwise cosine
34    /// similarity is computed directly from the `embedding` column.
35    pub coherence_fallback: f32,
36}
37
38impl Default for QualityGateConfig {
39    fn default() -> Self {
40        Self {
41            threshold: 0.5,
42            coverage_weight: 0.3,
43            confidence_weight: 0.3,
44            coherence_weight: 0.2,
45            sufficiency_weight: 0.2,
46            coherence_fallback: 0.6,
47        }
48    }
49}
50
51/// Quality assessment result.
52#[derive(Debug, Clone)]
53pub struct QualityAssessment {
54    pub coverage: f32,
55    pub confidence: f32,
56    pub coherence: f32,
57    pub sufficiency: f32,
58    pub combined: f32,
59    pub escalate: bool,
60}
61
62/// DataFusion operator that gates retrieval results by quality.
63///
64/// Passes through input batches, appending quality metrics columns.
65/// When quality is below threshold, adds `quality_action = "escalate"`.
66#[derive(Debug)]
67pub struct QualityGateExec {
68    input: Arc<dyn ExecutionPlan>,
69    config: QualityGateConfig,
70    /// Token budget for sufficiency calculation.
71    token_budget: usize,
72    schema: SchemaRef,
73    properties: PlanProperties,
74}
75
76impl QualityGateExec {
77    pub fn new(
78        input: Arc<dyn ExecutionPlan>,
79        config: QualityGateConfig,
80        token_budget: usize,
81    ) -> Self {
82        let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
83        fields.push(Arc::new(Field::new(
84            "quality_score",
85            DataType::Float32,
86            false,
87        )));
88        fields.push(Arc::new(Field::new(
89            "quality_action",
90            DataType::Utf8,
91            false,
92        )));
93        let schema = Arc::new(Schema::new(fields));
94
95        let properties = PlanProperties::new(
96            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
97            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
98            EmissionType::Final,
99            Boundedness::Bounded,
100        );
101
102        Self {
103            input,
104            config,
105            token_budget,
106            schema,
107            properties,
108        }
109    }
110
111    /// Compute quality assessment from batch statistics.
112    ///
113    /// Coverage: ratio of retrieved rows to expected minimum (heuristic: 5 useful results).
114    /// Confidence: average composite score of results.
115    /// Coherence: pairwise cosine similarity computed from the `embedding` column when
116    ///   present; falls back to `coherence_fallback` when embeddings are unavailable.
117    /// Sufficiency: ratio of retrieved tokens to token budget.
118    fn assess_quality(
119        config: &QualityGateConfig,
120        token_budget: usize,
121        row_count: usize,
122        avg_score: f32,
123        total_tokens: usize,
124        coherence: f32,
125    ) -> QualityAssessment {
126        let coverage = if row_count > 0 {
127            1.0_f32.min(row_count as f32 / 5.0)
128        } else {
129            0.0
130        };
131        let confidence = avg_score;
132        let sufficiency = if token_budget > 0 {
133            (total_tokens as f32 / token_budget as f32).min(1.0)
134        } else {
135            1.0
136        };
137
138        let combined = config.coverage_weight * coverage
139            + config.confidence_weight * confidence
140            + config.coherence_weight * coherence
141            + config.sufficiency_weight * sufficiency;
142
143        let escalate = combined < config.threshold;
144
145        QualityAssessment {
146            coverage,
147            confidence,
148            coherence,
149            sufficiency,
150            combined,
151            escalate,
152        }
153    }
154
155    /// Compute pairwise cosine coherence from the `embedding` column of the merged batch.
156    ///
157    /// Returns the mean pairwise cosine similarity, or `fallback` when fewer than 2
158    /// non-null embeddings are available.
159    fn compute_coherence_from_batch(batch: &RecordBatch, fallback: f32) -> f32 {
160        let fsl = match batch
161            .column_by_name("embedding")
162            .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>())
163        {
164            Some(fsl) => fsl,
165            None => return fallback,
166        };
167
168        let embeddings: Vec<Vec<f32>> = (0..fsl.len())
169            .filter(|&i| !fsl.is_null(i))
170            .filter_map(|i| {
171                let values = fsl.value(i);
172                let f32_arr = values.as_any().downcast_ref::<Float32Array>()?;
173                Some(f32_arr.values().to_vec())
174            })
175            .collect();
176
177        if embeddings.len() < 2 {
178            return fallback;
179        }
180
181        let mut sum = 0.0_f32;
182        let mut count = 0_u32;
183        for i in 0..embeddings.len() {
184            for j in (i + 1)..embeddings.len() {
185                sum += cosine_similarity(&embeddings[i], &embeddings[j]);
186                count += 1;
187            }
188        }
189
190        if count > 0 {
191            (sum / count as f32).clamp(0.0, 1.0)
192        } else {
193            fallback
194        }
195    }
196}
197
198fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199    let mut dot = 0.0_f32;
200    let mut norm_a = 0.0_f32;
201    let mut norm_b = 0.0_f32;
202    for (x, y) in a.iter().zip(b.iter()) {
203        dot += x * y;
204        norm_a += x * x;
205        norm_b += y * y;
206    }
207    let denom = norm_a.sqrt() * norm_b.sqrt();
208    if denom < f32::EPSILON {
209        0.0
210    } else {
211        (dot / denom).clamp(-1.0, 1.0)
212    }
213}
214
215impl DisplayAs for QualityGateExec {
216    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        write!(
218            f,
219            "QualityGateExec: threshold={}, budget={}",
220            self.config.threshold, self.token_budget
221        )
222    }
223}
224
225impl ExecutionPlan for QualityGateExec {
226    fn name(&self) -> &str {
227        "QualityGateExec"
228    }
229
230    fn as_any(&self) -> &dyn Any {
231        self
232    }
233
234    fn schema(&self) -> SchemaRef {
235        self.schema.clone()
236    }
237
238    fn properties(&self) -> &PlanProperties {
239        &self.properties
240    }
241
242    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
243        vec![&self.input]
244    }
245
246    fn with_new_children(
247        self: Arc<Self>,
248        children: Vec<Arc<dyn ExecutionPlan>>,
249    ) -> Result<Arc<dyn ExecutionPlan>> {
250        let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
251            datafusion_common::DataFusionError::Plan(format!(
252                "QualityGateExec requires exactly 1 child, got {}",
253                v.len()
254            ))
255        })?;
256        Ok(Arc::new(Self::new(
257            child,
258            self.config.clone(),
259            self.token_budget,
260        )))
261    }
262
263    fn execute(
264        &self,
265        partition: usize,
266        context: Arc<TaskContext>,
267    ) -> Result<SendableRecordBatchStream> {
268        let input_stream = self.input.execute(partition, context)?;
269        let schema = self.schema.clone();
270        let config = self.config.clone();
271        let token_budget = self.token_budget;
272
273        let stream = futures::stream::once(async move {
274            use futures::StreamExt;
275            let mut batches = Vec::new();
276            let mut input_stream = input_stream;
277            while let Some(batch_result) = input_stream.next().await {
278                batches.push(batch_result?);
279            }
280
281            let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
282            if total_rows == 0 {
283                // No results: return an empty batch.  The engine-level quality
284                // escalation path (recall_quality_should_escalate) handles the
285                // empty-results case independently and does not rely on a
286                // plan-level sentinel row.  Emitting a null-filled sentinel here
287                // violates the non-nullable schema fields (e.g. `id`) and causes
288                // Arrow to reject the batch with a schema-validation error.
289                return Ok(RecordBatch::new_empty(schema));
290            }
291
292            // Compute average score from score column (if present).
293            let mut total_score = 0.0_f32;
294            let mut score_count = 0_usize;
295            let mut total_tokens = 0_usize;
296
297            for batch in &batches {
298                if let Some(score_col) = batch.column_by_name("score") {
299                    if let Some(scores) = score_col.as_any().downcast_ref::<Float32Array>() {
300                        for i in 0..scores.len() {
301                            if !scores.is_null(i) {
302                                total_score += scores.value(i);
303                                score_count += 1;
304                            }
305                        }
306                    }
307                }
308                if let Some(token_col) = batch.column_by_name("token_count") {
309                    if let Some(tokens) = token_col.as_any().downcast_ref::<UInt32Array>() {
310                        for i in 0..tokens.len() {
311                            if !tokens.is_null(i) {
312                                total_tokens += tokens.value(i) as usize;
313                            }
314                        }
315                    }
316                }
317            }
318
319            let avg_score = if score_count > 0 {
320                total_score / score_count as f32
321            } else {
322                0.0
323            };
324
325            // Merge all batches first so coherence can be computed from embeddings.
326            let merged =
327                arrow_select::concat::concat_batches(&batches[0].schema(), batches.iter())?;
328
329            let coherence = QualityGateExec::compute_coherence_from_batch(
330                &merged,
331                config.coherence_fallback,
332            );
333            let assessment = QualityGateExec::assess_quality(
334                &config,
335                token_budget,
336                total_rows,
337                avg_score,
338                total_tokens,
339                coherence,
340            );
341            let action = if assessment.escalate {
342                "escalate"
343            } else {
344                "pass"
345            };
346
347            let n = merged.num_rows();
348            let quality_scores = Float32Array::from(vec![assessment.combined; n]);
349            let quality_actions = StringArray::from(vec![action.to_string(); n]);
350
351            let mut columns: Vec<Arc<dyn Array>> = merged.columns().to_vec();
352            columns.push(Arc::new(quality_scores));
353            columns.push(Arc::new(quality_actions));
354
355            RecordBatch::try_new(schema, columns).map_err(Into::into)
356        });
357
358        Ok(Box::pin(RecordBatchStreamAdapter::new(
359            self.schema.clone(),
360            stream,
361        )))
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn default_config() {
371        let config = QualityGateConfig::default();
372        assert!((config.threshold - 0.5).abs() < f32::EPSILON);
373        let weight_sum = config.coverage_weight
374            + config.confidence_weight
375            + config.coherence_weight
376            + config.sufficiency_weight;
377        assert!((weight_sum - 1.0).abs() < 0.01);
378    }
379
380    #[test]
381    fn high_quality_no_escalation() {
382        let config = QualityGateConfig::default();
383        let assessment = QualityGateExec::assess_quality(&config, 4096, 10, 0.8, 3000, 0.8);
384        assert!(!assessment.escalate);
385        assert!(assessment.combined > 0.5);
386    }
387
388    #[test]
389    fn low_quality_escalation() {
390        let config = QualityGateConfig::default();
391        let assessment = QualityGateExec::assess_quality(&config, 4096, 1, 0.1, 100, 0.3);
392        assert!(assessment.escalate);
393        assert!(assessment.combined < 0.5);
394    }
395
396    #[test]
397    fn zero_rows_zero_quality() {
398        let config = QualityGateConfig::default();
399        let assessment = QualityGateExec::assess_quality(&config, 4096, 0, 0.0, 0, 0.0);
400        assert!(assessment.escalate);
401        assert!(assessment.combined < 0.5);
402    }
403
404    #[test]
405    fn custom_threshold() {
406        let config = QualityGateConfig {
407            threshold: 0.8,
408            ..Default::default()
409        };
410        // Moderate quality → escalation with high threshold.
411        let assessment = QualityGateExec::assess_quality(&config, 4096, 5, 0.5, 2000, 0.5);
412        assert!(assessment.escalate);
413    }
414}