use serde::{Deserialize, Serialize};
use super::operator_node::OperatorNode;
use super::query_dag::{DagError, QueryDag};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializableDag {
nodes: Vec<OperatorNode>,
edges: Vec<(usize, usize)>, root: Option<usize>,
}
pub trait DagSerializer {
fn to_json(&self) -> Result<String, serde_json::Error>;
fn to_bytes(&self) -> Vec<u8>;
}
pub trait DagDeserializer {
fn from_json(json: &str) -> Result<Self, serde_json::Error>
where
Self: Sized;
fn from_bytes(bytes: &[u8]) -> Result<Self, DagError>
where
Self: Sized;
}
impl DagSerializer for QueryDag {
fn to_json(&self) -> Result<String, serde_json::Error> {
let nodes: Vec<OperatorNode> = self.nodes.values().cloned().collect();
let mut edges = Vec::new();
for (&parent, children) in &self.edges {
for &child in children {
edges.push((parent, child));
}
}
let serializable = SerializableDag {
nodes,
edges,
root: self.root,
};
serde_json::to_string_pretty(&serializable)
}
fn to_bytes(&self) -> Vec<u8> {
self.to_json().unwrap_or_default().into_bytes()
}
}
impl DagDeserializer for QueryDag {
fn from_json(json: &str) -> Result<Self, serde_json::Error> {
let serializable: SerializableDag = serde_json::from_str(json)?;
let mut dag = QueryDag::new();
let mut id_map = std::collections::HashMap::new();
for node in serializable.nodes {
let old_id = node.id;
let new_id = dag.add_node(node);
id_map.insert(old_id, new_id);
}
for (parent, child) in serializable.edges {
if let (Some(&new_parent), Some(&new_child)) = (id_map.get(&parent), id_map.get(&child))
{
let _ = dag.add_edge(new_parent, new_child);
}
}
if let Some(old_root) = serializable.root {
dag.root = id_map.get(&old_root).copied();
}
Ok(dag)
}
fn from_bytes(bytes: &[u8]) -> Result<Self, DagError> {
let json = String::from_utf8(bytes.to_vec())
.map_err(|e| DagError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
Self::from_json(&json)
.map_err(|e| DagError::InvalidOperation(format!("Deserialization failed: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OperatorNode;
#[test]
fn test_json_serialization() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
let json = dag.to_json().unwrap();
assert!(!json.is_empty());
let deserialized = QueryDag::from_json(&json).unwrap();
assert_eq!(deserialized.node_count(), 3);
assert_eq!(deserialized.edge_count(), 2);
}
#[test]
fn test_bytes_serialization() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
dag.add_edge(id1, id2).unwrap();
let bytes = dag.to_bytes();
assert!(!bytes.is_empty());
let deserialized = QueryDag::from_bytes(&bytes).unwrap();
assert_eq!(deserialized.node_count(), 2);
assert_eq!(deserialized.edge_count(), 1);
}
#[test]
fn test_complex_dag_roundtrip() {
let mut dag = QueryDag::new();
let scan1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let scan2 = dag.add_node(OperatorNode::seq_scan(0, "orders"));
let join = dag.add_node(OperatorNode::hash_join(0, "user_id"));
let filter = dag.add_node(OperatorNode::filter(0, "total > 100"));
let sort = dag.add_node(OperatorNode::sort(0, vec!["date".to_string()]));
let limit = dag.add_node(OperatorNode::limit(0, 10));
dag.add_edge(scan1, join).unwrap();
dag.add_edge(scan2, join).unwrap();
dag.add_edge(join, filter).unwrap();
dag.add_edge(filter, sort).unwrap();
dag.add_edge(sort, limit).unwrap();
let json = dag.to_json().unwrap();
let restored = QueryDag::from_json(&json).unwrap();
assert_eq!(restored.node_count(), dag.node_count());
assert_eq!(restored.edge_count(), dag.edge_count());
}
#[test]
fn test_empty_dag_serialization() {
let dag = QueryDag::new();
let json = dag.to_json().unwrap();
let restored = QueryDag::from_json(&json).unwrap();
assert_eq!(restored.node_count(), 0);
assert_eq!(restored.edge_count(), 0);
}
}