from __future__ import annotations
import sys
import unittest
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from bash_redirect_hook import (
BLOCK_GH_PR_DIFF,
BLOCK_MCP_GITHUB_GET_COMMIT,
_advice_with_echo,
_matches_gh_api_contents,
_classify_git_command,
_drop_heredoc_bodies,
_has_pickaxe_flag,
_has_ref_range,
_is_functionally_empty,
_matches_gh_pr_diff,
_emit_advice,
_read_payload,
decide_redirect,
main,
tokenize_command,
)
class TestIsFunctionallyEmpty(unittest.TestCase):
def test_empty_string_is_empty(self) -> None:
self.assertTrue(_is_functionally_empty(""))
def test_whitespace_only_is_empty(self) -> None:
self.assertTrue(_is_functionally_empty(" \t "))
def test_newline_only_is_empty(self) -> None:
self.assertTrue(_is_functionally_empty("\n"))
def test_escaped_newline_sequence_is_empty(self) -> None:
self.assertTrue(_is_functionally_empty("\\n \\n"))
def test_mixed_escape_sequences_are_empty(self) -> None:
self.assertTrue(_is_functionally_empty("\\n\\t\\r"))
def test_non_whitespace_is_not_empty(self) -> None:
self.assertFalse(_is_functionally_empty("git diff"))
def test_json_is_not_empty(self) -> None:
self.assertFalse(_is_functionally_empty('{"tool_name": "Bash"}'))
class TestHasRefRange(unittest.TestCase):
def test_double_dot_range_is_detected(self) -> None:
self.assertTrue(_has_ref_range(["main..HEAD"]))
def test_triple_dot_range_is_detected(self) -> None:
self.assertTrue(_has_ref_range(["main...HEAD"]))
def test_bare_double_dot_is_excluded(self) -> None:
self.assertFalse(_has_ref_range([".."]))
def test_bare_triple_dot_is_excluded(self) -> None:
self.assertFalse(_has_ref_range(["..."]))
def test_no_range_returns_false(self) -> None:
self.assertFalse(_has_ref_range(["git", "diff", "--stat"]))
def test_range_anywhere_in_list_is_detected(self) -> None:
self.assertTrue(_has_ref_range(["git", "diff", "feature..main"]))
class TestHasPickaxeFlag(unittest.TestCase):
def test_standalone_dash_s(self) -> None:
self.assertTrue(_has_pickaxe_flag(["-S"]))
def test_standalone_dash_g(self) -> None:
self.assertTrue(_has_pickaxe_flag(["-G"]))
def test_concatenated_dash_s_term(self) -> None:
self.assertTrue(_has_pickaxe_flag(["-Sfoo"]))
def test_concatenated_dash_g_term(self) -> None:
self.assertTrue(_has_pickaxe_flag(["-Gbar"]))
def test_other_flags_not_matched(self) -> None:
self.assertFalse(_has_pickaxe_flag(["-p", "--oneline", "-n"]))
def test_empty_list_returns_false(self) -> None:
self.assertFalse(_has_pickaxe_flag([]))
class TestClassifyGitCommand(unittest.TestCase):
def test_git_diff_with_ref_range_returns_change_manifest(self) -> None:
self.assertEqual(
_classify_git_command(["git", "diff", "main..HEAD"]),
"get_change_manifest",
)
def test_git_log_with_ref_range_returns_commit_history(self) -> None:
self.assertEqual(
_classify_git_command(["git", "log", "main..HEAD"]),
"get_commit_history",
)
def test_git_log_with_pickaxe_returns_function_context(self) -> None:
self.assertEqual(
_classify_git_command(["git", "log", "-S", "foo"]),
"get_function_context",
)
def test_git_log_pickaxe_priority_over_range(self) -> None:
self.assertEqual(
_classify_git_command(["git", "log", "-S", "foo", "main..HEAD"]),
"get_function_context",
)
def test_git_blame_returns_file_snapshots(self) -> None:
self.assertEqual(
_classify_git_command(["git", "blame", "src/main.rs"]),
"get_file_snapshots",
)
def test_git_show_returns_file_snapshots(self) -> None:
self.assertEqual(
_classify_git_command(["git", "show", "abc123:src/main.rs"]),
"get_file_snapshots",
)
def test_git_status_returns_none(self) -> None:
self.assertIsNone(_classify_git_command(["git", "status"]))
def test_git_add_returns_none(self) -> None:
self.assertIsNone(_classify_git_command(["git", "add", "file.txt"]))
def test_git_commit_returns_none(self) -> None:
self.assertIsNone(_classify_git_command(["git", "commit", "-m", "msg"]))
def test_git_push_returns_none(self) -> None:
self.assertIsNone(_classify_git_command(["git", "push", "origin"]))
def test_git_fetch_returns_none(self) -> None:
self.assertIsNone(_classify_git_command(["git", "fetch", "origin"]))
def test_non_git_command_returns_none(self) -> None:
self.assertIsNone(_classify_git_command(["ls", "-la"]))
def test_empty_list_returns_none(self) -> None:
self.assertIsNone(_classify_git_command([]))
def test_git_diff_without_range_returns_none(self) -> None:
self.assertIsNone(_classify_git_command(["git", "diff"]))
class TestTokenizeCommand(unittest.TestCase):
def test_simple_git_diff(self) -> None:
result = tokenize_command("git diff main..HEAD")
self.assertEqual(result, [["git", "diff", "main..HEAD"]])
def test_compound_and_command(self) -> None:
result = tokenize_command("cd /tmp && git diff main..HEAD")
self.assertIn(["git", "diff", "main..HEAD"], result)
def test_subshell_parentheses(self) -> None:
result = tokenize_command("(git log main..HEAD)")
self.assertIn(["git", "log", "main..HEAD"], result)
def test_pipeline(self) -> None:
result = tokenize_command("git diff main..HEAD | grep foo")
self.assertIn(["git", "diff", "main..HEAD"], result)
def test_backtick_normalization(self) -> None:
result = tokenize_command(
"cd `git rev-parse --show-toplevel` && git diff main..HEAD"
)
git_diff = [
c for c in result if c and c[0] == "git" and len(c) > 1 and c[1] == "diff"
]
self.assertTrue(git_diff, f"git diff candidate missing from {result}")
def test_variable_not_expanded(self) -> None:
result = tokenize_command("git diff $BASE..HEAD")
flat_tokens = [tok for cand in result for tok in cand]
self.assertTrue(
any("$BASE" in tok for tok in flat_tokens),
f"Expected literal '$BASE' in tokens, got: {flat_tokens}",
)
def test_empty_command_returns_empty_list(self) -> None:
self.assertEqual(tokenize_command(""), [])
def test_heredoc_body_is_skipped(self) -> None:
command = "cat <<EOF\ngit log a..b\nEOF\n"
result = tokenize_command(command)
git_log = [c for c in result if c and c[0] == "git"]
self.assertFalse(git_log, f"git command inside heredoc body leaked: {result}")
def test_heredoc_indented_delimiter_in_body_is_not_terminated(self) -> None:
command = "cat <<EOF\n EOF\ngit diff main..HEAD\nEOF\n"
result = tokenize_command(command)
git_diff_candidates = [
c for c in result if c and c[0] == "git" and len(c) > 1 and c[1] == "diff"
]
self.assertFalse(
git_diff_candidates,
f"git diff inside heredoc body leaked after indented faux-tag: {result}",
)
def test_tokenizer_resumes_after_heredoc_terminator(self) -> None:
command = "cat <<EOF\ngit log a..b\nEOF\ngit diff main..HEAD"
result = tokenize_command(command)
git_diff = [
c for c in result if c and c[0] == "git" and len(c) > 1 and c[1] == "diff"
]
self.assertTrue(git_diff, f"git diff after heredoc not found in: {result}")
git_log = [
c for c in result if c and c[0] == "git" and len(c) > 1 and c[1] == "log"
]
self.assertFalse(
git_log,
f"git log inside heredoc body leaked into candidates: {result}",
)
def test_digit_starting_heredoc_tag_body_is_skipped(self) -> None:
command = "echo \"$(cat <<'1'\ngh pr diff 123 --repo owner/repo\n1\n) done\"\n"
result = tokenize_command(command)
gh_pr_diff = [
c for c in result if c and len(c) >= 2 and c[0] == "gh" and c[1] == "pr"
]
self.assertFalse(
gh_pr_diff,
f"gh pr diff inside digit-tag quoted-heredoc body leaked: {result}",
)
def test_two_heredocs_same_line_second_body_leaks(self) -> None:
command = (
"cat <<EOF1 <<EOF2\nEOF1 body\nEOF1\ngh pr diff 123\nEOF2\necho done\n"
)
result = tokenize_command(command)
gh_pr_candidates = [
c for c in result if c and len(c) >= 2 and c[0] == "gh" and c[1] == "pr"
]
self.assertFalse(
gh_pr_candidates,
f"gh pr diff inside second heredoc body on same-line << leaked: {result}",
)
def test_two_heredocs_same_line_git_diff_leaks(self) -> None:
command = (
"cat <<EOF1 <<EOF2\nEOF1 body\nEOF1\ngit diff main..HEAD\nEOF2\necho done\n"
)
result = tokenize_command(command)
git_diff_candidates = [
c for c in result if c and len(c) >= 2 and c[0] == "git" and c[1] == "diff"
]
self.assertFalse(
git_diff_candidates,
f"git diff inside second heredoc body on same-line << leaked: {result}",
)
def test_two_heredocs_same_line_gh_api_contents_leaks(self) -> None:
from bash_redirect_hook import _matches_gh_api_contents
command = (
"cat <<EOF1 <<EOF2\n"
"EOF1 body\n"
"EOF1\n"
"gh api repos/owner/repo/contents/path?ref=abc123\n"
"EOF2\n"
"echo done\n"
)
result = _matches_gh_api_contents(command)
self.assertFalse(
result,
f"gh api contents?ref= inside second heredoc body on same-line << "
f"triggered false match: {result}",
)
class TestDropHeredocBodies(unittest.TestCase):
def test_simple_heredoc_body_is_dropped(self) -> None:
tokens = ["cat", "<<", "EOF", "\n", "git", "\n", "EOF", "\n", "echo", "done"]
result = _drop_heredoc_bodies(tokens)
self.assertNotIn("git", result)
self.assertIn("echo", result)
self.assertIn("done", result)
def test_dash_form_heredoc_body_is_dropped(self) -> None:
tokens = ["cat", "<<", "-EOF", "\n", "git", "\n", "EOF", "\n", "echo", "done"]
result = _drop_heredoc_bodies(tokens)
self.assertNotIn("git", result)
self.assertIn("echo", result)
def test_content_before_heredoc_is_preserved(self) -> None:
tokens = ["echo", "hi", "<<", "EOF", "\n", "body", "\n", "EOF"]
result = _drop_heredoc_bodies(tokens)
self.assertIn("echo", result)
self.assertIn("hi", result)
self.assertNotIn("body", result)
def test_empty_token_list_returns_empty(self) -> None:
self.assertEqual(_drop_heredoc_bodies([]), [])
def test_no_heredoc_passes_through_unchanged(self) -> None:
tokens = ["git", "diff", "main..HEAD"]
self.assertEqual(_drop_heredoc_bodies(tokens), tokens)
def test_invalid_heredoc_tag_drops_all_remaining_tokens(self) -> None:
tokens = ["echo", "hi", "<<", "''", "\n", "git", "diff", "main..HEAD"]
result = _drop_heredoc_bodies(tokens)
self.assertIn(
"git", result, f"git diff after malformed << was dropped: {result}"
)
self.assertNotIn(
"<<", result, f"<< operator was preserved after malformed tag: {result}"
)
class TestAdviceWithEcho(unittest.TestCase):
def test_echo_appends_verbatim_tokens(self) -> None:
tokens = ["git", "diff", "main..HEAD"]
result = _advice_with_echo("base advice", tokens)
self.assertIn("You ran: git diff main..HEAD", result)
def test_variable_not_expanded_in_echo(self) -> None:
tokens = ["git", "diff", "$BASE..HEAD"]
result = _advice_with_echo("base advice", tokens)
self.assertIn(
"$BASE..HEAD",
result,
f"Expected literal '$BASE..HEAD' in advice, got: {result!r}",
)
def test_base_advice_is_included(self) -> None:
tokens = ["git", "log", "main..HEAD"]
result = _advice_with_echo("USE GET_COMMIT_HISTORY", tokens)
self.assertIn("USE GET_COMMIT_HISTORY", result)
class TestMatchesGhPrDiff(unittest.TestCase):
def test_plain_gh_pr_diff_is_matched(self) -> None:
self.assertTrue(_matches_gh_pr_diff("gh pr diff 123"))
def test_compound_gh_pr_diff_is_matched(self) -> None:
self.assertTrue(_matches_gh_pr_diff("cd /tmp && gh pr diff 123"))
def test_gh_pr_view_is_not_matched(self) -> None:
self.assertFalse(_matches_gh_pr_diff("gh pr view 123"))
def test_git_diff_is_not_matched(self) -> None:
self.assertFalse(_matches_gh_pr_diff("git diff main..HEAD"))
def test_empty_command_is_not_matched(self) -> None:
self.assertFalse(_matches_gh_pr_diff(""))
class TestMatchesGhApiContents(unittest.TestCase):
def test_gh_api_contents_with_ref_is_matched(self) -> None:
self.assertTrue(
_matches_gh_api_contents("gh api repos/owner/repo/contents/path?ref=abc123")
)
def test_gh_api_contents_with_multiple_params_is_matched(self) -> None:
self.assertTrue(
_matches_gh_api_contents(
"gh api repos/owner/repo/contents/path?foo=bar&ref=abc123"
)
)
def test_gh_api_contents_without_ref_is_not_matched(self) -> None:
self.assertFalse(
_matches_gh_api_contents("gh api repos/owner/repo/contents/path")
)
def test_gh_api_other_endpoint_is_not_matched(self) -> None:
self.assertFalse(_matches_gh_api_contents("gh api repos/owner/repo/issues"))
def test_gh_api_contents_no_path_before_ref_is_matched(self) -> None:
self.assertTrue(
_matches_gh_api_contents("gh api repos/owner/repo/contents/?ref=abc123")
)
def test_empty_command_is_not_matched(self) -> None:
self.assertFalse(_matches_gh_api_contents(""))
class TestDecideRedirect(unittest.TestCase):
def _bash_payload(self, command: str) -> dict[str, str | dict[str, str]]:
return {
"tool_name": "Bash",
"tool_input": {"command": command},
"hook_event_name": "PreToolUse",
}
def test_git_diff_with_range_returns_advise(self) -> None:
decision = decide_redirect(self._bash_payload("git diff main..HEAD"))
self.assertEqual(decision.mode, "advise")
self.assertIn("get_change_manifest", decision.advice)
def test_git_log_with_range_returns_advise(self) -> None:
decision = decide_redirect(self._bash_payload("git log main..HEAD"))
self.assertEqual(decision.mode, "advise")
self.assertIn("get_commit_history", decision.advice)
def test_git_log_pickaxe_returns_advise_for_function_context(self) -> None:
decision = decide_redirect(self._bash_payload("git log -S foo"))
self.assertEqual(decision.mode, "advise")
self.assertIn("get_function_context", decision.advice)
def test_git_blame_returns_advise_for_file_snapshots(self) -> None:
decision = decide_redirect(self._bash_payload("git blame src/main.rs"))
self.assertEqual(decision.mode, "advise")
self.assertIn("get_file_snapshots", decision.advice)
def test_git_status_returns_silent(self) -> None:
decision = decide_redirect(self._bash_payload("git status"))
self.assertEqual(decision.mode, "silent")
def test_gh_pr_diff_returns_block(self) -> None:
decision = decide_redirect(self._bash_payload("gh pr diff 123"))
self.assertEqual(decision.mode, "block")
self.assertIn("get_change_manifest", decision.message)
def test_mcp_github_get_commit_tool_name_returns_block(self) -> None:
payload = {
"tool_name": "mcp__github__get_commit",
"tool_input": {},
"hook_event_name": "PreToolUse",
}
decision = decide_redirect(payload)
self.assertEqual(decision.mode, "block")
self.assertIn("git-prism", decision.message)
def test_mcp_github_get_commit_as_bash_command_returns_block(self) -> None:
decision = decide_redirect(
self._bash_payload("mcp__github__get_commit owner=foo repo=bar sha=abc")
)
self.assertEqual(decision.mode, "block")
def test_mcp_github_list_commits_tool_name_returns_block(self) -> None:
payload = {
"tool_name": "mcp__github__list_commits",
"tool_input": {},
"hook_event_name": "PreToolUse",
}
decision = decide_redirect(payload)
self.assertEqual(decision.mode, "block")
def test_mcp_github_list_commits_as_bash_command_returns_block(self) -> None:
decision = decide_redirect(
self._bash_payload("mcp__github__list_commits owner=foo repo=bar")
)
self.assertEqual(decision.mode, "block")
def test_non_bash_tool_is_silent(self) -> None:
payload = {
"tool_name": "Read",
"tool_input": {"file_path": "/tmp/file"},
"hook_event_name": "PreToolUse",
}
decision = decide_redirect(payload)
self.assertEqual(decision.mode, "silent")
def test_empty_command_is_silent(self) -> None:
decision = decide_redirect(self._bash_payload(""))
self.assertEqual(decision.mode, "silent")
def test_missing_tool_name_is_silent(self) -> None:
decision = decide_redirect({})
self.assertEqual(decision.mode, "silent")
def test_gh_api_contents_with_ref_returns_advise(self) -> None:
decision = decide_redirect(
self._bash_payload("gh api repos/owner/repo/contents/path?ref=abc123")
)
self.assertEqual(decision.mode, "advise")
self.assertIn("get_file_snapshots", decision.advice)
def test_gh_api_contents_without_ref_returns_silent(self) -> None:
decision = decide_redirect(
self._bash_payload("gh api repos/owner/repo/contents/path")
)
self.assertEqual(decision.mode, "silent")
def test_variable_in_command_does_not_expand(self) -> None:
decision = decide_redirect(self._bash_payload("git diff $BASE..HEAD"))
self.assertEqual(decision.mode, "advise")
self.assertIn(
"$BASE",
decision.advice,
f"Expected literal '$BASE' in advice, got: {decision.advice!r}",
)
def test_heredoc_body_git_command_is_not_advised(self) -> None:
command = "cat <<EOF\ngit log a..b\nEOF\n"
decision = decide_redirect(self._bash_payload(command))
self.assertEqual(
decision.mode,
"silent",
f"Expected silent for heredoc-body git, got mode={decision.mode!r}",
)
def test_git_diff_after_heredoc_is_advised(self) -> None:
command = "cat <<EOF\ngit log a..b\nEOF\ngit diff main..HEAD"
decision = decide_redirect(self._bash_payload(command))
self.assertEqual(decision.mode, "advise")
self.assertIn("get_change_manifest", decision.advice)
class TestUnicodeEncoding(unittest.TestCase):
def test_advisory_path_is_ascii_safe(self):
import io
fake_stdout = io.StringIO()
old_stdout = sys.stdout
sys.stdout = fake_stdout
try:
_emit_advice(BLOCK_GH_PR_DIFF)
finally:
sys.stdout = old_stdout
output = fake_stdout.getvalue()
self.assertNotIn("←", output)
self.assertIn("-->", output)
def test_block_path_stderr_fails_on_ascii_locale(self):
import io
raw = io.BytesIO()
ascii_stderr = io.TextIOWrapper(raw, encoding="ascii")
old_stderr = sys.stderr
sys.stderr = ascii_stderr
try:
sys.stderr.write(BLOCK_GH_PR_DIFF)
sys.stderr.write("\n")
ascii_stderr.flush()
except UnicodeEncodeError:
self.fail(
"BLOCK_GH_PR_DIFF raised UnicodeEncodeError on ASCII stderr; "
"main() would silently downgrade a hard block to allow"
)
finally:
sys.stderr = old_stderr
def test_block_mcp_github_get_commit_is_ascii_safe(self):
for i, ch in enumerate(BLOCK_MCP_GITHUB_GET_COMMIT):
self.assertLess(
ord(ch),
128,
f"Non-ASCII char {ch!r} (U+{ord(ch):04X}) at index {i} in "
f"BLOCK_MCP_GITHUB_GET_COMMIT",
)
def test_block_mcp_github_get_commit_stderr_survives_ascii_locale(self):
import io
raw = io.BytesIO()
ascii_stderr = io.TextIOWrapper(raw, encoding="ascii")
old_stderr = sys.stderr
sys.stderr = ascii_stderr
try:
sys.stderr.write(BLOCK_MCP_GITHUB_GET_COMMIT)
sys.stderr.write("\n")
ascii_stderr.flush()
except UnicodeEncodeError:
self.fail(
"BLOCK_MCP_GITHUB_GET_COMMIT raised UnicodeEncodeError on ASCII stderr; "
"main() would silently downgrade a hard block to allow"
)
finally:
sys.stderr = old_stderr
def test_malformed_json_warning_survives_ascii_locale(self):
import io
raw = io.BytesIO()
ascii_stderr = io.TextIOWrapper(raw, encoding="ascii")
old_stderr = sys.stderr
sys.stderr = ascii_stderr
fake_stdin = io.StringIO("this is not json {")
try:
result = _read_payload(fake_stdin)
self.assertIsNone(result)
except UnicodeEncodeError:
self.fail(
"_read_payload raised UnicodeEncodeError on ASCII stderr when "
"emitting malformed JSON warning; main() would crash instead of "
"returning fail-open exit 0"
)
finally:
sys.stderr = old_stderr
def test_main_unexpected_error_handler_survives_ascii_locale(self):
import io
raw = io.BytesIO()
ascii_stderr = io.TextIOWrapper(raw, encoding="ascii")
old_stderr = sys.stderr
old_stdin = sys.stdin
sys.stderr = ascii_stderr
class BrokenStdin:
def read(self):
raise RuntimeError("boom")
sys.stdin = BrokenStdin()
try:
result = main()
self.assertEqual(result, 0)
except UnicodeEncodeError:
self.fail(
"main() raised UnicodeEncodeError on ASCII stderr when handling "
"unexpected error; hook would crash instead of fail-open"
)
finally:
sys.stderr = old_stderr
sys.stdin = old_stdin
if __name__ == "__main__":
unittest.main()