use std::collections::HashMap;
use tower_lsp::lsp_types::*;
use crate::Backend;
use crate::code_actions::{CodeActionData, make_code_action_data};
use crate::util::ranges_overlap;
const ALREADY_NARROWED_ID: &str = "function.alreadyNarrowedType";
const ACTION_KIND: &str = "phpstan.removeAssert";
const ASSERT_MESSAGE_PREFIX: &str = "Call to function assert()";
impl Backend {
pub(crate) fn collect_remove_assert_actions(
&self,
uri: &str,
_content: &str,
params: &CodeActionParams,
out: &mut Vec<CodeActionOrCommand>,
) {
let phpstan_diags: Vec<Diagnostic> = {
let cache = self.phpstan_last_diags.lock();
cache.get(uri).cloned().unwrap_or_default()
};
for diag in &phpstan_diags {
if !ranges_overlap(&diag.range, ¶ms.range) {
continue;
}
let identifier = match &diag.code {
Some(NumberOrString::String(s)) => s.as_str(),
_ => continue,
};
if identifier != ALREADY_NARROWED_ID {
continue;
}
if !diag.message.starts_with(ASSERT_MESSAGE_PREFIX) {
continue;
}
let diag_line = diag.range.start.line as usize;
let title = "Remove always-true assert()".to_string();
let extra = serde_json::json!({
"diagnostic_line": diag_line,
});
let data = make_code_action_data(ACTION_KIND, uri, ¶ms.range, extra);
out.push(CodeActionOrCommand::CodeAction(CodeAction {
title,
kind: Some(CodeActionKind::QUICKFIX),
diagnostics: Some(vec![diag.clone()]),
edit: None,
command: None,
is_preferred: Some(true),
disabled: None,
data: Some(data),
}));
}
}
pub(crate) fn resolve_remove_assert(
&self,
data: &CodeActionData,
content: &str,
) -> Option<WorkspaceEdit> {
let extra = &data.extra;
let diag_line = extra.get("diagnostic_line")?.as_u64()? as usize;
let edit = build_remove_assert_edit(content, diag_line)?;
let doc_uri: Url = data.uri.parse().ok()?;
let mut changes = HashMap::new();
changes.insert(doc_uri, vec![edit]);
Some(WorkspaceEdit {
changes: Some(changes),
document_changes: None,
change_annotations: None,
})
}
}
fn build_remove_assert_edit(content: &str, diag_line: usize) -> Option<TextEdit> {
let lines: Vec<&str> = content.lines().collect();
if diag_line >= lines.len() {
return None;
}
let line_text = lines[diag_line];
let assert_col = line_text.find("assert(")?;
let line_start_byte = lines[..diag_line]
.iter()
.map(|l| l.len() + 1) .sum::<usize>();
let assert_byte = line_start_byte + assert_col;
let after_paren = assert_byte + "assert(".len();
let close_paren_byte = find_matching_close_paren(content, after_paren)?;
let rest_after_paren = &content[close_paren_byte + 1..];
let semi_offset = rest_after_paren
.find(|c: char| !c.is_ascii_whitespace() || c == '\n')
.unwrap_or(0);
let semi_byte = close_paren_byte + 1 + semi_offset;
if content.as_bytes().get(semi_byte) != Some(&b';') {
return None;
}
let stmt_end_byte = semi_byte + 1;
let before_assert = &content[line_start_byte..assert_byte];
let is_only_statement = before_assert.trim().is_empty();
let after_semi = if stmt_end_byte < content.len() {
let next_newline = content[stmt_end_byte..]
.find('\n')
.map(|p| stmt_end_byte + p)
.unwrap_or(content.len());
content[stmt_end_byte..next_newline].trim().is_empty()
} else {
true
};
if is_only_statement && after_semi {
let delete_end_byte = if stmt_end_byte < content.len() {
content[stmt_end_byte..]
.find('\n')
.map(|p| stmt_end_byte + p + 1)
.unwrap_or(content.len())
} else {
content.len()
};
let start_pos = Position::new(diag_line as u32, 0);
let end_line = content[..delete_end_byte].matches('\n').count();
let end_col = if delete_end_byte <= content.len() {
delete_end_byte
- content[..delete_end_byte]
.rfind('\n')
.map(|p| p + 1)
.unwrap_or(0)
} else {
0
};
Some(TextEdit {
range: Range {
start: start_pos,
end: Position::new(end_line as u32, end_col as u32),
},
new_text: String::new(),
})
} else if is_only_statement {
let start_pos = Position::new(diag_line as u32, 0);
let trailing_space = content[stmt_end_byte..]
.chars()
.take_while(|c| *c == ' ' || *c == '\t')
.count();
let end_offset = stmt_end_byte + trailing_space;
let end_line_num = content[..end_offset].matches('\n').count();
let end_line_start = content[..end_offset]
.rfind('\n')
.map(|p| p + 1)
.unwrap_or(0);
let end_col_final = (end_offset - end_line_start) as u32;
Some(TextEdit {
range: Range {
start: start_pos,
end: Position::new(end_line_num as u32, end_col_final),
},
new_text: String::new(),
})
} else {
let leading_space = content[..assert_byte]
.chars()
.rev()
.take_while(|c| *c == ' ' || *c == '\t')
.count();
let remove_start = assert_byte - leading_space;
let remove_start_col = (remove_start - line_start_byte) as u32;
let end_line_num = content[..stmt_end_byte].matches('\n').count();
let end_line_start = content[..stmt_end_byte]
.rfind('\n')
.map(|p| p + 1)
.unwrap_or(0);
let end_col = (stmt_end_byte - end_line_start) as u32;
Some(TextEdit {
range: Range {
start: Position::new(diag_line as u32, remove_start_col),
end: Position::new(end_line_num as u32, end_col),
},
new_text: String::new(),
})
}
}
fn find_matching_close_paren(content: &str, start_byte: usize) -> Option<usize> {
let bytes = content.as_bytes();
let mut depth: u32 = 1;
let mut i = start_byte;
while i < bytes.len() && depth > 0 {
match bytes[i] {
b'(' => {
depth += 1;
i += 1;
}
b')' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
i += 1;
}
b'\'' | b'"' => {
let quote = bytes[i];
i += 1;
while i < bytes.len() {
if bytes[i] == b'\\' {
i += 2; } else if bytes[i] == quote {
i += 1;
break;
} else {
i += 1;
}
}
}
_ => {
i += 1;
}
}
}
None
}
pub(crate) fn is_remove_assert_stale(content: &str, diag_line: usize) -> bool {
let line_text = match content.lines().nth(diag_line) {
Some(l) => l,
None => return true, };
!line_text.contains("assert(")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_paren() {
let s = "true);";
assert_eq!(find_matching_close_paren(s, 0), Some(4));
}
#[test]
fn nested_parens() {
let s = "is_string($x));";
assert_eq!(find_matching_close_paren(s, 0), Some(13));
}
#[test]
fn string_with_parens() {
let s = r#""foo(bar)" !== null);"#;
assert_eq!(find_matching_close_paren(s, 0), Some(19));
}
#[test]
fn single_quoted_string() {
let s = "'a)b' === true);";
assert_eq!(find_matching_close_paren(s, 0), Some(14));
}
#[test]
fn escaped_quote_in_string() {
let s = r#""foo\")" !== null);"#;
assert_eq!(find_matching_close_paren(s, 0), Some(17));
}
#[test]
fn unmatched_returns_none() {
let s = "true";
assert_eq!(find_matching_close_paren(s, 0), None);
}
#[test]
fn removes_simple_assert_line() {
let content = "<?php\n assert($x instanceof Foo);\n $x->bar();\n";
let edit = build_remove_assert_edit(content, 1).unwrap();
assert_eq!(edit.range.start, Position::new(1, 0));
assert_eq!(edit.range.end, Position::new(2, 0));
assert_eq!(edit.new_text, "");
}
#[test]
fn removes_assert_at_end_of_file() {
let content = "<?php\nassert(true);";
let edit = build_remove_assert_edit(content, 1).unwrap();
assert_eq!(edit.range.start, Position::new(1, 0));
assert_eq!(edit.new_text, "");
}
#[test]
fn removes_assert_with_nested_calls() {
let content = "<?php\n assert(is_string(trim($x)));\n echo 'ok';\n";
let edit = build_remove_assert_edit(content, 1).unwrap();
assert_eq!(edit.range.start, Position::new(1, 0));
assert_eq!(edit.range.end, Position::new(2, 0));
assert_eq!(edit.new_text, "");
}
#[test]
fn preserves_code_before_assert() {
let content = "<?php\n$a = 1; assert(true);\n";
let edit = build_remove_assert_edit(content, 1).unwrap();
assert_eq!(edit.range.start.line, 1);
assert!(edit.range.start.character > 0);
assert_eq!(edit.new_text, "");
}
#[test]
fn returns_none_for_no_assert() {
let content = "<?php\n $x = 1;\n";
assert!(build_remove_assert_edit(content, 1).is_none());
}
#[test]
fn returns_none_for_invalid_line() {
let content = "<?php\n";
assert!(build_remove_assert_edit(content, 5).is_none());
}
#[test]
fn returns_none_for_missing_semicolon() {
let content = "<?php\nassert(true)\n";
assert!(build_remove_assert_edit(content, 1).is_none());
}
#[test]
fn stale_when_assert_removed() {
let content = "<?php\n $x->bar();\n";
assert!(is_remove_assert_stale(content, 1));
}
#[test]
fn not_stale_when_assert_present() {
let content = "<?php\n assert($x instanceof Foo);\n";
assert!(!is_remove_assert_stale(content, 1));
}
#[test]
fn stale_when_line_gone() {
let content = "<?php\n";
assert!(is_remove_assert_stale(content, 5));
}
fn make_diagnostic(line: u32, message: &str, code: &str) -> Diagnostic {
Diagnostic {
range: Range {
start: Position::new(line, 0),
end: Position::new(line, 100),
},
severity: Some(DiagnosticSeverity::ERROR),
code: Some(NumberOrString::String(code.to_string())),
source: Some("PHPStan".to_string()),
message: message.to_string(),
..Default::default()
}
}
#[test]
fn matches_assert_message() {
let msg = "Call to function assert() with true will always evaluate to true.";
assert!(msg.starts_with(ASSERT_MESSAGE_PREFIX));
}
#[test]
fn rejects_non_assert_message() {
let msg = "Call to function is_string() with string will always evaluate to true.";
assert!(!msg.starts_with(ASSERT_MESSAGE_PREFIX));
}
#[test]
fn rejects_wrong_identifier() {
let diag = make_diagnostic(1, "Call to function assert() with ...", "some.other");
let identifier = match &diag.code {
Some(NumberOrString::String(s)) => s.as_str(),
_ => "",
};
assert_ne!(identifier, ALREADY_NARROWED_ID);
}
#[test]
fn accepts_correct_identifier_and_message() {
let diag = make_diagnostic(
1,
"Call to function assert() with true will always evaluate to true.",
ALREADY_NARROWED_ID,
);
let identifier = match &diag.code {
Some(NumberOrString::String(s)) => s.as_str(),
_ => "",
};
assert_eq!(identifier, ALREADY_NARROWED_ID);
assert!(diag.message.starts_with(ASSERT_MESSAGE_PREFIX));
}
}