use crate::logical_plan::producer;
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use datafusion::prelude::*;
use prost::Message;
use std::path::Path;
use substrait::proto::Plan;
use tokio::{
fs::OpenOptions,
io::{AsyncReadExt, AsyncWriteExt},
};
pub async fn serialize(
sql: &str,
ctx: &SessionContext,
path: impl AsRef<Path>,
) -> Result<()> {
let protobuf_out = serialize_bytes(sql, ctx).await?;
let mut file = OpenOptions::new()
.write(true)
.create_new(true)
.open(path)
.await?;
file.write_all(&protobuf_out).await?;
file.flush().await?;
Ok(())
}
pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) -> Result<Vec<u8>> {
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = producer::to_substrait_plan(&plan, &ctx.state())?;
let mut protobuf_out = Vec::<u8>::new();
proto
.encode(&mut protobuf_out)
.map_err(|e| DataFusionError::Substrait(format!("Failed to encode plan: {e}")))?;
Ok(protobuf_out)
}
pub async fn deserialize(path: impl AsRef<Path>) -> Result<Box<Plan>> {
let mut protobuf_in = Vec::<u8>::new();
let mut file = OpenOptions::new().read(true).open(path).await?;
file.read_to_end(&mut protobuf_in).await?;
deserialize_bytes(protobuf_in).await
}
pub async fn deserialize_bytes(proto_bytes: Vec<u8>) -> Result<Box<Plan>> {
Ok(Box::new(Message::decode(&*proto_bytes).map_err(|e| {
DataFusionError::Substrait(format!("Failed to decode plan: {e}"))
})?))
}