#![cfg(feature = "streamable-http")]
use async_trait::async_trait;
use mockito::Server as MockServer;
use pmcp::shared::streamable_http::{
AuthProvider, SendOptions, StreamableHttpTransport, StreamableHttpTransportConfigBuilder,
};
use pmcp::shared::TransportMessage;
use pmcp::types::{ClientNotification, Notification};
use proptest::prelude::*;
use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use tokio::runtime::Runtime;
use url::Url;
const STATUS_CODES: &[u16] = &[200, 202, 400, 401, 403, 404, 500, 503];
fn rt() -> &'static Runtime {
static RT: OnceLock<Runtime> = OnceLock::new();
RT.get_or_init(|| Runtime::new().expect("failed to build tokio runtime for proptest harness"))
}
struct CountingProvider {
token: String,
get_count: AtomicUsize,
unauthorized_count: AtomicUsize,
}
impl fmt::Debug for CountingProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CountingProvider")
.field("get_count", &self.get_count.load(Ordering::SeqCst))
.field(
"unauthorized_count",
&self.unauthorized_count.load(Ordering::SeqCst),
)
.finish()
}
}
impl CountingProvider {
fn new(token: impl Into<String>) -> Self {
Self {
token: token.into(),
get_count: AtomicUsize::new(0),
unauthorized_count: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl AuthProvider for CountingProvider {
async fn get_access_token(&self) -> pmcp::Result<String> {
self.get_count.fetch_add(1, Ordering::SeqCst);
Ok(self.token.clone())
}
async fn on_unauthorized(&self) -> pmcp::Result<()> {
self.unauthorized_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
fn make_transport(url: Url, provider: Option<Arc<dyn AuthProvider>>) -> StreamableHttpTransport {
let mut builder = StreamableHttpTransportConfigBuilder::new(url);
if let Some(p) = provider {
builder = builder.with_auth_provider(p);
}
StreamableHttpTransport::new(builder.build())
}
fn ping_message() -> TransportMessage {
TransportMessage::Notification(Notification::Client(ClientNotification::Initialized))
}
async fn run_one_case(status: u16, provider: Option<Arc<CountingProvider>>) -> (usize, usize) {
let mut server = MockServer::new_async().await;
let _mock = server
.mock("POST", "/")
.with_status(status as usize)
.with_header("content-type", "application/json")
.with_body(r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
.create_async()
.await;
let url = Url::parse(&server.url()).unwrap();
let dyn_provider = provider.clone().map(|p| p as Arc<dyn AuthProvider>);
let mut transport = make_transport(url, dyn_provider);
let _ = transport
.send_with_options(ping_message(), SendOptions::default())
.await;
provider.map_or((0, 0), |p| {
(
p.get_count.load(Ordering::SeqCst),
p.unauthorized_count.load(Ordering::SeqCst),
)
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(8))]
#[test]
fn property_on_unauthorized_triggers_iff_401(status in proptest::sample::select(STATUS_CODES)) {
let provider = Arc::new(CountingProvider::new("token"));
let (get_count, unauth_count) =
rt().block_on(run_one_case(status, Some(provider)));
if status == 401 {
prop_assert_eq!(unauth_count, 1, "on_unauthorized must fire exactly once on 401");
prop_assert_eq!(get_count, 2, "get_access_token must be called twice on 401 (original + retry)");
} else {
prop_assert_eq!(unauth_count, 0,
"on_unauthorized must NOT fire on non-401 status {}", status);
prop_assert_eq!(get_count, 1,
"get_access_token must be called exactly once on non-401 status {}", status);
}
}
#[test]
fn property_no_provider_means_no_retry(status in proptest::sample::select(STATUS_CODES)) {
let (get_count, unauth_count) = rt().block_on(run_one_case(status, None));
prop_assert_eq!(get_count, 0, "without a provider, get_access_token cannot be called");
prop_assert_eq!(unauth_count, 0, "without a provider, on_unauthorized cannot be called");
}
}