use crate::flow::cfg::{BlockId, CFG, Terminator};
use crate::flow::dataflow::find_node_by_id;
use crate::semantics::LanguageSemantics;
use rma_parser::ParsedFile;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct State {
pub name: String,
pub is_initial: bool,
pub is_final: bool,
pub is_error: bool,
}
impl State {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
is_initial: false,
is_final: false,
is_error: false,
}
}
pub fn initial(name: impl Into<String>) -> Self {
Self {
name: name.into(),
is_initial: true,
is_final: false,
is_error: false,
}
}
pub fn final_state(name: impl Into<String>) -> Self {
Self {
name: name.into(),
is_initial: false,
is_final: true,
is_error: false,
}
}
pub fn error(name: impl Into<String>) -> Self {
Self {
name: name.into(),
is_initial: false,
is_final: false,
is_error: true,
}
}
pub fn with_initial(mut self, is_initial: bool) -> Self {
self.is_initial = is_initial;
self
}
pub fn with_final(mut self, is_final: bool) -> Self {
self.is_final = is_final;
self
}
pub fn with_error(mut self, is_error: bool) -> Self {
self.is_error = is_error;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TransitionTrigger {
MethodCall(String),
Assignment,
FunctionReturn(String),
Destructor,
PatternMatch(String),
}
impl TransitionTrigger {
pub fn method(name: impl Into<String>) -> Self {
Self::MethodCall(name.into())
}
pub fn function_return(name: impl Into<String>) -> Self {
Self::FunctionReturn(name.into())
}
pub fn matches_method(&self, method: &str) -> bool {
match self {
TransitionTrigger::MethodCall(m) => m == method || m == "*",
_ => false,
}
}
pub fn matches_function_return(&self, func: &str) -> bool {
match self {
TransitionTrigger::FunctionReturn(f) => f == func || f == "*",
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub struct Transition {
pub from: String,
pub to: String,
pub trigger: TransitionTrigger,
}
impl Transition {
pub fn new(from: impl Into<String>, to: impl Into<String>, trigger: TransitionTrigger) -> Self {
Self {
from: from.into(),
to: to.into(),
trigger,
}
}
pub fn on_method(
from: impl Into<String>,
to: impl Into<String>,
method: impl Into<String>,
) -> Self {
Self::new(from, to, TransitionTrigger::MethodCall(method.into()))
}
pub fn on_assignment(from: impl Into<String>, to: impl Into<String>) -> Self {
Self::new(from, to, TransitionTrigger::Assignment)
}
pub fn on_destructor(from: impl Into<String>, to: impl Into<String>) -> Self {
Self::new(from, to, TransitionTrigger::Destructor)
}
}
#[derive(Debug, Clone)]
pub struct StateMachine {
pub name: String,
pub states: Vec<State>,
pub transitions: Vec<Transition>,
pub tracked_types: Vec<String>,
}
impl StateMachine {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
states: Vec::new(),
transitions: Vec::new(),
tracked_types: Vec::new(),
}
}
pub fn with_state(mut self, state: State) -> Self {
self.states.push(state);
self
}
pub fn with_transition(mut self, transition: Transition) -> Self {
self.transitions.push(transition);
self
}
pub fn with_tracked_type(mut self, type_name: impl Into<String>) -> Self {
self.tracked_types.push(type_name.into());
self
}
pub fn with_tracked_types(mut self, type_names: &[&str]) -> Self {
self.tracked_types
.extend(type_names.iter().map(|s| s.to_string()));
self
}
pub fn initial_state(&self) -> Option<&State> {
self.states.iter().find(|s| s.is_initial)
}
pub fn get_state(&self, name: &str) -> Option<&State> {
self.states.iter().find(|s| s.name == name)
}
pub fn is_final_state(&self, name: &str) -> bool {
self.get_state(name).map(|s| s.is_final).unwrap_or(false)
}
pub fn is_error_state(&self, name: &str) -> bool {
self.get_state(name).map(|s| s.is_error).unwrap_or(false)
}
pub fn get_method_transition(&self, from_state: &str, method: &str) -> Option<&Transition> {
self.transitions
.iter()
.find(|t| t.from == from_state && t.trigger.matches_method(method))
}
pub fn get_transition(
&self,
from_state: &str,
trigger: &TransitionTrigger,
) -> Option<&Transition> {
self.transitions
.iter()
.find(|t| t.from == from_state && &t.trigger == trigger)
}
pub fn tracks_type(&self, type_name: &str) -> bool {
self.tracked_types
.iter()
.any(|t| t == type_name || type_name.ends_with(t))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ViolationKind {
InvalidTransition,
MissingTransition,
UseInErrorState,
NonFinalStateAtExit,
ConflictingStates,
}
impl std::fmt::Display for ViolationKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ViolationKind::InvalidTransition => write!(f, "Invalid state transition"),
ViolationKind::MissingTransition => write!(f, "Missing required state transition"),
ViolationKind::UseInErrorState => write!(f, "Use of object in error state"),
ViolationKind::NonFinalStateAtExit => write!(f, "Object not in final state at exit"),
ViolationKind::ConflictingStates => write!(f, "Conflicting states at merge point"),
}
}
}
#[derive(Debug, Clone)]
pub struct TypestateViolation {
pub kind: ViolationKind,
pub location: usize,
pub line: usize,
pub current_state: String,
pub attempted_transition: Option<String>,
pub message: String,
}
impl TypestateViolation {
pub fn new(
kind: ViolationKind,
location: usize,
line: usize,
current_state: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self {
kind,
location,
line,
current_state: current_state.into(),
attempted_transition: None,
message: message.into(),
}
}
pub fn with_attempted_transition(mut self, transition: impl Into<String>) -> Self {
self.attempted_transition = Some(transition.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrackedState {
Known(String),
Unknown,
Conflicting(HashSet<String>),
}
impl TrackedState {
pub fn is_known(&self) -> bool {
matches!(self, TrackedState::Known(_))
}
pub fn state_name(&self) -> Option<&str> {
match self {
TrackedState::Known(name) => Some(name),
_ => None,
}
}
pub fn merge(&self, other: &TrackedState) -> TrackedState {
match (self, other) {
(TrackedState::Known(a), TrackedState::Known(b)) if a == b => {
TrackedState::Known(a.clone())
}
(TrackedState::Known(a), TrackedState::Known(b)) => {
let mut set = HashSet::new();
set.insert(a.clone());
set.insert(b.clone());
TrackedState::Conflicting(set)
}
(TrackedState::Known(a), TrackedState::Unknown)
| (TrackedState::Unknown, TrackedState::Known(a)) => TrackedState::Known(a.clone()),
(TrackedState::Unknown, TrackedState::Unknown) => TrackedState::Unknown,
(TrackedState::Conflicting(set), TrackedState::Known(s))
| (TrackedState::Known(s), TrackedState::Conflicting(set)) => {
let mut new_set = set.clone();
new_set.insert(s.clone());
TrackedState::Conflicting(new_set)
}
(TrackedState::Conflicting(a), TrackedState::Conflicting(b)) => {
let mut new_set = a.clone();
new_set.extend(b.iter().cloned());
TrackedState::Conflicting(new_set)
}
(TrackedState::Conflicting(set), TrackedState::Unknown)
| (TrackedState::Unknown, TrackedState::Conflicting(set)) => {
TrackedState::Conflicting(set.clone())
}
}
}
}
#[derive(Debug, Clone)]
pub struct TypestateResult {
pub variable: String,
pub state_machine: String,
pub violations: Vec<TypestateViolation>,
pub block_states: HashMap<BlockId, TrackedState>,
pub block_exit_states: HashMap<BlockId, TrackedState>,
}
impl TypestateResult {
pub fn new(variable: impl Into<String>, state_machine: impl Into<String>) -> Self {
Self {
variable: variable.into(),
state_machine: state_machine.into(),
violations: Vec::new(),
block_states: HashMap::new(),
block_exit_states: HashMap::new(),
}
}
pub fn has_violations(&self) -> bool {
!self.violations.is_empty()
}
pub fn state_at_block(&self, block_id: BlockId) -> Option<&TrackedState> {
self.block_states.get(&block_id)
}
pub fn exit_state_at_block(&self, block_id: BlockId) -> Option<&TrackedState> {
self.block_exit_states.get(&block_id)
}
}
#[derive(Debug, Clone)]
pub struct MethodCallInfo {
pub node_id: usize,
pub line: usize,
pub method_name: String,
pub receiver: Option<String>,
}
pub fn find_method_calls_on_var(
parsed: &ParsedFile,
var_name: &str,
semantics: &LanguageSemantics,
) -> Vec<MethodCallInfo> {
let mut results = Vec::new();
let source = parsed.content.as_bytes();
fn walk_node<'a>(
node: tree_sitter::Node<'a>,
source: &[u8],
var_name: &str,
semantics: &LanguageSemantics,
results: &mut Vec<MethodCallInfo>,
) {
let kind = node.kind();
if semantics.is_call(kind) {
if let Some(func_node) = node.child_by_field_name(semantics.function_field) {
if semantics.is_member_access(func_node.kind()) {
if let (Some(obj), Some(method)) = (
func_node.child_by_field_name(semantics.object_field),
func_node.child_by_field_name(semantics.property_field),
) {
if let Ok(obj_text) = obj.utf8_text(source) {
if obj_text == var_name {
if let Ok(method_text) = method.utf8_text(source) {
results.push(MethodCallInfo {
node_id: node.id(),
line: node.start_position().row + 1,
method_name: method_text.to_string(),
receiver: Some(var_name.to_string()),
});
}
}
}
}
}
}
}
let mut cursor = node.walk();
for child in node.named_children(&mut cursor) {
walk_node(child, source, var_name, semantics, results);
}
}
walk_node(
parsed.tree.root_node(),
source,
var_name,
semantics,
&mut results,
);
results.sort_by_key(|info| info.line);
results
}
pub fn find_assignments_to_var(
parsed: &ParsedFile,
var_name: &str,
semantics: &LanguageSemantics,
) -> Vec<(usize, usize)> {
let mut results = Vec::new();
let source = parsed.content.as_bytes();
fn walk_node<'a>(
node: tree_sitter::Node<'a>,
source: &[u8],
var_name: &str,
semantics: &LanguageSemantics,
results: &mut Vec<(usize, usize)>,
) {
let kind = node.kind();
if semantics.is_assignment(kind) || semantics.is_variable_declaration(kind) {
let left = node
.child_by_field_name(semantics.left_field)
.or_else(|| node.child_by_field_name(semantics.name_field));
if let Some(left) = left {
if let Ok(left_text) = left.utf8_text(source) {
if left_text == var_name
|| left_text.trim_start_matches("mut ").trim() == var_name
{
results.push((node.id(), node.start_position().row + 1));
}
}
}
}
let mut cursor = node.walk();
for child in node.named_children(&mut cursor) {
walk_node(child, source, var_name, semantics, results);
}
}
walk_node(
parsed.tree.root_node(),
source,
var_name,
semantics,
&mut results,
);
results.sort_by_key(|(_, line)| *line);
results
}
pub struct TypestateAnalyzer {
state_machines: Vec<StateMachine>,
semantics: &'static LanguageSemantics,
}
impl TypestateAnalyzer {
pub fn new(semantics: &'static LanguageSemantics) -> Self {
Self {
state_machines: Vec::new(),
semantics,
}
}
pub fn with_state_machine(mut self, sm: StateMachine) -> Self {
self.state_machines.push(sm);
self
}
pub fn with_state_machines(mut self, machines: Vec<StateMachine>) -> Self {
self.state_machines.extend(machines);
self
}
pub fn state_machines(&self) -> &[StateMachine] {
&self.state_machines
}
pub fn analyze(&self, parsed: &ParsedFile, cfg: &CFG) -> Vec<TypestateResult> {
let mut results = Vec::new();
let tracked_vars = self.find_tracked_variables(parsed);
for (var_name, sm) in tracked_vars {
let result = self.track_variable_state(&var_name, sm, cfg, parsed);
results.push(result);
}
results
}
fn find_tracked_variables<'a>(
&'a self,
parsed: &ParsedFile,
) -> Vec<(String, &'a StateMachine)> {
let mut tracked = Vec::new();
let source = parsed.content.as_bytes();
fn walk_for_declarations<'a>(
node: tree_sitter::Node,
source: &[u8],
semantics: &LanguageSemantics,
state_machines: &'a [StateMachine],
tracked: &mut Vec<(String, &'a StateMachine)>,
) {
let kind = node.kind();
if semantics.is_variable_declaration(kind) {
let name = node
.child_by_field_name(semantics.name_field)
.or_else(|| node.child_by_field_name("name"))
.or_else(|| node.child_by_field_name("pattern"));
let value = node
.child_by_field_name(semantics.value_field)
.or_else(|| node.child_by_field_name("value"));
if let (Some(name_node), Some(value_node)) = (name, value) {
if let Ok(var_name) = name_node.utf8_text(source) {
let var_name = var_name.trim_start_matches("mut ").trim().to_string();
if semantics.is_call(value_node.kind()) {
if let Some(func) =
value_node.child_by_field_name(semantics.function_field)
{
if let Ok(func_name) = func.utf8_text(source) {
for sm in state_machines {
if sm.tracks_type(func_name)
|| sm.transitions.iter().any(|t| {
t.trigger.matches_function_return(func_name)
})
{
tracked.push((var_name.clone(), sm));
break;
}
}
}
}
}
if semantics.is_member_access(value_node.kind()) {
if let Ok(expr_text) = value_node.utf8_text(source) {
for sm in state_machines {
if sm.tracks_type(expr_text) {
tracked.push((var_name.clone(), sm));
break;
}
}
}
}
}
}
}
let mut cursor = node.walk();
for child in node.named_children(&mut cursor) {
if !semantics.is_function_def(child.kind()) {
walk_for_declarations(child, source, semantics, state_machines, tracked);
}
}
}
walk_for_declarations(
parsed.tree.root_node(),
source,
self.semantics,
&self.state_machines,
&mut tracked,
);
tracked
}
pub fn track_variable_state(
&self,
var_name: &str,
sm: &StateMachine,
cfg: &CFG,
parsed: &ParsedFile,
) -> TypestateResult {
let mut result = TypestateResult::new(var_name, &sm.name);
let method_calls = find_method_calls_on_var(parsed, var_name, self.semantics);
let assignments = find_assignments_to_var(parsed, var_name, self.semantics);
let mut node_events: HashMap<usize, (String, bool)> = HashMap::new(); for call in &method_calls {
node_events.insert(call.node_id, (call.method_name.clone(), true));
}
for (node_id, _) in &assignments {
node_events.insert(*node_id, ("__assignment__".to_string(), false));
}
let initial_state = sm
.initial_state()
.map(|s| TrackedState::Known(s.name.clone()))
.unwrap_or(TrackedState::Unknown);
for block in &cfg.blocks {
result.block_states.insert(block.id, TrackedState::Unknown);
result
.block_exit_states
.insert(block.id, TrackedState::Unknown);
}
result.block_states.insert(cfg.entry, initial_state.clone());
let mut worklist: VecDeque<BlockId> = VecDeque::new();
let mut in_worklist: HashSet<BlockId> = HashSet::new();
worklist.push_back(cfg.entry);
in_worklist.insert(cfg.entry);
let max_iterations = cfg.blocks.len() * 10;
let mut iterations = 0;
while let Some(block_id) = worklist.pop_front() {
in_worklist.remove(&block_id);
iterations += 1;
if iterations > max_iterations {
break;
}
if block_id >= cfg.blocks.len() {
continue;
}
let block = &cfg.blocks[block_id];
if !block.reachable {
continue;
}
let entry_state = if block_id == cfg.entry {
initial_state.clone()
} else {
let mut merged = TrackedState::Unknown;
let mut has_pred = false;
for &pred in &block.predecessors {
if let Some(pred_exit) = result.block_exit_states.get(&pred) {
if has_pred {
merged = merged.merge(pred_exit);
} else {
merged = pred_exit.clone();
has_pred = true;
}
}
}
merged
};
let mut current_state = entry_state.clone();
for &stmt_node_id in &block.statements {
if let Some((event_name, is_method)) = node_events.get(&stmt_node_id) {
let line = self.get_line_for_node(parsed, stmt_node_id);
if *is_method {
current_state = self.apply_method_transition(
¤t_state,
event_name,
sm,
stmt_node_id,
line,
var_name,
&mut result.violations,
);
} else {
current_state = self.apply_assignment_transition(
¤t_state,
sm,
stmt_node_id,
line,
var_name,
&mut result.violations,
);
}
}
}
if let TrackedState::Conflicting(states) = ¤t_state {
let state_list: Vec<_> = states.iter().cloned().collect();
result.violations.push(TypestateViolation::new(
ViolationKind::ConflictingStates,
block.statements.first().copied().unwrap_or(0),
self.get_line_for_block(parsed, cfg, block_id),
state_list.join(" | "),
format!(
"Variable '{}' has conflicting states at this point: {}",
var_name,
state_list.join(", ")
),
));
}
let old_exit = result.block_exit_states.get(&block_id).cloned();
let state_changed = old_exit.as_ref() != Some(¤t_state);
result.block_states.insert(block_id, entry_state);
result.block_exit_states.insert(block_id, current_state);
if state_changed {
for succ in cfg.successors(block_id) {
if !in_worklist.contains(&succ) {
worklist.push_back(succ);
in_worklist.insert(succ);
}
}
}
}
self.check_exit_states(&mut result, sm, cfg, parsed, var_name);
result
}
fn apply_method_transition(
&self,
current_state: &TrackedState,
method: &str,
sm: &StateMachine,
node_id: usize,
line: usize,
var_name: &str,
violations: &mut Vec<TypestateViolation>,
) -> TrackedState {
match current_state {
TrackedState::Known(state_name) => {
if sm.is_error_state(state_name) {
violations.push(
TypestateViolation::new(
ViolationKind::UseInErrorState,
node_id,
line,
state_name,
format!(
"Method '{}' called on '{}' which is in error state '{}'",
method, var_name, state_name
),
)
.with_attempted_transition(method.to_string()),
);
return current_state.clone();
}
if let Some(transition) = sm.get_method_transition(state_name, method) {
TrackedState::Known(transition.to.clone())
} else {
let has_any_transition = sm.transitions.iter().any(|t| t.from == *state_name);
if has_any_transition {
violations.push(
TypestateViolation::new(
ViolationKind::InvalidTransition,
node_id,
line,
state_name,
format!(
"Invalid method '{}' called on '{}' in state '{}' - no transition defined",
method, var_name, state_name
),
)
.with_attempted_transition(method.to_string()),
);
}
current_state.clone()
}
}
TrackedState::Unknown => TrackedState::Unknown,
TrackedState::Conflicting(states) => {
let mut new_states = HashSet::new();
for state_name in states {
if let Some(transition) = sm.get_method_transition(state_name, method) {
new_states.insert(transition.to.clone());
} else {
new_states.insert(state_name.clone());
}
}
if new_states.len() == 1 {
TrackedState::Known(new_states.into_iter().next().unwrap())
} else {
TrackedState::Conflicting(new_states)
}
}
}
}
fn apply_assignment_transition(
&self,
current_state: &TrackedState,
sm: &StateMachine,
_node_id: usize,
_line: usize,
_var_name: &str,
_violations: &mut Vec<TypestateViolation>,
) -> TrackedState {
if let TrackedState::Known(state_name) = current_state {
if let Some(transition) = sm.get_transition(state_name, &TransitionTrigger::Assignment)
{
return TrackedState::Known(transition.to.clone());
}
}
if let Some(initial) = sm.initial_state() {
TrackedState::Known(initial.name.clone())
} else {
TrackedState::Unknown
}
}
fn check_exit_states(
&self,
result: &mut TypestateResult,
sm: &StateMachine,
cfg: &CFG,
parsed: &ParsedFile,
var_name: &str,
) {
for block in &cfg.blocks {
if !block.reachable {
continue;
}
let is_exit = matches!(
block.terminator,
Terminator::Return | Terminator::Unreachable
);
if is_exit {
if let Some(exit_state) = result.block_exit_states.get(&block.id) {
match exit_state {
TrackedState::Known(state_name) => {
if !sm.is_final_state(state_name) && !sm.is_error_state(state_name) {
let line = self.get_line_for_block(parsed, cfg, block.id);
result.violations.push(TypestateViolation::new(
ViolationKind::NonFinalStateAtExit,
block.statements.last().copied().unwrap_or(0),
line,
state_name,
format!(
"Variable '{}' is in state '{}' at function exit, but expected a final state ({})",
var_name,
state_name,
sm.states.iter()
.filter(|s| s.is_final)
.map(|s| &s.name)
.cloned()
.collect::<Vec<_>>()
.join(", ")
),
));
}
}
TrackedState::Conflicting(states) => {
let non_final: Vec<_> = states
.iter()
.filter(|s| !sm.is_final_state(s))
.cloned()
.collect();
if !non_final.is_empty() {
let line = self.get_line_for_block(parsed, cfg, block.id);
result.violations.push(TypestateViolation::new(
ViolationKind::NonFinalStateAtExit,
block.statements.last().copied().unwrap_or(0),
line,
non_final.join(" | "),
format!(
"Variable '{}' may be in non-final state(s) {} at function exit",
var_name,
non_final.join(", ")
),
));
}
}
TrackedState::Unknown => {
}
}
}
}
}
}
fn get_line_for_node(&self, parsed: &ParsedFile, node_id: usize) -> usize {
find_node_by_id(&parsed.tree, node_id)
.map(|n| n.start_position().row + 1)
.unwrap_or(0)
}
fn get_line_for_block(&self, parsed: &ParsedFile, cfg: &CFG, block_id: BlockId) -> usize {
if block_id < cfg.blocks.len() {
let block = &cfg.blocks[block_id];
if let Some(&first_stmt) = block.statements.first() {
return self.get_line_for_node(parsed, first_stmt);
}
}
0
}
pub fn get_transition<'a>(
&self,
sm: &'a StateMachine,
method: &str,
current_state: &str,
) -> Option<&'a Transition> {
sm.get_method_transition(current_state, method)
}
pub fn check_all_paths_final(
&self,
sm: &StateMachine,
cfg: &CFG,
states: &HashMap<BlockId, String>,
) -> Vec<TypestateViolation> {
let mut violations = Vec::new();
for block in &cfg.blocks {
if !block.reachable {
continue;
}
let is_exit = matches!(
block.terminator,
Terminator::Return | Terminator::Unreachable
);
if is_exit {
if let Some(state) = states.get(&block.id) {
if !sm.is_final_state(state) {
violations.push(TypestateViolation::new(
ViolationKind::NonFinalStateAtExit,
block.statements.last().copied().unwrap_or(0),
0,
state,
format!("Path exits with non-final state: {}", state),
));
}
}
}
}
violations
}
}
pub fn file_state_machine() -> StateMachine {
StateMachine::new("File")
.with_state(State::initial("Unopened"))
.with_state(State::new("Open").with_final(false))
.with_state(State::final_state("Closed"))
.with_state(State::error("UseAfterClose"))
.with_transition(Transition::on_method("Unopened", "Open", "open"))
.with_transition(Transition::on_method("Unopened", "Open", "create"))
.with_transition(Transition::on_method("Open", "Open", "read"))
.with_transition(Transition::on_method("Open", "Open", "write"))
.with_transition(Transition::on_method("Open", "Open", "flush"))
.with_transition(Transition::on_method("Open", "Closed", "close"))
.with_transition(Transition::on_method("Closed", "UseAfterClose", "read"))
.with_transition(Transition::on_method("Closed", "UseAfterClose", "write"))
.with_tracked_types(&["File", "std::fs::File", "fs.File", "FileHandle"])
}
pub fn lock_state_machine() -> StateMachine {
StateMachine::new("Lock")
.with_state(State::initial("Unlocked").with_final(true))
.with_state(State::new("Locked").with_final(false))
.with_state(State::error("DoubleLock"))
.with_state(State::error("DoubleUnlock"))
.with_transition(Transition::on_method("Unlocked", "Locked", "lock"))
.with_transition(Transition::on_method("Unlocked", "Locked", "acquire"))
.with_transition(Transition::on_method("Locked", "Unlocked", "unlock"))
.with_transition(Transition::on_method("Locked", "Unlocked", "release"))
.with_transition(Transition::on_method("Locked", "DoubleLock", "lock"))
.with_transition(Transition::on_method("Unlocked", "DoubleUnlock", "unlock"))
.with_tracked_types(&["Lock", "Mutex", "RwLock", "sync.Mutex"])
}
pub fn connection_state_machine() -> StateMachine {
StateMachine::new("Connection")
.with_state(State::initial("Disconnected"))
.with_state(State::new("Connected").with_final(false))
.with_state(State::final_state("Closed"))
.with_state(State::error("UseAfterClose"))
.with_transition(Transition::on_method(
"Disconnected",
"Connected",
"connect",
))
.with_transition(Transition::on_method("Disconnected", "Connected", "open"))
.with_transition(Transition::on_method("Connected", "Connected", "query"))
.with_transition(Transition::on_method("Connected", "Connected", "execute"))
.with_transition(Transition::on_method("Connected", "Closed", "close"))
.with_transition(Transition::on_method("Connected", "Closed", "disconnect"))
.with_transition(Transition::on_method("Closed", "UseAfterClose", "query"))
.with_transition(Transition::on_method("Closed", "UseAfterClose", "execute"))
.with_tracked_types(&["Connection", "DatabaseConnection", "DbConnection", "sql.DB"])
}
pub fn iterator_state_machine() -> StateMachine {
StateMachine::new("Iterator")
.with_state(State::initial("Ready").with_final(true))
.with_state(State::new("Iterating").with_final(true))
.with_state(State::new("Exhausted").with_final(true))
.with_transition(Transition::on_method("Ready", "Iterating", "next"))
.with_transition(Transition::on_method("Iterating", "Iterating", "next"))
.with_transition(Transition::on_method("Iterating", "Exhausted", "collect"))
.with_tracked_types(&["Iterator", "IntoIterator"])
}
use crate::flow::FlowContext;
impl FlowContext {
pub fn analyze_typestate(&mut self, _state_machines: &[StateMachine]) -> Vec<TypestateResult> {
Vec::new() }
}
pub fn analyze_typestate_with_context(
parsed: &ParsedFile,
cfg: &CFG,
semantics: &'static LanguageSemantics,
state_machines: &[StateMachine],
) -> Vec<TypestateResult> {
let analyzer = TypestateAnalyzer::new(semantics).with_state_machines(state_machines.to_vec());
analyzer.analyze(parsed, cfg)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::flow::cfg::CFG;
use rma_common::Language;
use rma_parser::ParserEngine;
use std::path::Path;
fn parse_js(code: &str) -> ParsedFile {
let config = rma_common::RmaConfig::default();
let parser = ParserEngine::new(config);
parser
.parse_file(Path::new("test.js"), code)
.expect("parse failed")
}
fn parse_rust(code: &str) -> ParsedFile {
let config = rma_common::RmaConfig::default();
let parser = ParserEngine::new(config);
parser
.parse_file(Path::new("test.rs"), code)
.expect("parse failed")
}
#[test]
fn test_state_creation() {
let s = State::new("Open");
assert_eq!(s.name, "Open");
assert!(!s.is_initial);
assert!(!s.is_final);
assert!(!s.is_error);
let initial = State::initial("Start");
assert!(initial.is_initial);
let final_state = State::final_state("End");
assert!(final_state.is_final);
let error = State::error("Error");
assert!(error.is_error);
}
#[test]
fn test_state_builder() {
let s = State::new("Test").with_initial(true).with_final(true);
assert!(s.is_initial);
assert!(s.is_final);
}
#[test]
fn test_transition_creation() {
let t = Transition::on_method("Open", "Closed", "close");
assert_eq!(t.from, "Open");
assert_eq!(t.to, "Closed");
assert!(t.trigger.matches_method("close"));
assert!(!t.trigger.matches_method("open"));
}
#[test]
fn test_state_machine_creation() {
let sm = file_state_machine();
assert_eq!(sm.name, "File");
assert!(!sm.states.is_empty());
assert!(!sm.transitions.is_empty());
let initial = sm.initial_state();
assert!(initial.is_some());
assert_eq!(initial.unwrap().name, "Unopened");
assert!(sm.is_final_state("Closed"));
assert!(!sm.is_final_state("Open"));
assert!(sm.is_error_state("UseAfterClose"));
}
#[test]
fn test_state_machine_transitions() {
let sm = file_state_machine();
let t = sm.get_method_transition("Open", "close");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "Closed");
let t = sm.get_method_transition("Closed", "close");
assert!(t.is_none());
let t = sm.get_method_transition("Closed", "read");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "UseAfterClose");
}
#[test]
fn test_tracks_type() {
let sm = file_state_machine();
assert!(sm.tracks_type("File"));
assert!(sm.tracks_type("std::fs::File"));
assert!(sm.tracks_type("my::module::File")); assert!(!sm.tracks_type("Connection")); assert!(!sm.tracks_type("Lock")); }
#[test]
fn test_tracked_state_merge_same() {
let a = TrackedState::Known("Open".to_string());
let b = TrackedState::Known("Open".to_string());
let merged = a.merge(&b);
assert_eq!(merged, TrackedState::Known("Open".to_string()));
}
#[test]
fn test_tracked_state_merge_different() {
let a = TrackedState::Known("Open".to_string());
let b = TrackedState::Known("Closed".to_string());
let merged = a.merge(&b);
match merged {
TrackedState::Conflicting(states) => {
assert!(states.contains("Open"));
assert!(states.contains("Closed"));
}
_ => panic!("Expected Conflicting state"),
}
}
#[test]
fn test_tracked_state_merge_with_unknown() {
let a = TrackedState::Known("Open".to_string());
let b = TrackedState::Unknown;
let merged = a.merge(&b);
assert_eq!(merged, TrackedState::Known("Open".to_string()));
}
#[test]
fn test_violation_creation() {
let v = TypestateViolation::new(
ViolationKind::InvalidTransition,
123,
5,
"Open",
"Cannot call close() when file is already closed",
)
.with_attempted_transition("close");
assert_eq!(v.kind, ViolationKind::InvalidTransition);
assert_eq!(v.location, 123);
assert_eq!(v.line, 5);
assert_eq!(v.current_state, "Open");
assert_eq!(v.attempted_transition, Some("close".to_string()));
}
#[test]
fn test_violation_kind_display() {
assert_eq!(
format!("{}", ViolationKind::InvalidTransition),
"Invalid state transition"
);
assert_eq!(
format!("{}", ViolationKind::UseInErrorState),
"Use of object in error state"
);
}
#[test]
fn test_find_method_calls() {
let code = r#"
const file = openFile("test.txt");
file.read();
file.write("data");
file.close();
"#;
let parsed = parse_js(code);
let semantics = crate::semantics::LanguageSemantics::for_language(Language::JavaScript);
let calls = find_method_calls_on_var(&parsed, "file", semantics);
let method_names: Vec<_> = calls.iter().map(|c| c.method_name.as_str()).collect();
assert!(method_names.contains(&"read"), "Should find read()");
assert!(method_names.contains(&"write"), "Should find write()");
assert!(method_names.contains(&"close"), "Should find close()");
}
#[test]
fn test_find_method_calls_different_var() {
let code = r#"
const file1 = openFile("a.txt");
const file2 = openFile("b.txt");
file1.read();
file2.write();
"#;
let parsed = parse_js(code);
let semantics = crate::semantics::LanguageSemantics::for_language(Language::JavaScript);
let calls1 = find_method_calls_on_var(&parsed, "file1", semantics);
let calls2 = find_method_calls_on_var(&parsed, "file2", semantics);
assert_eq!(calls1.len(), 1);
assert_eq!(calls1[0].method_name, "read");
assert_eq!(calls2.len(), 1);
assert_eq!(calls2[0].method_name, "write");
}
#[test]
fn test_typestate_result() {
let mut result = TypestateResult::new("file", "File");
assert_eq!(result.variable, "file");
assert_eq!(result.state_machine, "File");
assert!(!result.has_violations());
result.violations.push(TypestateViolation::new(
ViolationKind::InvalidTransition,
0,
1,
"Closed",
"Test violation",
));
assert!(result.has_violations());
}
#[test]
fn test_analyzer_creation() {
let semantics = crate::semantics::LanguageSemantics::for_language(Language::JavaScript);
let analyzer = TypestateAnalyzer::new(semantics)
.with_state_machine(file_state_machine())
.with_state_machine(lock_state_machine());
assert_eq!(analyzer.state_machines().len(), 2);
}
#[test]
fn test_analyzer_basic_file_operations() {
let code = r#"
function process() {
const file = File.open("test.txt");
file.read();
file.close();
}
"#;
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let semantics = crate::semantics::LanguageSemantics::for_language(Language::JavaScript);
let analyzer = TypestateAnalyzer::new(semantics).with_state_machine(file_state_machine());
let _results = analyzer.analyze(&parsed, &cfg);
}
#[test]
fn test_simple_state_tracking() {
let sm = StateMachine::new("TestSM")
.with_state(State::initial("A"))
.with_state(State::new("B"))
.with_state(State::final_state("C"))
.with_transition(Transition::on_method("A", "B", "step1"))
.with_transition(Transition::on_method("B", "C", "step2"))
.with_tracked_type("TestType");
let semantics = crate::semantics::LanguageSemantics::for_language(Language::JavaScript);
let analyzer = TypestateAnalyzer::new(semantics).with_state_machine(sm.clone());
let t = analyzer.get_transition(&sm, "step1", "A");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "B");
let t = analyzer.get_transition(&sm, "step1", "B");
assert!(t.is_none());
}
#[test]
fn test_lock_state_machine() {
let sm = lock_state_machine();
assert!(sm.initial_state().is_some());
assert_eq!(sm.initial_state().unwrap().name, "Unlocked");
assert!(sm.is_final_state("Unlocked"));
assert!(!sm.is_final_state("Locked"));
assert!(sm.is_error_state("DoubleLock"));
assert!(sm.is_error_state("DoubleUnlock"));
}
#[test]
fn test_lock_transitions() {
let sm = lock_state_machine();
let t = sm.get_method_transition("Unlocked", "lock");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "Locked");
let t = sm.get_method_transition("Locked", "unlock");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "Unlocked");
let t = sm.get_method_transition("Locked", "lock");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "DoubleLock");
let t = sm.get_method_transition("Unlocked", "unlock");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "DoubleUnlock");
}
#[test]
fn test_connection_state_machine() {
let sm = connection_state_machine();
assert_eq!(sm.initial_state().unwrap().name, "Disconnected");
assert!(sm.is_final_state("Closed"));
assert!(sm.is_error_state("UseAfterClose"));
let t = sm.get_method_transition("Disconnected", "connect");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "Connected");
let t = sm.get_method_transition("Connected", "query");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "Connected");
let t = sm.get_method_transition("Connected", "close");
assert!(t.is_some());
assert_eq!(t.unwrap().to, "Closed");
}
#[test]
fn test_analyze_typestate_with_context() {
let code = r#"
function test() {
const f = File.open("x");
f.read();
return;
}
"#;
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let semantics = crate::semantics::LanguageSemantics::for_language(Language::JavaScript);
let state_machines = vec![file_state_machine()];
let _results = analyze_typestate_with_context(&parsed, &cfg, semantics, &state_machines);
}
#[test]
fn test_multiple_state_machines() {
let semantics = crate::semantics::LanguageSemantics::for_language(Language::JavaScript);
let analyzer = TypestateAnalyzer::new(semantics).with_state_machines(vec![
file_state_machine(),
lock_state_machine(),
connection_state_machine(),
]);
assert_eq!(analyzer.state_machines().len(), 3);
}
#[test]
fn test_empty_state_machine() {
let sm = StateMachine::new("Empty");
assert!(sm.initial_state().is_none());
assert!(sm.states.is_empty());
assert!(sm.transitions.is_empty());
}
#[test]
fn test_wildcard_method_transition() {
let sm = StateMachine::new("Test")
.with_state(State::initial("Any"))
.with_transition(Transition::on_method("Any", "Any", "*"));
assert!(sm.get_method_transition("Any", "foo").is_some());
assert!(sm.get_method_transition("Any", "bar").is_some());
}
#[test]
fn test_conflicting_states_at_merge() {
let a = TrackedState::Known("Open".to_string());
let b = TrackedState::Known("Closed".to_string());
let c = TrackedState::Known("Open".to_string());
let merged = a.merge(&b);
match &merged {
TrackedState::Conflicting(states) => {
assert_eq!(states.len(), 2);
}
_ => panic!("Expected conflicting"),
}
let merged2 = merged.merge(&c);
match merged2 {
TrackedState::Conflicting(states) => {
assert_eq!(states.len(), 2); }
_ => panic!("Expected conflicting"),
}
}
#[test]
fn test_rust_semantics() {
let code = r#"
fn main() {
let file = File::open("test.txt").unwrap();
file.read_to_string(&mut s);
}
"#;
let parsed = parse_rust(code);
let cfg = CFG::build(&parsed, Language::Rust);
let semantics = crate::semantics::LanguageSemantics::for_language(Language::Rust);
let analyzer = TypestateAnalyzer::new(semantics).with_state_machine(file_state_machine());
let _results = analyzer.analyze(&parsed, &cfg);
}
}