use std::path::{Path, PathBuf};
use std::process::Command;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const ORG_SLUG: &str = "telemetry-test-org";
fn binary() -> PathBuf {
PathBuf::from(env!("CARGO_BIN_EXE_socket-patch"))
}
fn write_root_package_json(root: &Path) {
std::fs::write(
root.join("package.json"),
r#"{"name":"telemetry-test","version":"0.0.0"}"#,
)
.unwrap();
}
fn write_npm_package(root: &Path, name: &str, version: &str) {
let pkg = root.join("node_modules").join(name);
std::fs::create_dir_all(&pkg).unwrap();
let manifest = format!(r#"{{"name":"{name}","version":"{version}"}}"#);
std::fs::write(pkg.join("package.json"), manifest).unwrap();
}
fn run_cmd(
cwd: &Path,
api_url: &str,
subcommand: &str,
extra_args: &[&str],
extra_env: &[(&str, &str)],
) -> (i32, String, String) {
let mut args = vec![
subcommand,
"--json",
"--api-url",
api_url,
"--api-token",
"fake-token-for-test",
"--org",
ORG_SLUG,
];
args.extend_from_slice(extra_args);
let mut cmd = Command::new(binary());
cmd.args(&args).current_dir(cwd);
cmd.env_remove("VITEST");
cmd.env_remove("SOCKET_TELEMETRY_DISABLED");
cmd.env_remove("SOCKET_PATCH_TELEMETRY_DISABLED");
cmd.env_remove("SOCKET_OFFLINE");
cmd.env("SOCKET_API_URL", api_url);
cmd.env("SOCKET_PROXY_URL", api_url);
for (k, v) in extra_env {
cmd.env(k, v);
}
let out = cmd.output().expect("run socket-patch");
(
out.status.code().unwrap_or(-1),
String::from_utf8_lossy(&out.stdout).to_string(),
String::from_utf8_lossy(&out.stderr).to_string(),
)
}
async fn telemetry_post_count(mock: &MockServer, event_type: Option<&str>) -> usize {
let received = mock
.received_requests()
.await
.expect("wiremock allows recording");
received
.iter()
.filter(|req| {
req.method == wiremock::http::Method::POST
&& req
.url
.path()
.ends_with(&format!("/v0/orgs/{ORG_SLUG}/telemetry"))
})
.filter(|req| match event_type {
None => true,
Some(want) => match serde_json::from_slice::<serde_json::Value>(&req.body) {
Ok(v) => v.get("event_type").and_then(|t| t.as_str()) == Some(want),
Err(_) => false,
},
})
.count()
}
async fn setup_mock(
batch_response: serde_json::Value,
fetch_uuid_response: Option<serde_json::Value>,
) -> MockServer {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG_SLUG}/patches/batch")))
.respond_with(ResponseTemplate::new(200).set_body_json(batch_response))
.mount(&mock)
.await;
if let Some(body) = fetch_uuid_response {
Mock::given(method("GET"))
.and(wiremock::matchers::path_regex(format!(
"^/v0/orgs/{ORG_SLUG}/patches/[0-9a-f-]+$"
)))
.respond_with(ResponseTemplate::new(200).set_body_json(body))
.mount(&mock)
.await;
}
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG_SLUG}/telemetry")))
.respond_with(ResponseTemplate::new(201))
.mount(&mock)
.await;
mock
}
#[tokio::test]
async fn scan_emits_patch_scanned_telemetry_on_success() {
let mock = setup_mock(
serde_json::json!({ "packages": [], "canAccessPaidPatches": false }),
None,
)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
write_npm_package(tmp.path(), "minimist", "1.2.2");
let (code, _stdout, _stderr) = run_cmd(tmp.path(), &mock.uri(), "scan", &[], &[]);
assert_eq!(code, 0);
let count = telemetry_post_count(&mock, Some("patch_scanned")).await;
assert_eq!(
count, 1,
"scan must POST exactly one patch_scanned telemetry event"
);
}
#[tokio::test]
async fn scan_skips_telemetry_in_airgap_mode() {
let mock = setup_mock(
serde_json::json!({ "packages": [], "canAccessPaidPatches": false }),
None,
)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
write_npm_package(tmp.path(), "minimist", "1.2.2");
let (_code, _stdout, _stderr) =
run_cmd(tmp.path(), &mock.uri(), "scan", &[], &[("SOCKET_OFFLINE", "1")]);
let count = telemetry_post_count(&mock, None).await;
assert_eq!(
count, 0,
"SOCKET_OFFLINE=1 must suppress every telemetry POST during scan"
);
}
#[tokio::test]
async fn get_emits_patch_fetched_telemetry_on_uuid_lookup_success() {
const UUID: &str = "12345678-1234-4123-8123-123456789abc";
let patch_response = serde_json::json!({
"uuid": UUID,
"purl": "pkg:npm/lodash@4.17.20",
"tier": "free",
"publishedAt": "2024-06-01T00:00:00Z",
"license": "MIT",
"description": "test patch",
"files": {},
"vulnerabilities": {},
});
let mock = setup_mock(
serde_json::json!({ "packages": [], "canAccessPaidPatches": false }),
Some(patch_response),
)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
write_npm_package(tmp.path(), "lodash", "4.17.20");
let (_code, _stdout, _stderr) = run_cmd(
tmp.path(),
&mock.uri(),
"get",
&["--id", UUID],
&[],
);
let fetched = telemetry_post_count(&mock, Some("patch_fetched")).await;
let failed = telemetry_post_count(&mock, Some("patch_fetch_failed")).await;
assert!(
fetched + failed >= 1,
"get --id UUID must POST a patch_fetched or patch_fetch_failed event \
(saw fetched={fetched} failed={failed})"
);
}
#[tokio::test]
async fn get_skips_telemetry_in_airgap_mode() {
const UUID: &str = "deadbeef-dead-4eef-8eef-deadbeefdead";
let mock = setup_mock(
serde_json::json!({ "packages": [], "canAccessPaidPatches": false }),
Some(serde_json::json!({
"uuid": UUID,
"purl": "pkg:npm/lodash@4.17.20",
"tier": "free",
"publishedAt": "2024-06-01T00:00:00Z",
"license": "MIT",
"description": "test patch",
"files": {},
"vulnerabilities": {},
})),
)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
write_npm_package(tmp.path(), "lodash", "4.17.20");
let (_code, _stdout, _stderr) = run_cmd(
tmp.path(),
&mock.uri(),
"get",
&["--id", UUID],
&[("SOCKET_OFFLINE", "1")],
);
let count = telemetry_post_count(&mock, None).await;
assert_eq!(
count, 0,
"SOCKET_OFFLINE=1 must suppress every telemetry POST during get"
);
}
#[tokio::test]
async fn apply_skips_telemetry_in_airgap_mode() {
let mock = setup_mock(
serde_json::json!({ "packages": [], "canAccessPaidPatches": false }),
None,
)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
let socket = tmp.path().join(".socket");
std::fs::create_dir_all(&socket).unwrap();
std::fs::write(
socket.join("manifest.json"),
r#"{"patches":{}}"#,
)
.unwrap();
let (_code, _stdout, _stderr) = run_cmd(
tmp.path(),
&mock.uri(),
"apply",
&[],
&[("SOCKET_OFFLINE", "1")],
);
let count = telemetry_post_count(&mock, None).await;
assert_eq!(
count, 0,
"SOCKET_OFFLINE=1 must suppress patch_applied telemetry"
);
}
#[tokio::test]
async fn list_emits_patch_listed_telemetry_when_telemetry_enabled() {
let mock = setup_mock(
serde_json::json!({ "packages": [], "canAccessPaidPatches": false }),
None,
)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
let socket = tmp.path().join(".socket");
std::fs::create_dir_all(&socket).unwrap();
std::fs::write(
socket.join("manifest.json"),
r#"{"patches":{}}"#,
)
.unwrap();
let (code, _stdout, _stderr) = run_cmd(tmp.path(), &mock.uri(), "list", &[], &[]);
assert_eq!(code, 0);
let count = telemetry_post_count(&mock, Some("patch_listed")).await;
assert_eq!(count, 1, "list must POST exactly one patch_listed event");
}
#[tokio::test]
async fn scan_falls_back_to_proxy_on_401_and_tags_telemetry() {
let auth_mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG_SLUG}/patches/batch")))
.respond_with(ResponseTemplate::new(401).set_body_string("invalid token"))
.mount(&auth_mock)
.await;
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG_SLUG}/telemetry")))
.respond_with(ResponseTemplate::new(201))
.mount(&auth_mock)
.await;
let proxy_mock = MockServer::start().await;
Mock::given(method("GET"))
.and(wiremock::matchers::path_regex(r"^/patch/by-package/.*$"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"patches": [],
"canAccessPaidPatches": false,
})))
.mount(&proxy_mock)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
write_npm_package(tmp.path(), "minimist", "1.2.2");
let (code, _stdout, stderr) = run_cmd(
tmp.path(),
&auth_mock.uri(),
"scan",
&[],
&[("SOCKET_PROXY_URL", &proxy_mock.uri())],
);
assert_eq!(code, 0, "scan must succeed after falling back to proxy");
assert!(
stderr.contains("falling back to public patch API proxy"),
"stderr must carry the fallback warning; got: {stderr}"
);
let received = auth_mock
.received_requests()
.await
.expect("recording enabled");
let telemetry_bodies: Vec<serde_json::Value> = received
.iter()
.filter(|r| {
r.method == wiremock::http::Method::POST
&& r.url
.path()
.ends_with(&format!("/v0/orgs/{ORG_SLUG}/telemetry"))
})
.filter_map(|r| serde_json::from_slice(&r.body).ok())
.collect();
let scanned = telemetry_bodies
.iter()
.find(|v| v.get("event_type").and_then(|t| t.as_str()) == Some("patch_scanned"))
.expect("a patch_scanned event must reach the recorder");
assert_eq!(
scanned["metadata"]["fallback_to_proxy"],
serde_json::Value::Bool(true),
"fallback must be reflected in telemetry metadata; got {scanned}"
);
}
#[tokio::test]
async fn scan_does_not_fall_back_on_500() {
let auth_mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG_SLUG}/patches/batch")))
.respond_with(ResponseTemplate::new(500).set_body_string("backend on fire"))
.mount(&auth_mock)
.await;
Mock::given(method("POST"))
.and(path(format!("/v0/orgs/{ORG_SLUG}/telemetry")))
.respond_with(ResponseTemplate::new(201))
.mount(&auth_mock)
.await;
let proxy_mock = MockServer::start().await;
Mock::given(method("GET"))
.and(wiremock::matchers::path_regex(r"^/patch/by-package/.*$"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"patches": [],
"canAccessPaidPatches": false,
})))
.mount(&proxy_mock)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
write_npm_package(tmp.path(), "minimist", "1.2.2");
let (_code, _stdout, stderr) = run_cmd(
tmp.path(),
&auth_mock.uri(),
"scan",
&[],
&[("SOCKET_PROXY_URL", &proxy_mock.uri())],
);
assert!(
!stderr.contains("falling back"),
"5xx must NOT trigger fallback; stderr was: {stderr}"
);
let proxy_hits = proxy_mock
.received_requests()
.await
.expect("recording enabled")
.len();
assert_eq!(
proxy_hits, 0,
"proxy must not be queried after a 500 from the auth endpoint"
);
}
#[tokio::test]
async fn list_skips_telemetry_in_airgap_mode() {
let mock = setup_mock(
serde_json::json!({ "packages": [], "canAccessPaidPatches": false }),
None,
)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
write_root_package_json(tmp.path());
let socket = tmp.path().join(".socket");
std::fs::create_dir_all(&socket).unwrap();
std::fs::write(
socket.join("manifest.json"),
r#"{"patches":{}}"#,
)
.unwrap();
let (_code, _stdout, _stderr) = run_cmd(
tmp.path(),
&mock.uri(),
"list",
&[],
&[("SOCKET_OFFLINE", "1")],
);
let count = telemetry_post_count(&mock, None).await;
assert_eq!(count, 0, "SOCKET_OFFLINE=1 must suppress patch_listed");
}