1use std::any::Any;
8use std::fmt;
9use std::sync::Arc;
10
11use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt32Array};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use datafusion_common::Result;
14use datafusion_execution::{SendableRecordBatchStream, TaskContext};
15use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
16use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
17use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
18
19#[derive(Debug, Clone)]
21pub struct QualityGateConfig {
22 pub threshold: f32,
24 pub coverage_weight: f32,
26 pub confidence_weight: f32,
28 pub coherence_weight: f32,
30 pub sufficiency_weight: f32,
32 pub coherence_fallback: f32,
36}
37
38impl Default for QualityGateConfig {
39 fn default() -> Self {
40 Self {
41 threshold: 0.5,
42 coverage_weight: 0.3,
43 confidence_weight: 0.3,
44 coherence_weight: 0.2,
45 sufficiency_weight: 0.2,
46 coherence_fallback: 0.6,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct QualityAssessment {
54 pub coverage: f32,
55 pub confidence: f32,
56 pub coherence: f32,
57 pub sufficiency: f32,
58 pub combined: f32,
59 pub escalate: bool,
60}
61
62#[derive(Debug)]
67pub struct QualityGateExec {
68 input: Arc<dyn ExecutionPlan>,
69 config: QualityGateConfig,
70 token_budget: usize,
72 schema: SchemaRef,
73 properties: PlanProperties,
74}
75
76impl QualityGateExec {
77 pub fn new(
78 input: Arc<dyn ExecutionPlan>,
79 config: QualityGateConfig,
80 token_budget: usize,
81 ) -> Self {
82 let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
83 fields.push(Arc::new(Field::new(
84 "quality_score",
85 DataType::Float32,
86 false,
87 )));
88 fields.push(Arc::new(Field::new(
89 "quality_action",
90 DataType::Utf8,
91 false,
92 )));
93 let schema = Arc::new(Schema::new(fields));
94
95 let properties = PlanProperties::new(
96 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
97 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
98 EmissionType::Final,
99 Boundedness::Bounded,
100 );
101
102 Self {
103 input,
104 config,
105 token_budget,
106 schema,
107 properties,
108 }
109 }
110
111 fn assess_quality(
119 config: &QualityGateConfig,
120 token_budget: usize,
121 row_count: usize,
122 avg_score: f32,
123 total_tokens: usize,
124 coherence: f32,
125 ) -> QualityAssessment {
126 let coverage = if row_count > 0 {
127 1.0_f32.min(row_count as f32 / 5.0)
128 } else {
129 0.0
130 };
131 let confidence = avg_score;
132 let sufficiency = if token_budget > 0 {
133 (total_tokens as f32 / token_budget as f32).min(1.0)
134 } else {
135 1.0
136 };
137
138 let combined = config.coverage_weight * coverage
139 + config.confidence_weight * confidence
140 + config.coherence_weight * coherence
141 + config.sufficiency_weight * sufficiency;
142
143 let escalate = combined < config.threshold;
144
145 QualityAssessment {
146 coverage,
147 confidence,
148 coherence,
149 sufficiency,
150 combined,
151 escalate,
152 }
153 }
154
155 fn compute_coherence_from_batch(batch: &RecordBatch, fallback: f32) -> f32 {
160 let fsl = match batch
161 .column_by_name("embedding")
162 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>())
163 {
164 Some(fsl) => fsl,
165 None => return fallback,
166 };
167
168 let embeddings: Vec<Vec<f32>> = (0..fsl.len())
169 .filter(|&i| !fsl.is_null(i))
170 .filter_map(|i| {
171 let values = fsl.value(i);
172 let f32_arr = values.as_any().downcast_ref::<Float32Array>()?;
173 Some(f32_arr.values().to_vec())
174 })
175 .collect();
176
177 if embeddings.len() < 2 {
178 return fallback;
179 }
180
181 let mut sum = 0.0_f32;
182 let mut count = 0_u32;
183 for i in 0..embeddings.len() {
184 for j in (i + 1)..embeddings.len() {
185 sum += cosine_similarity(&embeddings[i], &embeddings[j]);
186 count += 1;
187 }
188 }
189
190 if count > 0 {
191 (sum / count as f32).clamp(0.0, 1.0)
192 } else {
193 fallback
194 }
195 }
196}
197
198fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199 let mut dot = 0.0_f32;
200 let mut norm_a = 0.0_f32;
201 let mut norm_b = 0.0_f32;
202 for (x, y) in a.iter().zip(b.iter()) {
203 dot += x * y;
204 norm_a += x * x;
205 norm_b += y * y;
206 }
207 let denom = norm_a.sqrt() * norm_b.sqrt();
208 if denom < f32::EPSILON {
209 0.0
210 } else {
211 (dot / denom).clamp(-1.0, 1.0)
212 }
213}
214
215impl DisplayAs for QualityGateExec {
216 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 write!(
218 f,
219 "QualityGateExec: threshold={}, budget={}",
220 self.config.threshold, self.token_budget
221 )
222 }
223}
224
225impl ExecutionPlan for QualityGateExec {
226 fn name(&self) -> &str {
227 "QualityGateExec"
228 }
229
230 fn as_any(&self) -> &dyn Any {
231 self
232 }
233
234 fn schema(&self) -> SchemaRef {
235 self.schema.clone()
236 }
237
238 fn properties(&self) -> &PlanProperties {
239 &self.properties
240 }
241
242 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
243 vec![&self.input]
244 }
245
246 fn with_new_children(
247 self: Arc<Self>,
248 children: Vec<Arc<dyn ExecutionPlan>>,
249 ) -> Result<Arc<dyn ExecutionPlan>> {
250 let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
251 datafusion_common::DataFusionError::Plan(format!(
252 "QualityGateExec requires exactly 1 child, got {}",
253 v.len()
254 ))
255 })?;
256 Ok(Arc::new(Self::new(
257 child,
258 self.config.clone(),
259 self.token_budget,
260 )))
261 }
262
263 fn execute(
264 &self,
265 partition: usize,
266 context: Arc<TaskContext>,
267 ) -> Result<SendableRecordBatchStream> {
268 let input_stream = self.input.execute(partition, context)?;
269 let schema = self.schema.clone();
270 let config = self.config.clone();
271 let token_budget = self.token_budget;
272
273 let stream = futures::stream::once(async move {
274 use futures::StreamExt;
275 let mut batches = Vec::new();
276 let mut input_stream = input_stream;
277 while let Some(batch_result) = input_stream.next().await {
278 batches.push(batch_result?);
279 }
280
281 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
282 if total_rows == 0 {
283 return Ok(RecordBatch::new_empty(schema));
290 }
291
292 let mut total_score = 0.0_f32;
294 let mut score_count = 0_usize;
295 let mut total_tokens = 0_usize;
296
297 for batch in &batches {
298 if let Some(score_col) = batch.column_by_name("score") {
299 if let Some(scores) = score_col.as_any().downcast_ref::<Float32Array>() {
300 for i in 0..scores.len() {
301 if !scores.is_null(i) {
302 total_score += scores.value(i);
303 score_count += 1;
304 }
305 }
306 }
307 }
308 if let Some(token_col) = batch.column_by_name("token_count") {
309 if let Some(tokens) = token_col.as_any().downcast_ref::<UInt32Array>() {
310 for i in 0..tokens.len() {
311 if !tokens.is_null(i) {
312 total_tokens += tokens.value(i) as usize;
313 }
314 }
315 }
316 }
317 }
318
319 let avg_score = if score_count > 0 {
320 total_score / score_count as f32
321 } else {
322 0.0
323 };
324
325 let merged =
327 arrow_select::concat::concat_batches(&batches[0].schema(), batches.iter())?;
328
329 let coherence = QualityGateExec::compute_coherence_from_batch(
330 &merged,
331 config.coherence_fallback,
332 );
333 let assessment = QualityGateExec::assess_quality(
334 &config,
335 token_budget,
336 total_rows,
337 avg_score,
338 total_tokens,
339 coherence,
340 );
341 let action = if assessment.escalate {
342 "escalate"
343 } else {
344 "pass"
345 };
346
347 let n = merged.num_rows();
348 let quality_scores = Float32Array::from(vec![assessment.combined; n]);
349 let quality_actions = StringArray::from(vec![action.to_string(); n]);
350
351 let mut columns: Vec<Arc<dyn Array>> = merged.columns().to_vec();
352 columns.push(Arc::new(quality_scores));
353 columns.push(Arc::new(quality_actions));
354
355 RecordBatch::try_new(schema, columns).map_err(Into::into)
356 });
357
358 Ok(Box::pin(RecordBatchStreamAdapter::new(
359 self.schema.clone(),
360 stream,
361 )))
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn default_config() {
371 let config = QualityGateConfig::default();
372 assert!((config.threshold - 0.5).abs() < f32::EPSILON);
373 let weight_sum = config.coverage_weight
374 + config.confidence_weight
375 + config.coherence_weight
376 + config.sufficiency_weight;
377 assert!((weight_sum - 1.0).abs() < 0.01);
378 }
379
380 #[test]
381 fn high_quality_no_escalation() {
382 let config = QualityGateConfig::default();
383 let assessment = QualityGateExec::assess_quality(&config, 4096, 10, 0.8, 3000, 0.8);
384 assert!(!assessment.escalate);
385 assert!(assessment.combined > 0.5);
386 }
387
388 #[test]
389 fn low_quality_escalation() {
390 let config = QualityGateConfig::default();
391 let assessment = QualityGateExec::assess_quality(&config, 4096, 1, 0.1, 100, 0.3);
392 assert!(assessment.escalate);
393 assert!(assessment.combined < 0.5);
394 }
395
396 #[test]
397 fn zero_rows_zero_quality() {
398 let config = QualityGateConfig::default();
399 let assessment = QualityGateExec::assess_quality(&config, 4096, 0, 0.0, 0, 0.0);
400 assert!(assessment.escalate);
401 assert!(assessment.combined < 0.5);
402 }
403
404 #[test]
405 fn custom_threshold() {
406 let config = QualityGateConfig {
407 threshold: 0.8,
408 ..Default::default()
409 };
410 let assessment = QualityGateExec::assess_quality(&config, 4096, 5, 0.5, 2000, 0.5);
412 assert!(assessment.escalate);
413 }
414}