Skip to main content

hirn_exec/udfs/
source_reliability.rs

1//! `source_reliability` UDF — score memory provenance.
2//!
3//! `source_reliability(source_type: Utf8) → Float32`
4
5use std::any::Any;
6use std::sync::Arc;
7
8use arrow_array::Array;
9use arrow_array::Float32Array;
10use arrow_array::cast::AsArray;
11use datafusion_common::Result;
12use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
13
14use arrow_schema::DataType;
15
16#[derive(Debug, PartialEq, Eq, Hash)]
17pub struct SourceReliabilityUdf {
18    signature: Signature,
19}
20
21impl Default for SourceReliabilityUdf {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl SourceReliabilityUdf {
28    pub fn new() -> Self {
29        Self {
30            signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
31        }
32    }
33}
34
35fn reliability_score(source_type: &str) -> f32 {
36    match source_type {
37        "direct_observation" => 1.0,
38        "agent_generated" => 0.8,
39        "inferred" => 0.6,
40        "cross_agent" => 0.5,
41        _ => 0.4,
42    }
43}
44
45impl ScalarUDFImpl for SourceReliabilityUdf {
46    fn as_any(&self) -> &dyn Any {
47        self
48    }
49
50    fn name(&self) -> &str {
51        "source_reliability"
52    }
53
54    fn signature(&self) -> &Signature {
55        &self.signature
56    }
57
58    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
59        Ok(DataType::Float32)
60    }
61
62    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
63        let num_rows = args.number_rows;
64        let arrays: Vec<_> = args
65            .args
66            .iter()
67            .map(|a| a.to_array(num_rows))
68            .collect::<Result<Vec<_>>>()?;
69
70        let source_types = arrays[0].as_string::<i32>();
71        let len = source_types.len();
72        let mut results = Vec::with_capacity(len);
73
74        for i in 0..len {
75            if source_types.is_null(i) {
76                results.push(None);
77            } else {
78                results.push(Some(reliability_score(source_types.value(i))));
79            }
80        }
81
82        Ok(ColumnarValue::Array(Arc::new(Float32Array::from(results))))
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use arrow_array::StringArray;
90    use arrow_schema::Field;
91    use datafusion_common::config::ConfigOptions;
92
93    fn invoke(types: Vec<Option<&str>>) -> Float32Array {
94        let udf = SourceReliabilityUdf::new();
95        let args = ScalarFunctionArgs {
96            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(
97                types.clone(),
98            )))],
99            number_rows: types.len(),
100            return_field: Arc::new(Field::new("result", DataType::Float32, true)),
101            arg_fields: vec![],
102            config_options: Arc::new(ConfigOptions::new()),
103        };
104        let result = udf.invoke_with_args(args).unwrap();
105        match result {
106            ColumnarValue::Array(a) => a.as_any().downcast_ref::<Float32Array>().unwrap().clone(),
107            _ => panic!("expected array"),
108        }
109    }
110
111    #[test]
112    fn known_source_types() {
113        let vals = invoke(vec![
114            Some("direct_observation"),
115            Some("agent_generated"),
116            Some("inferred"),
117            Some("cross_agent"),
118        ]);
119        assert!((vals.value(0) - 1.0).abs() < 1e-6);
120        assert!((vals.value(1) - 0.8).abs() < 1e-6);
121        assert!((vals.value(2) - 0.6).abs() < 1e-6);
122        assert!((vals.value(3) - 0.5).abs() < 1e-6);
123    }
124
125    #[test]
126    fn unknown_returns_default() {
127        let vals = invoke(vec![Some("unknown_type")]);
128        assert!((vals.value(0) - 0.4).abs() < 1e-6);
129    }
130
131    #[test]
132    fn null_returns_null() {
133        let vals = invoke(vec![None]);
134        assert!(vals.is_null(0));
135    }
136}