use std::path::{Path, PathBuf};
use std::process::Command;
use base64::Engine;
use serial_test::serial;
use sha2::{Digest, Sha256};
use socket_patch_cli::commands::remove::{run as remove_run, RemoveArgs};
use socket_patch_cli::commands::rollback::{run as rollback_run, RollbackArgs};
use socket_patch_cli::commands::scan::{run as scan_run, ScanArgs};
use wiremock::matchers::{method, path, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};
const ORG: &str = "test-org";
const PYPI_PACKAGE: &str = "six";
const PYPI_VERSION: &str = "1.16.0";
const UUID_INSTALLED: &str = "11111111-1111-4111-8111-111111111111";
const UUID_OTHER_WHEEL: &str = "22222222-2222-4222-8222-222222222222";
const UUID_SDIST: &str = "33333333-3333-4333-8333-333333333333";
const ARTIFACT_INSTALLED: &str = "wheel-cp-installed";
const ARTIFACT_OTHER_WHEEL: &str = "wheel-cp-other";
const ARTIFACT_SDIST: &str = "sdist";
const MARKER_INSTALLED: &[u8] = b"\n# SOCKET-MULTIRELEASE-INSTALLED\n";
fn git_sha256(content: &[u8]) -> String {
let header = format!("blob {}\0", content.len());
let mut hasher = Sha256::new();
hasher.update(header.as_bytes());
hasher.update(content);
hex::encode(hasher.finalize())
}
fn b64(bytes: &[u8]) -> String {
base64::engine::general_purpose::STANDARD.encode(bytes)
}
fn find_python() -> Option<&'static str> {
for cmd in ["python3", "python", "py"] {
let ok = Command::new(cmd)
.arg("--version")
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false);
if ok {
return Some(cmd);
}
}
None
}
fn has_python3() -> bool {
find_python().is_some()
}
fn venv_pip(venv: &Path) -> PathBuf {
if cfg!(windows) {
venv.join("Scripts").join("pip.exe")
} else {
venv.join("bin").join("pip")
}
}
fn find_site_packages(venv: &Path) -> PathBuf {
if cfg!(windows) {
venv.join("Lib").join("site-packages")
} else {
let lib = venv.join("lib");
for entry in std::fs::read_dir(&lib).expect("lib dir").flatten() {
let sp = entry.path().join("site-packages");
if sp.exists() {
return sp;
}
}
panic!("site-packages not found under {}", lib.display());
}
}
fn install_six(tmp: &Path) -> PathBuf {
let venv = tmp.join(".venv");
let python = find_python().expect("python interpreter not on PATH");
let status = Command::new(python)
.args(["-m", "venv", venv.to_str().unwrap()])
.status()
.expect("python venv");
assert!(status.success(), "failed to create venv");
let pip = venv_pip(&venv);
let status = Command::new(&pip)
.args([
"install",
"--disable-pip-version-check",
"--quiet",
"--no-cache-dir",
&format!("{PYPI_PACKAGE}=={PYPI_VERSION}"),
])
.status()
.expect("pip install");
assert!(status.success(), "failed to install {PYPI_PACKAGE}");
let candidate = find_site_packages(&venv).join("six.py");
assert!(candidate.exists(), "six.py not found after pip install");
candidate
}
fn base_purl() -> String {
format!("pkg:pypi/{PYPI_PACKAGE}@{PYPI_VERSION}")
}
fn qualified(artifact_id: &str) -> String {
format!("{}?artifact_id={artifact_id}", base_purl())
}
async fn setup_multi_release_mock(server: &MockServer, installed_before_hash: &str) {
let base = base_purl();
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG}/patches/batch")))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"packages": [{
"purl": base,
"patches": [
{ "uuid": UUID_INSTALLED, "purl": qualified(ARTIFACT_INSTALLED),
"tier": "free", "cveIds": [], "ghsaIds": [],
"severity": "high", "title": "installed wheel" },
{ "uuid": UUID_OTHER_WHEEL, "purl": qualified(ARTIFACT_OTHER_WHEEL),
"tier": "free", "cveIds": [], "ghsaIds": [],
"severity": "high", "title": "other wheel" },
{ "uuid": UUID_SDIST, "purl": qualified(ARTIFACT_SDIST),
"tier": "free", "cveIds": [], "ghsaIds": [],
"severity": "high", "title": "sdist" },
]
}],
"canAccessPaidPatches": false,
})))
.mount(server)
.await;
Mock::given(method("GET"))
.and(path_regex(format!("^/v0/orgs/{ORG}/patches/by-package/.+$")))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"patches": [
{ "uuid": UUID_INSTALLED, "purl": qualified(ARTIFACT_INSTALLED),
"publishedAt": "2024-01-01T00:00:00Z", "description": "installed wheel",
"license": "MIT", "tier": "free", "vulnerabilities": {} },
{ "uuid": UUID_OTHER_WHEEL, "purl": qualified(ARTIFACT_OTHER_WHEEL),
"publishedAt": "2024-01-01T00:00:00Z", "description": "other wheel",
"license": "MIT", "tier": "free", "vulnerabilities": {} },
{ "uuid": UUID_SDIST, "purl": qualified(ARTIFACT_SDIST),
"publishedAt": "2024-01-01T00:00:00Z", "description": "sdist",
"license": "MIT", "tier": "free", "vulnerabilities": {} },
],
"canAccessPaidPatches": false,
})))
.mount(server)
.await;
let other_before = b"# six.py from a DIFFERENT wheel distribution\n";
let mut other_after = other_before.to_vec();
other_after.extend_from_slice(b"\n# OTHER-WHEEL-MARKER\n");
mount_view(
server,
UUID_OTHER_WHEEL,
&qualified(ARTIFACT_OTHER_WHEEL),
&git_sha256(other_before),
&git_sha256(&other_after),
other_before,
&other_after,
)
.await;
let sdist_before = b"# six.py from the sdist distribution\n";
let mut sdist_after = sdist_before.to_vec();
sdist_after.extend_from_slice(b"\n# SDIST-MARKER\n");
mount_view(
server,
UUID_SDIST,
&qualified(ARTIFACT_SDIST),
&git_sha256(sdist_before),
&git_sha256(&sdist_after),
sdist_before,
&sdist_after,
)
.await;
let _ = installed_before_hash;
}
async fn mount_installed_view(
server: &MockServer,
before_hash: &str,
after_hash: &str,
before_bytes: &[u8],
after_bytes: &[u8],
) {
mount_view(
server,
UUID_INSTALLED,
&qualified(ARTIFACT_INSTALLED),
before_hash,
after_hash,
before_bytes,
after_bytes,
)
.await;
}
#[allow(clippy::too_many_arguments)]
async fn mount_view(
server: &MockServer,
uuid: &str,
purl: &str,
before_hash: &str,
after_hash: &str,
before_bytes: &[u8],
after_bytes: &[u8],
) {
Mock::given(method("GET"))
.and(path(format!("/v0/orgs/{ORG}/patches/view/{uuid}")))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"uuid": uuid,
"purl": purl,
"publishedAt": "2024-01-01T00:00:00Z",
"files": {
"six.py": {
"beforeHash": before_hash,
"afterHash": after_hash,
"blobContent": b64(after_bytes),
"beforeBlobContent": b64(before_bytes),
}
},
"vulnerabilities": {},
"description": "multi-release fixture",
"license": "MIT",
"tier": "free",
})))
.mount(server)
.await;
}
fn scan_args(tmp: &Path, api_url: String, all_releases: bool) -> ScanArgs {
ScanArgs {
common: socket_patch_cli::args::GlobalArgs {
cwd: tmp.to_path_buf(),
org: Some(ORG.to_string()),
json: true,
yes: true,
global: false,
global_prefix: None,
api_url,
api_token: Some("fake".to_string()),
ecosystems: Some(vec!["pypi".to_string()]),
download_mode: "diff".to_string(),
dry_run: false,
..socket_patch_cli::args::GlobalArgs::default()
},
batch_size: 100,
apply: true,
prune: false,
sync: false,
all_releases,
}
}
fn manifest_keys(tmp: &Path) -> Vec<String> {
let path = tmp.join(".socket").join("manifest.json");
let raw = std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("manifest not found at {}", path.display()));
let v: serde_json::Value = serde_json::from_str(&raw).expect("manifest json");
v["patches"]
.as_object()
.map(|m| m.keys().cloned().collect())
.unwrap_or_default()
}
fn file_has_marker(file: &Path, marker: &[u8]) -> bool {
let bytes = std::fs::read(file).expect("read file");
bytes.windows(marker.len()).any(|w| w == marker)
}
async fn fixture(tmp: &Path) -> (PathBuf, MockServer) {
let six_path = install_six(tmp);
let original = std::fs::read(&six_path).expect("read six.py");
let before_hash = git_sha256(&original);
let mut patched = original.clone();
patched.extend_from_slice(MARKER_INSTALLED);
let after_hash = git_sha256(&patched);
let server = MockServer::start().await;
setup_multi_release_mock(&server, &before_hash).await;
mount_installed_view(&server, &before_hash, &after_hash, &original, &patched).await;
(six_path, server)
}
#[tokio::test]
#[serial]
async fn narrow_scan_keeps_only_installed_release() {
if !has_python3() {
println!("SKIP: python3 not on PATH");
return;
}
let tmp = tempfile::tempdir().expect("tempdir");
let (six_path, server) = fixture(tmp.path()).await;
let code = scan_run(scan_args(tmp.path(), server.uri(), false)).await;
assert!(code == 0 || code == 1, "scan exit: {code}");
let keys = manifest_keys(tmp.path());
assert_eq!(
keys,
vec![qualified(ARTIFACT_INSTALLED)],
"narrow scan must store only the installed-dist variant; got {keys:?}"
);
assert!(
file_has_marker(&six_path, MARKER_INSTALLED),
"installed variant should have patched six.py"
);
}
#[tokio::test]
#[serial]
async fn broad_scan_keeps_all_releases() {
if !has_python3() {
println!("SKIP: python3 not on PATH");
return;
}
let tmp = tempfile::tempdir().expect("tempdir");
let (six_path, server) = fixture(tmp.path()).await;
let code = scan_run(scan_args(tmp.path(), server.uri(), true)).await;
assert!(code == 0 || code == 1, "scan exit: {code}");
let mut keys = manifest_keys(tmp.path());
keys.sort();
let mut expected = vec![
qualified(ARTIFACT_INSTALLED),
qualified(ARTIFACT_OTHER_WHEEL),
qualified(ARTIFACT_SDIST),
];
expected.sort();
assert_eq!(keys, expected, "broad scan must store every variant");
assert!(
file_has_marker(&six_path, MARKER_INSTALLED),
"broad apply should still patch with the installed variant"
);
}
#[tokio::test]
#[serial]
async fn remove_base_purl_clears_all_variants_and_rolls_back() {
if !has_python3() {
println!("SKIP: python3 not on PATH");
return;
}
let tmp = tempfile::tempdir().expect("tempdir");
let (six_path, server) = fixture(tmp.path()).await;
let _ = scan_run(scan_args(tmp.path(), server.uri(), true)).await;
assert_eq!(manifest_keys(tmp.path()).len(), 3);
assert!(file_has_marker(&six_path, MARKER_INSTALLED));
let remove_args = RemoveArgs {
identifier: base_purl(),
common: socket_patch_cli::args::GlobalArgs {
cwd: tmp.path().to_path_buf(),
org: Some(ORG.to_string()),
api_url: server.uri(),
api_token: Some("fake".to_string()),
json: true,
yes: true,
ecosystems: Some(vec!["pypi".to_string()]),
..socket_patch_cli::args::GlobalArgs::default()
},
skip_rollback: false,
};
let code = remove_run(remove_args).await;
assert_eq!(code, 0, "remove base PURL should succeed (exit 0)");
assert!(
manifest_keys(tmp.path()).is_empty(),
"all release variants should be removed from the manifest"
);
assert!(
!file_has_marker(&six_path, MARKER_INSTALLED),
"remove should roll the on-disk file back to its original bytes"
);
}
#[tokio::test]
#[serial]
async fn rollback_all_over_broad_manifest_succeeds() {
if !has_python3() {
println!("SKIP: python3 not on PATH");
return;
}
let tmp = tempfile::tempdir().expect("tempdir");
let (six_path, server) = fixture(tmp.path()).await;
let _ = scan_run(scan_args(tmp.path(), server.uri(), true)).await;
assert_eq!(manifest_keys(tmp.path()).len(), 3);
assert!(file_has_marker(&six_path, MARKER_INSTALLED));
let rollback_args = RollbackArgs {
identifier: None,
common: socket_patch_cli::args::GlobalArgs {
cwd: tmp.path().to_path_buf(),
org: Some(ORG.to_string()),
api_url: server.uri(),
api_token: Some("fake".to_string()),
json: true,
ecosystems: Some(vec!["pypi".to_string()]),
..socket_patch_cli::args::GlobalArgs::default()
},
one_off: false,
};
let code = rollback_run(rollback_args).await;
assert_eq!(code, 0, "rollback-all over broad manifest should exit 0");
assert!(
!file_has_marker(&six_path, MARKER_INSTALLED),
"rollback should restore the original file bytes"
);
}