use std::path::{Path, PathBuf};
use base64::Engine;
use serial_test::serial;
use sha2::{Digest, Sha256};
use socket_patch_cli::commands::remove::{run as remove_run, RemoveArgs};
use socket_patch_cli::commands::rollback::{run as rollback_run, RollbackArgs};
use socket_patch_cli::commands::scan::{run as scan_run, ScanArgs};
use wiremock::matchers::{method, path, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};
const ORG: &str = "test-org";
const GEM_NAME: &str = "nokogiri";
const GEM_VERSION: &str = "1.16.5";
const UUID_INSTALLED: &str = "11111111-1111-4111-8111-aaaaaaaaaaaa";
const UUID_OTHER: &str = "22222222-2222-4222-8222-bbbbbbbbbbbb";
const PLATFORM_INSTALLED: &str = "x86_64-linux";
const PLATFORM_OTHER: &str = "arm64-darwin";
const MARKER_INSTALLED: &[u8] = b"\n# SOCKET-GEM-INSTALLED-X86_64\n";
fn git_sha256(content: &[u8]) -> String {
let header = format!("blob {}\0", content.len());
let mut hasher = Sha256::new();
hasher.update(header.as_bytes());
hasher.update(content);
hex::encode(hasher.finalize())
}
fn b64(bytes: &[u8]) -> String {
base64::engine::general_purpose::STANDARD.encode(bytes)
}
fn base_purl() -> String {
format!("pkg:gem/{GEM_NAME}@{GEM_VERSION}")
}
fn qualified(platform: &str) -> String {
format!("{}?platform={platform}", base_purl())
}
fn install_platform_gem(cwd: &Path, platform: &str, contents: &[u8]) -> PathBuf {
let gems = cwd
.join("vendor")
.join("bundle")
.join("ruby")
.join("3.0.0")
.join("gems");
let gem_dir = gems.join(format!("{GEM_NAME}-{GEM_VERSION}-{platform}"));
let lib = gem_dir.join("lib");
std::fs::create_dir_all(&lib).expect("create gem lib dir");
let file = lib.join(format!("{GEM_NAME}.rb"));
std::fs::write(&file, contents).expect("write gem file");
file
}
async fn setup_mock(
server: &MockServer,
installed_before_hash: &str,
installed_after_hash: &str,
installed_before_bytes: &[u8],
installed_after_bytes: &[u8],
) {
let base = base_purl();
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG}/patches/batch")))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"packages": [{
"purl": base,
"patches": [
{ "uuid": UUID_INSTALLED, "purl": qualified(PLATFORM_INSTALLED),
"tier": "free", "cveIds": [], "ghsaIds": [],
"severity": "high", "title": "linux gem" },
{ "uuid": UUID_OTHER, "purl": qualified(PLATFORM_OTHER),
"tier": "free", "cveIds": [], "ghsaIds": [],
"severity": "high", "title": "darwin gem" },
]
}],
"canAccessPaidPatches": false,
})))
.mount(server)
.await;
Mock::given(method("GET"))
.and(path_regex(format!("^/v0/orgs/{ORG}/patches/by-package/.+$")))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"patches": [
{ "uuid": UUID_INSTALLED, "purl": qualified(PLATFORM_INSTALLED),
"publishedAt": "2024-01-01T00:00:00Z", "description": "linux gem",
"license": "MIT", "tier": "free", "vulnerabilities": {} },
{ "uuid": UUID_OTHER, "purl": qualified(PLATFORM_OTHER),
"publishedAt": "2024-01-01T00:00:00Z", "description": "darwin gem",
"license": "MIT", "tier": "free", "vulnerabilities": {} },
],
"canAccessPaidPatches": false,
})))
.mount(server)
.await;
mount_view(
server,
UUID_INSTALLED,
&qualified(PLATFORM_INSTALLED),
installed_before_hash,
installed_after_hash,
installed_before_bytes,
installed_after_bytes,
)
.await;
let other_before = b"# nokogiri.rb from the arm64-darwin gem\n";
let mut other_after = other_before.to_vec();
other_after.extend_from_slice(b"\n# DARWIN-MARKER\n");
mount_view(
server,
UUID_OTHER,
&qualified(PLATFORM_OTHER),
&git_sha256(other_before),
&git_sha256(&other_after),
other_before,
&other_after,
)
.await;
}
#[allow(clippy::too_many_arguments)]
async fn mount_view(
server: &MockServer,
uuid: &str,
purl: &str,
before_hash: &str,
after_hash: &str,
before_bytes: &[u8],
after_bytes: &[u8],
) {
Mock::given(method("GET"))
.and(path(format!("/v0/orgs/{ORG}/patches/view/{uuid}")))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"uuid": uuid,
"purl": purl,
"publishedAt": "2024-01-01T00:00:00Z",
"files": {
"lib/nokogiri.rb": {
"beforeHash": before_hash,
"afterHash": after_hash,
"blobContent": b64(after_bytes),
"beforeBlobContent": b64(before_bytes),
}
},
"vulnerabilities": {},
"description": "gem multi-platform fixture",
"license": "MIT",
"tier": "free",
})))
.mount(server)
.await;
}
fn scan_args(cwd: &Path, api_url: String, all_releases: bool) -> ScanArgs {
ScanArgs {
common: socket_patch_cli::args::GlobalArgs {
cwd: cwd.to_path_buf(),
org: Some(ORG.to_string()),
json: true,
yes: true,
global: false,
global_prefix: None,
api_url,
api_token: Some("fake".to_string()),
ecosystems: Some(vec!["gem".to_string()]),
download_mode: "diff".to_string(),
dry_run: false,
..socket_patch_cli::args::GlobalArgs::default()
},
batch_size: 100,
apply: true,
prune: false,
sync: false,
all_releases,
vex: Default::default(),
}
}
fn manifest_keys(cwd: &Path) -> Vec<String> {
let path = cwd.join(".socket").join("manifest.json");
let raw = std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("manifest not found at {}", path.display()));
let v: serde_json::Value = serde_json::from_str(&raw).expect("manifest json");
v["patches"]
.as_object()
.map(|m| m.keys().cloned().collect())
.unwrap_or_default()
}
fn file_has_marker(file: &Path, marker: &[u8]) -> bool {
let bytes = std::fs::read(file).expect("read file");
bytes.windows(marker.len()).any(|w| w == marker)
}
async fn fixture(cwd: &Path) -> (PathBuf, MockServer) {
let original = b"module Nokogiri\n VERSION = '1.16.5'\nend\n".to_vec();
let file = install_platform_gem(cwd, PLATFORM_INSTALLED, &original);
let before_hash = git_sha256(&original);
let mut patched = original.clone();
patched.extend_from_slice(MARKER_INSTALLED);
let after_hash = git_sha256(&patched);
let server = MockServer::start().await;
setup_mock(&server, &before_hash, &after_hash, &original, &patched).await;
(file, server)
}
#[tokio::test]
#[serial]
async fn narrow_scan_keeps_only_installed_platform() {
let tmp = tempfile::tempdir().expect("tempdir");
let (gem_file, server) = fixture(tmp.path()).await;
let code = scan_run(scan_args(tmp.path(), server.uri(), false)).await;
assert!(code == 0 || code == 1, "scan exit: {code}");
let keys = manifest_keys(tmp.path());
assert_eq!(
keys,
vec![qualified(PLATFORM_INSTALLED)],
"narrow scan must store only the installed platform variant; got {keys:?}"
);
assert!(
file_has_marker(&gem_file, MARKER_INSTALLED),
"installed platform gem should be patched"
);
}
#[tokio::test]
#[serial]
async fn broad_scan_keeps_all_platforms() {
let tmp = tempfile::tempdir().expect("tempdir");
let (gem_file, server) = fixture(tmp.path()).await;
let code = scan_run(scan_args(tmp.path(), server.uri(), true)).await;
assert!(code == 0 || code == 1, "scan exit: {code}");
let mut keys = manifest_keys(tmp.path());
keys.sort();
let mut expected = vec![qualified(PLATFORM_INSTALLED), qualified(PLATFORM_OTHER)];
expected.sort();
assert_eq!(keys, expected, "broad scan must store every platform variant");
assert!(
file_has_marker(&gem_file, MARKER_INSTALLED),
"broad apply should patch with the installed platform variant"
);
}
#[tokio::test]
#[serial]
async fn remove_base_purl_clears_all_platforms_and_rolls_back() {
let tmp = tempfile::tempdir().expect("tempdir");
let (gem_file, server) = fixture(tmp.path()).await;
let _ = scan_run(scan_args(tmp.path(), server.uri(), true)).await;
assert_eq!(manifest_keys(tmp.path()).len(), 2);
assert!(file_has_marker(&gem_file, MARKER_INSTALLED));
let remove_args = RemoveArgs {
identifier: base_purl(),
common: socket_patch_cli::args::GlobalArgs {
cwd: tmp.path().to_path_buf(),
org: Some(ORG.to_string()),
api_url: server.uri(),
api_token: Some("fake".to_string()),
json: true,
yes: true,
ecosystems: Some(vec!["gem".to_string()]),
..socket_patch_cli::args::GlobalArgs::default()
},
skip_rollback: false,
};
let code = remove_run(remove_args).await;
assert_eq!(code, 0, "remove base PURL should succeed (exit 0)");
assert!(
manifest_keys(tmp.path()).is_empty(),
"all platform variants should be removed from the manifest"
);
assert!(
!file_has_marker(&gem_file, MARKER_INSTALLED),
"remove should roll the gem file back to its original bytes"
);
}
#[tokio::test]
#[serial]
async fn rollback_all_over_broad_manifest_succeeds() {
let tmp = tempfile::tempdir().expect("tempdir");
let (gem_file, server) = fixture(tmp.path()).await;
let _ = scan_run(scan_args(tmp.path(), server.uri(), true)).await;
assert_eq!(manifest_keys(tmp.path()).len(), 2);
assert!(file_has_marker(&gem_file, MARKER_INSTALLED));
let rollback_args = RollbackArgs {
identifier: None,
common: socket_patch_cli::args::GlobalArgs {
cwd: tmp.path().to_path_buf(),
org: Some(ORG.to_string()),
api_url: server.uri(),
api_token: Some("fake".to_string()),
json: true,
ecosystems: Some(vec!["gem".to_string()]),
..socket_patch_cli::args::GlobalArgs::default()
},
one_off: false,
};
let code = rollback_run(rollback_args).await;
assert_eq!(code, 0, "rollback-all over broad manifest should exit 0");
assert!(
!file_has_marker(&gem_file, MARKER_INSTALLED),
"rollback should restore the original gem file"
);
}