use std::sync::Arc;
use datafusion::arrow::array::{ArrayRef, Float64Array, Int64Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::ipc::writer::StreamWriter;
use datafusion::arrow::record_batch::RecordBatch;
use serde::Serialize;
pub(super) fn encode<T: Serialize>(value: &T) -> crate::Result<Vec<u8>> {
rmp_serde::to_vec_named(value).map_err(|e| crate::Error::Codec {
detail: format!("response serialization: {e}"),
})
}
pub fn encode_as_arrow_ipc(
rows: &[(String, serde_json::Value)],
projection: &[String],
) -> Option<Vec<u8>> {
if rows.is_empty() {
return None;
}
let first_obj = rows[0].1.as_object()?;
let field_names: Vec<&str> = if projection.is_empty() {
first_obj.keys().map(|k| k.as_str()).collect()
} else {
projection.iter().map(|s| s.as_str()).collect()
};
if field_names.is_empty() {
return None;
}
let mut fields = vec![Field::new("id", DataType::Utf8, false)];
for &name in &field_names {
let dt = first_obj
.get(name)
.map(infer_type)
.unwrap_or(DataType::Utf8);
fields.push(Field::new(name, dt, true));
}
let schema = Arc::new(Schema::new(fields));
let mut ids: Vec<String> = Vec::with_capacity(rows.len());
let mut builders: Vec<ColBuilder> = field_names
.iter()
.map(|&name| {
let dt = first_obj
.get(name)
.map(infer_type)
.unwrap_or(DataType::Utf8);
ColBuilder::new(dt, rows.len())
})
.collect();
for (doc_id, data) in rows {
ids.push(doc_id.clone());
let obj = data.as_object();
for (i, &name) in field_names.iter().enumerate() {
match obj.and_then(|o| o.get(name)) {
Some(v) => builders[i].push(v),
None => builders[i].push_null(),
}
}
}
let mut arrays: Vec<ArrayRef> = vec![Arc::new(StringArray::from(ids))];
for b in builders {
arrays.push(b.finish());
}
let batch = RecordBatch::try_new(schema.clone(), arrays).ok()?;
let mut buf = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut buf, &schema).ok()?;
writer.write(&batch).ok()?;
writer.finish().ok()?;
}
Some(buf)
}
fn infer_type(v: &serde_json::Value) -> DataType {
match v {
serde_json::Value::Number(n) if n.is_i64() => DataType::Int64,
serde_json::Value::Number(_) => DataType::Float64,
serde_json::Value::Bool(_) => DataType::Boolean,
_ => DataType::Utf8,
}
}
enum ColBuilder {
Str(Vec<Option<String>>),
I64(Vec<Option<i64>>),
F64(Vec<Option<f64>>),
}
impl ColBuilder {
fn new(dt: DataType, cap: usize) -> Self {
match dt {
DataType::Int64 => Self::I64(Vec::with_capacity(cap)),
DataType::Float64 => Self::F64(Vec::with_capacity(cap)),
_ => Self::Str(Vec::with_capacity(cap)),
}
}
fn push(&mut self, v: &serde_json::Value) {
match self {
Self::Str(vec) => vec.push(Some(match v {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
})),
Self::I64(vec) => vec.push(v.as_i64()),
Self::F64(vec) => vec.push(v.as_f64()),
}
}
fn push_null(&mut self) {
match self {
Self::Str(v) => v.push(None),
Self::I64(v) => v.push(None),
Self::F64(v) => v.push(None),
}
}
fn finish(self) -> ArrayRef {
match self {
Self::Str(v) => Arc::new(StringArray::from(v)) as ArrayRef,
Self::I64(v) => Arc::new(Int64Array::from(v)) as ArrayRef,
Self::F64(v) => Arc::new(Float64Array::from(v)) as ArrayRef,
}
}
}
pub(super) fn encode_count(key: &str, count: usize) -> crate::Result<Vec<u8>> {
let mut map = std::collections::BTreeMap::new();
map.insert(key, count);
rmp_serde::to_vec_named(&map).map_err(|e| crate::Error::Codec {
detail: format!("count response serialization: {e}"),
})
}
pub fn decode_payload_to_json(payload: &[u8]) -> String {
if payload.is_empty() {
return String::new();
}
let first = payload[0];
let is_likely_json = first == b'['
|| first == b'{'
|| first == b'"'
|| first.is_ascii_digit()
|| first == b't'
|| first == b'f'
|| first == b'n';
if is_likely_json {
return String::from_utf8_lossy(payload).into_owned();
}
match rmp_serde::from_slice::<serde_json::Value>(payload) {
Ok(value) => serde_json::to_string(&value)
.unwrap_or_else(|_| String::from_utf8_lossy(payload).into_owned()),
Err(_) => String::from_utf8_lossy(payload).into_owned(),
}
}
#[derive(Serialize)]
pub(super) struct VectorSearchHit {
pub id: u32,
pub distance: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub doc_id: Option<String>,
}
#[derive(Serialize)]
pub(super) struct DocumentRow {
pub id: String,
pub data: serde_json::Value,
}
#[derive(Serialize)]
pub(super) struct NeighborEntry<'a> {
pub label: &'a str,
pub node: &'a str,
}
#[derive(Serialize)]
pub(super) struct SubgraphEdge<'a> {
pub src: &'a str,
pub label: &'a str,
pub dst: &'a str,
}
#[derive(Serialize)]
pub(super) struct GraphRagResult {
pub node_id: String,
pub rrf_score: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_rank: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_distance: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hop_distance: Option<usize>,
}
#[derive(Serialize)]
pub(super) struct TextSearchHit<'a> {
pub doc_id: &'a str,
pub score: f32,
pub fuzzy: bool,
}
#[derive(Serialize)]
pub(super) struct HybridSearchHit<'a> {
pub doc_id: &'a str,
pub rrf_score: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_rank: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text_rank: Option<usize>,
}
#[derive(Serialize)]
pub(super) struct GraphRagResponse {
pub results: Vec<GraphRagResult>,
pub metadata: GraphRagMetadata,
}
#[derive(Serialize)]
pub(super) struct GraphRagMetadata {
pub vector_candidates: usize,
pub graph_expanded: usize,
pub truncated: bool,
pub watermark_lsn: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_vector_hits() {
let hits = vec![
VectorSearchHit {
id: 1,
distance: 0.5,
doc_id: None,
},
VectorSearchHit {
id: 2,
distance: 0.8,
doc_id: None,
},
];
let bytes = encode(&hits).unwrap();
assert!(!bytes.is_empty());
let json = decode_payload_to_json(&bytes);
assert!(json.contains("\"id\""));
assert!(json.contains("\"distance\""));
}
#[test]
fn encode_count_msg() {
let bytes = encode_count("inserted", 42).unwrap();
let json = decode_payload_to_json(&bytes);
assert!(json.contains("\"inserted\""));
assert!(json.contains("42"));
}
#[test]
fn json_passthrough() {
let json_str = r#"[{"id":1}]"#;
let result = decode_payload_to_json(json_str.as_bytes());
assert_eq!(result, json_str);
}
#[test]
fn msgpack_to_json_roundtrip() {
let value = serde_json::json!({"key": "value", "num": 42});
let msgpack = rmp_serde::to_vec(&value).unwrap();
let json = decode_payload_to_json(&msgpack);
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["key"], "value");
assert_eq!(parsed["num"], 42);
}
}