use super::effects::EffectSet;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CaseArm {
pub pattern: CasePattern,
pub guard: Option<ShellValue>,
pub body: Box<ShellIR>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CasePattern {
Literal(String), Wildcard, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ShellIR {
Let {
name: String,
value: ShellValue,
effects: EffectSet,
},
Exec { cmd: Command, effects: EffectSet },
If {
test: ShellValue,
then_branch: Box<ShellIR>,
else_branch: Option<Box<ShellIR>>,
},
Exit { code: u8, message: Option<String> },
Sequence(Vec<ShellIR>),
Noop,
Function {
name: String,
params: Vec<String>,
body: Box<ShellIR>,
},
Echo { value: ShellValue },
For {
var: String,
start: ShellValue,
end: ShellValue,
body: Box<ShellIR>,
},
Case {
scrutinee: ShellValue,
arms: Vec<CaseArm>,
},
While {
condition: ShellValue,
body: Box<ShellIR>,
},
ForIn {
var: String,
items: Vec<ShellValue>,
body: Box<ShellIR>,
},
Break,
Continue,
Return { value: Option<ShellValue> },
}
impl ShellIR {
pub fn effects(&self) -> EffectSet {
match self {
ShellIR::Let { effects, .. } | ShellIR::Exec { effects, .. } => effects.clone(),
ShellIR::If {
then_branch,
else_branch,
..
} => {
let mut combined = then_branch.effects();
if let Some(else_ir) = else_branch {
combined = combined.union(&else_ir.effects());
}
combined
}
ShellIR::Sequence(items) => items
.iter()
.fold(EffectSet::pure(), |acc, item| acc.union(&item.effects())),
ShellIR::Exit { .. }
| ShellIR::Noop
| ShellIR::Echo { .. }
| ShellIR::Return { .. } => EffectSet::pure(),
ShellIR::Function { body, .. } => body.effects(),
ShellIR::For { body, .. } | ShellIR::ForIn { body, .. } => body.effects(),
ShellIR::While { body, .. } => body.effects(),
ShellIR::Case { arms, .. } => arms
.iter()
.fold(EffectSet::pure(), |acc, arm| acc.union(&arm.body.effects())),
ShellIR::Break | ShellIR::Continue => EffectSet::pure(),
}
}
pub fn is_pure(&self) -> bool {
self.effects().is_pure()
}
pub fn collect_used_functions(&self) -> std::collections::HashSet<&str> {
let mut used = std::collections::HashSet::new();
self.collect_functions_recursive(&mut used);
used
}
fn collect_functions_recursive<'a>(&'a self, used: &mut std::collections::HashSet<&'a str>) {
match self {
ShellIR::Exec { cmd, .. } => {
used.insert(&cmd.program);
collect_functions_from_values(&cmd.args, used);
}
ShellIR::Let { value, .. } | ShellIR::Echo { value } => {
value.collect_functions(used);
}
ShellIR::If {
test,
then_branch,
else_branch,
} => {
test.collect_functions(used);
then_branch.collect_functions_recursive(used);
if let Some(eb) = else_branch {
eb.collect_functions_recursive(used);
}
}
ShellIR::Sequence(items) => {
for item in items {
item.collect_functions_recursive(used);
}
}
ShellIR::Function { body, .. } => body.collect_functions_recursive(used),
ShellIR::For {
start, end, body, ..
} => {
start.collect_functions(used);
end.collect_functions(used);
body.collect_functions_recursive(used);
}
ShellIR::While { condition, body } => {
condition.collect_functions(used);
body.collect_functions_recursive(used);
}
ShellIR::Case { scrutinee, arms } => {
scrutinee.collect_functions(used);
collect_functions_from_arms(arms, used);
}
ShellIR::ForIn { items, body, .. } => {
collect_functions_from_values(items, used);
body.collect_functions_recursive(used);
}
ShellIR::Return { value } => {
if let Some(v) = value {
v.collect_functions(used);
}
}
ShellIR::Exit { .. } | ShellIR::Noop | ShellIR::Break | ShellIR::Continue => {}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Command {
pub program: String,
pub args: Vec<ShellValue>,
}
fn collect_functions_from_values<'a>(
values: &'a [ShellValue],
used: &mut std::collections::HashSet<&'a str>,
) {
for v in values {
v.collect_functions(used);
}
}
fn collect_functions_from_arms<'a>(
arms: &'a [CaseArm],
used: &mut std::collections::HashSet<&'a str>,
) {
for arm in arms {
arm.body.collect_functions_recursive(used);
if let Some(guard) = &arm.guard {
guard.collect_functions(used);
}
}
}
impl Command {
pub fn new(program: impl Into<String>) -> Self {
Self {
program: program.into(),
args: Vec::new(),
}
}
pub fn arg(mut self, arg: ShellValue) -> Self {
self.args.push(arg);
self
}
pub fn args(mut self, args: Vec<ShellValue>) -> Self {
self.args.extend(args);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ShellValue {
String(String),
Bool(bool),
Variable(String),
Concat(Vec<ShellValue>),
CommandSubst(Command),
Comparison {
op: ComparisonOp,
left: Box<ShellValue>,
right: Box<ShellValue>,
},
Arithmetic {
op: ArithmeticOp,
left: Box<ShellValue>,
right: Box<ShellValue>,
},
LogicalAnd {
left: Box<ShellValue>,
right: Box<ShellValue>,
},
LogicalOr {
left: Box<ShellValue>,
right: Box<ShellValue>,
},
LogicalNot { operand: Box<ShellValue> },
EnvVar {
name: String,
default: Option<String>,
},
Arg {
position: Option<usize>, },
ArgWithDefault { position: usize, default: String },
ArgCount,
ExitCode,
DynamicArrayAccess {
array: String,
index: Box<ShellValue>,
},
Glob(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ComparisonOp {
NumEq,
NumNe,
Gt,
Ge,
Lt,
Le,
StrEq,
StrNe,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ArithmeticOp {
Add,
Sub,
Mul,
Div,
Mod,
BitAnd,
BitOr,
BitXor,
Shl,
Shr,
}
impl ShellValue {
pub fn is_constant(&self) -> bool {
match self {
ShellValue::String(_) | ShellValue::Bool(_) | ShellValue::Glob(_) => true,
ShellValue::Variable(_)
| ShellValue::CommandSubst(_)
| ShellValue::EnvVar { .. }
| ShellValue::Arg { .. }
| ShellValue::ArgWithDefault { .. }
| ShellValue::ArgCount
| ShellValue::ExitCode
| ShellValue::DynamicArrayAccess { .. } => false,
ShellValue::Concat(parts) => parts.iter().all(|p| p.is_constant()),
ShellValue::Comparison { left, right, .. }
| ShellValue::Arithmetic { left, right, .. }
| ShellValue::LogicalAnd { left, right }
| ShellValue::LogicalOr { left, right } => left.is_constant() && right.is_constant(),
ShellValue::LogicalNot { operand } => operand.is_constant(),
}
}
pub fn collect_functions<'a>(&'a self, used: &mut std::collections::HashSet<&'a str>) {
match self {
ShellValue::CommandSubst(cmd) => {
used.insert(&cmd.program);
for arg in &cmd.args {
arg.collect_functions(used);
}
}
ShellValue::Concat(parts) => {
for part in parts {
part.collect_functions(used);
}
}
ShellValue::Comparison { left, right, .. }
| ShellValue::Arithmetic { left, right, .. }
| ShellValue::LogicalAnd { left, right }
| ShellValue::LogicalOr { left, right } => {
left.collect_functions(used);
right.collect_functions(used);
}
ShellValue::LogicalNot { operand } => {
operand.collect_functions(used);
}
ShellValue::String(_)
| ShellValue::Bool(_)
| ShellValue::Variable(_)
| ShellValue::EnvVar { .. }
| ShellValue::Arg { .. }
| ShellValue::ArgWithDefault { .. }
| ShellValue::ArgCount
| ShellValue::ExitCode
| ShellValue::Glob(_) => {}
ShellValue::DynamicArrayAccess { index, .. } => {
index.collect_functions(used);
}
}
}
pub fn as_constant_string(&self) -> Option<String> {
match self {
ShellValue::String(s) => Some(s.clone()),
ShellValue::Bool(b) => Some(if *b {
"true".to_string()
} else {
"false".to_string()
}),
ShellValue::Concat(parts) => {
if parts.iter().all(|p| p.is_constant()) {
let mut result = String::new();
for part in parts {
if let Some(s) = part.as_constant_string() {
result.push_str(&s);
} else {
return None;
}
}
Some(result)
} else {
None
}
}
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ShellExpression {
String(String),
Variable(String, bool), Command(String),
Arithmetic(String),
}
impl ShellExpression {
pub fn is_quoted(&self) -> bool {
match self {
ShellExpression::String(s) => s.starts_with('"') && s.ends_with('"'),
ShellExpression::Variable(_, quoted) => *quoted,
ShellExpression::Command(_) => false,
ShellExpression::Arithmetic(_) => true,
}
}
}
#[cfg(test)]
#[path = "shell_ir_tests.rs"]
mod tests;