use std::collections::HashMap;
use std::sync::Arc;
use polars::prelude::*;
use polars_plan::dsl::DslPlan;
use crate::error::{Error, Result};
use crate::source::Source;
const TAXA_SCHEME: &str = "taxa://";
const TAXA_PREFIX: &str = "taxa__";
fn scan_path(plan: &DslPlan) -> Option<String> {
match plan {
DslPlan::Scan { sources, .. } => sources.first_path().map(|p| p.to_str().to_string()),
_ => None,
}
}
fn name_from_path(path: &str) -> Option<String> {
if let Some(n) = path.strip_prefix(TAXA_SCHEME) {
return Some(n.to_string());
}
let base = path.rsplit(['/', '\\']).next().unwrap_or(path);
let stem = base
.strip_suffix(".parquet")
.or_else(|| base.strip_suffix(".csv"))
.unwrap_or(base);
stem.strip_prefix(TAXA_PREFIX).map(|n| n.to_string())
}
#[cfg(test)]
fn is_placeholder(path: &str) -> bool {
path.starts_with(TAXA_SCHEME)
|| path
.rsplit(['/', '\\'])
.next()
.unwrap_or(path)
.starts_with(TAXA_PREFIX)
}
pub fn scan_name(plan: &DslPlan) -> Option<String> {
scan_path(plan).as_deref().and_then(name_from_path)
}
pub fn rebind_named(plan: DslPlan, named: &HashMap<String, DslPlan>) -> Result<DslPlan> {
if let Some(name) = scan_name(&plan) {
return named.get(&name).cloned().ok_or_else(|| {
Error::Engine(format!(
"transform references undeclared source `taxa://{name}`"
))
});
}
let recur_arc = |p: Arc<DslPlan>| -> Result<Arc<DslPlan>> {
Ok(Arc::new(rebind_named((*p).clone(), named)?))
};
let recur_vec = |ps: Vec<DslPlan>| -> Result<Vec<DslPlan>> {
ps.into_iter().map(|p| rebind_named(p, named)).collect()
};
let out = match plan {
DslPlan::Filter { input, predicate } => DslPlan::Filter {
input: recur_arc(input)?,
predicate,
},
DslPlan::Cache { input, id } => DslPlan::Cache {
input: recur_arc(input)?,
id,
},
DslPlan::Select {
expr,
input,
options,
} => DslPlan::Select {
expr,
input: recur_arc(input)?,
options,
},
DslPlan::GroupBy {
input,
keys,
aggs,
maintain_order,
options,
apply,
} => DslPlan::GroupBy {
input: recur_arc(input)?,
keys,
aggs,
maintain_order,
options,
apply,
},
DslPlan::HStack {
input,
exprs,
options,
} => DslPlan::HStack {
input: recur_arc(input)?,
exprs,
options,
},
DslPlan::Distinct { input, options } => DslPlan::Distinct {
input: recur_arc(input)?,
options,
},
DslPlan::Sort {
input,
by_column,
slice,
sort_options,
} => DslPlan::Sort {
input: recur_arc(input)?,
by_column,
slice,
sort_options,
},
DslPlan::Slice { input, offset, len } => DslPlan::Slice {
input: recur_arc(input)?,
offset,
len,
},
DslPlan::MapFunction { input, function } => DslPlan::MapFunction {
input: recur_arc(input)?,
function,
},
DslPlan::Sink { input, payload } => DslPlan::Sink {
input: recur_arc(input)?,
payload,
},
DslPlan::Join {
input_left,
input_right,
left_on,
right_on,
predicates,
options,
} => DslPlan::Join {
input_left: recur_arc(input_left)?,
input_right: recur_arc(input_right)?,
left_on,
right_on,
predicates,
options,
},
DslPlan::Union { inputs, args } => DslPlan::Union {
inputs: recur_vec(inputs)?,
args,
},
DslPlan::HConcat { inputs, options } => DslPlan::HConcat {
inputs: recur_vec(inputs)?,
options,
},
DslPlan::ExtContext { input, contexts } => DslPlan::ExtContext {
input: recur_arc(input)?,
contexts: recur_vec(contexts)?,
},
other => other,
};
Ok(out)
}
fn validate_bound(plan: &DslPlan) -> Result<()> {
if let Some(path) = scan_path(plan) {
return Err(Error::Engine(format!(
"transform plan has a disallowed file scan after rebind: `{path}`. \
All data sources must be `taxa://` named sources (bound from the \
manifest's `sources`); direct file scans are not permitted."
)));
}
for child in children(plan) {
validate_bound(child)?;
}
Ok(())
}
fn children(plan: &DslPlan) -> Vec<&DslPlan> {
match plan {
DslPlan::Filter { input, .. }
| DslPlan::Cache { input, .. }
| DslPlan::Select { input, .. }
| DslPlan::GroupBy { input, .. }
| DslPlan::HStack { input, .. }
| DslPlan::Distinct { input, .. }
| DslPlan::Sort { input, .. }
| DslPlan::Slice { input, .. }
| DslPlan::MapFunction { input, .. }
| DslPlan::Sink { input, .. }
| DslPlan::MatchToSchema { input, .. }
| DslPlan::PipeWithSchema { input, .. } => vec![input],
DslPlan::Join {
input_left,
input_right,
..
} => vec![input_left, input_right],
DslPlan::Union { inputs, .. }
| DslPlan::HConcat { inputs, .. }
| DslPlan::SinkMultiple { inputs } => inputs.iter().collect(),
DslPlan::ExtContext { input, contexts } => {
let mut v = vec![input.as_ref()];
v.extend(contexts.iter());
v
}
DslPlan::IR { dsl, .. } => vec![dsl.as_ref()],
_ => vec![],
}
}
pub struct TransformSource {
sources: HashMap<String, Arc<dyn Source>>,
plan: Vec<u8>,
}
impl TransformSource {
pub fn new(sources: HashMap<String, Arc<dyn Source>>, plan_bytes: Vec<u8>) -> Self {
TransformSource {
sources,
plan: plan_bytes,
}
}
}
impl Source for TransformSource {
fn frame(&self) -> Result<LazyFrame> {
let plan = DslPlan::deserialize_versioned(&self.plan[..])
.map_err(|e| Error::Engine(format!("failed to deserialize transform plan: {e}")))?;
let mut named: HashMap<String, DslPlan> = HashMap::with_capacity(self.sources.len());
for (name, src) in &self.sources {
named.insert(name.clone(), src.frame()?.logical_plan);
}
let rebound = rebind_named(plan, &named)?;
validate_bound(&rebound)?;
Ok(LazyFrame::from(rebound))
}
}
#[cfg(test)]
mod name_tests {
use super::{is_placeholder, name_from_path};
#[test]
fn both_placeholder_forms_resolve() {
assert_eq!(
name_from_path("taxa://snapshots").as_deref(),
Some("snapshots")
);
assert_eq!(
name_from_path("taxa__snapshots.parquet").as_deref(),
Some("snapshots")
);
assert_eq!(
name_from_path("/tmp/x/taxa__leaf_map.parquet").as_deref(),
Some("leaf_map")
);
assert_eq!(
name_from_path("taxa__canonical").as_deref(),
Some("canonical")
);
assert_eq!(name_from_path("/data/real.parquet"), None);
assert!(is_placeholder("taxa://x"));
assert!(is_placeholder("/a/taxa__x.parquet"));
assert!(!is_placeholder("/data/real.parquet"));
}
}