stprobe 0.2.3

A minimal CLI for inspecting safetensors headers
Documentation
mod common;

use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

#[test]
fn inspects_remote_files_via_range_requests() {
    let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
    let server = runtime.block_on(MockServer::start());
    let bytes = common::sample_safetensors_bytes();
    let total_size = bytes.len();
    let (prefix, header_bytes) = common::split_header(&bytes);
    let header_end = 8 + header_bytes.len() - 1;
    let header_range = format!("bytes=8-{header_end}");
    let base_url = server.uri();
    let resolve_path = "/resolve/main/sample.safetensors";
    let cdn_path = "/cdn/sample.safetensors";

    runtime.block_on(async {
        Mock::given(method("GET"))
            .and(path(resolve_path))
            .respond_with(
                ResponseTemplate::new(302)
                    .append_header("Location", format!("{base_url}{cdn_path}")),
            )
            .expect(2)
            .mount(&server)
            .await;

        Mock::given(method("GET"))
            .and(path(cdn_path))
            .and(header("range", "bytes=0-7"))
            .respond_with(
                ResponseTemplate::new(206)
                    .append_header("Content-Range", format!("bytes 0-7/{total_size}"))
                    .set_body_bytes(prefix),
            )
            .expect(1)
            .mount(&server)
            .await;

        Mock::given(method("GET"))
            .and(path(cdn_path))
            .and(header("range", header_range.as_str()))
            .respond_with(
                ResponseTemplate::new(206)
                    .append_header(
                        "Content-Range",
                        format!("bytes 8-{header_end}/{total_size}"),
                    )
                    .set_body_bytes(header_bytes),
            )
            .expect(1)
            .mount(&server)
            .await;
    });

    let url = format!("{base_url}{resolve_path}");
    let report = stprobe::inspect_input(&url).expect("inspect remote safetensors");
    let output = stprobe::render_report(&report);

    assert!(output.contains(&format!("File: {url}")));
    assert!(output.contains("Tensors: 2"));
    assert!(output.contains("Parameters: 4"));
    assert!(output.contains("Tensor-Bytes: 24"));
    assert!(output.contains("  format = pt"));
    assert!(output.contains("  embedding.ids"));
    assert!(output.contains("  embedding.weight"));
}

#[test]
fn reports_servers_without_range_support() {
    let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
    let server = runtime.block_on(MockServer::start());

    runtime.block_on(async {
        Mock::given(method("GET"))
            .and(path("/sample.safetensors"))
            .and(header("range", "bytes=0-7"))
            .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0; 8]))
            .expect(1)
            .mount(&server)
            .await;
    });

    let url = format!("{}/sample.safetensors", server.uri());
    let error = stprobe::inspect_input(&url).expect_err("range support error");

    assert_eq!(
        error.to_string(),
        format!("remote server does not support byte range requests: {url}")
    );
}