#![allow(dead_code)]
use async_trait::async_trait;
use atproto_devtool::commands::test::labeler::create_report::{
CreateReportStageError, CreateReportTee, PdsXrpcClient, RawCreateReportResponse,
RawPdsXrpcResponse,
};
use atproto_devtool::commands::test::labeler::http::{HttpStageError, RawHttpTee, RawXrpcResponse};
use atproto_devtool::commands::test::labeler::subscription::{
FrameStream, SubscriptionStageError, WebSocketClient,
};
use reqwest::StatusCode;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use url::Url;
pub type FakeHttpResponses = Arc<Mutex<HashMap<Option<String>, (reqwest::StatusCode, Vec<u8>)>>>;
pub struct FakeRawHttpTee {
responses: FakeHttpResponses,
transport_error: Arc<Mutex<bool>>,
}
impl FakeRawHttpTee {
pub fn new() -> Self {
Self {
responses: Arc::new(Mutex::new(HashMap::new())),
transport_error: Arc::new(Mutex::new(false)),
}
}
pub fn add_response(&self, cursor: Option<&str>, status: u16, body: Vec<u8>) {
self.responses.lock().unwrap().insert(
cursor.map(|s| s.to_string()),
(reqwest::StatusCode::from_u16(status).unwrap(), body),
);
}
pub fn set_transport_error(&self) {
*self.transport_error.lock().unwrap() = true;
}
}
impl Default for FakeRawHttpTee {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RawHttpTee for FakeRawHttpTee {
async fn query_labels(&self, cursor: Option<&str>) -> Result<RawXrpcResponse, HttpStageError> {
if *self.transport_error.lock().unwrap() {
return Err(HttpStageError::Transport {
message: "tcp connect: connection refused".into(),
source: None,
});
}
let cursor_key = cursor.map(|s| s.to_string());
let responses = self.responses.lock().unwrap();
match responses.get(&cursor_key) {
Some((status, body)) => {
let raw_body: Arc<[u8]> = Arc::from(body.as_slice());
let decoded = serde_json::from_slice::<
atrium_api::com::atproto::label::query_labels::Output,
>(body)
.map_err(|source| HttpStageError::DecodeFailed {
raw_body: raw_body.clone(),
source,
source_url: "https://example.com/xrpc/com.atproto.label.queryLabels"
.to_string(),
})?;
Ok(RawXrpcResponse {
status: *status,
raw_body,
decoded,
source_url: "https://example.com/xrpc/com.atproto.label.queryLabels"
.to_string(),
})
}
None => {
Err(HttpStageError::Transport {
message: "tcp connect: connection refused".into(),
source: None,
})
}
}
}
}
#[derive(Debug, Clone)]
pub enum FakeCreateReportResponse {
Transport {
message: String,
},
Response {
status: u16,
content_type: Option<String>,
body: Vec<u8>,
},
}
impl FakeCreateReportResponse {
pub fn ok_empty() -> Self {
Self::Response {
status: 200,
content_type: Some("application/json".to_string()),
body: br#"{"id":1,"reasonType":"com.atproto.moderation.defs#reasonOther","subject":{"$type":"com.atproto.admin.defs#repoRef","did":"did:plc:aaa22222222222222222bbbbbb"},"reportedBy":"did:web:127.0.0.1%3A0","createdAt":"2026-04-17T00:00:00.000Z"}"#.to_vec(),
}
}
pub fn unauthorized(error_name: &str, message: &str) -> Self {
Self::Response {
status: 401,
content_type: Some("application/json".to_string()),
body: serde_json::to_vec(&serde_json::json!({
"error": error_name,
"message": message,
}))
.unwrap(),
}
}
pub fn bad_request(error_name: &str, message: &str) -> Self {
Self::Response {
status: 400,
content_type: Some("application/json".to_string()),
body: serde_json::to_vec(&serde_json::json!({
"error": error_name,
"message": message,
}))
.unwrap(),
}
}
}
#[derive(Debug, Clone)]
pub struct RecordedCreateReportRequest {
pub auth: Option<String>,
pub body: serde_json::Value,
}
pub struct FakeCreateReportTee {
scripts: Arc<Mutex<Vec<FakeCreateReportResponse>>>,
recorded: Arc<Mutex<Vec<RecordedCreateReportRequest>>>,
}
impl FakeCreateReportTee {
pub fn new() -> Self {
Self {
scripts: Arc::new(Mutex::new(Vec::new())),
recorded: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn enqueue(&self, response: FakeCreateReportResponse) {
self.scripts.lock().unwrap().push(response);
}
pub fn recorded_requests(&self) -> Vec<RecordedCreateReportRequest> {
self.recorded.lock().unwrap().clone()
}
pub fn last_request(&self) -> RecordedCreateReportRequest {
self.recorded
.lock()
.unwrap()
.last()
.cloned()
.expect("FakeCreateReportTee: no requests recorded yet")
}
}
impl Default for FakeCreateReportTee {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl CreateReportTee for FakeCreateReportTee {
async fn post_create_report(
&self,
auth: Option<&str>,
body: &serde_json::Value,
) -> Result<RawCreateReportResponse, CreateReportStageError> {
self.recorded
.lock()
.unwrap()
.push(RecordedCreateReportRequest {
auth: auth.map(|s| s.to_string()),
body: body.clone(),
});
let mut scripts = self.scripts.lock().unwrap();
if scripts.is_empty() {
panic!(
"FakeCreateReportTee: post_create_report called with no script queued. \
Each test must enqueue() exactly the responses it expects the stage to consume."
);
}
let script = scripts.remove(0);
match script {
FakeCreateReportResponse::Transport { message } => {
Err(CreateReportStageError::Transport {
source: Box::new(std::io::Error::other(message)),
})
}
FakeCreateReportResponse::Response {
status,
content_type,
body,
} => {
let raw_body: Arc<[u8]> = Arc::from(body.as_slice());
Ok(RawCreateReportResponse {
status: StatusCode::from_u16(status).expect("test must use valid HTTP status"),
content_type: content_type.map(|s| s.to_ascii_lowercase()),
raw_body,
source_url: "https://labeler.test/xrpc/com.atproto.moderation.createReport"
.to_string(),
})
}
}
}
}
#[derive(Debug, Clone)]
pub enum FakePdsXrpcResponse {
Transport { message: String },
Response { status: u16, body: Vec<u8> },
}
#[derive(Debug, Clone)]
pub struct RecordedPdsRequest {
pub method: &'static str,
pub path: String,
pub bearer: Option<String>,
pub atproto_proxy: Option<String>,
pub body: Option<serde_json::Value>,
pub query: Vec<(String, String)>,
}
pub struct FakePdsXrpcClient {
scripts: Arc<Mutex<Vec<FakePdsXrpcResponse>>>,
recorded: Arc<Mutex<Vec<RecordedPdsRequest>>>,
}
impl FakePdsXrpcClient {
pub fn new() -> Self {
Self {
scripts: Arc::new(Mutex::new(Vec::new())),
recorded: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn enqueue(&self, response: FakePdsXrpcResponse) {
self.scripts.lock().unwrap().push(response);
}
pub fn recorded_requests(&self) -> Vec<RecordedPdsRequest> {
self.recorded.lock().unwrap().clone()
}
pub fn last_request(&self) -> RecordedPdsRequest {
self.recorded
.lock()
.unwrap()
.last()
.cloned()
.expect("FakePdsXrpcClient: no requests recorded yet")
}
fn dispatch_next(&self) -> Result<RawPdsXrpcResponse, CreateReportStageError> {
let mut scripts = self.scripts.lock().unwrap();
if scripts.is_empty() {
panic!(
"FakePdsXrpcClient: call made with no script queued. \
Each test must enqueue() exactly the responses it expects."
);
}
let script = scripts.remove(0);
match script {
FakePdsXrpcResponse::Transport { message } => Err(CreateReportStageError::Transport {
source: Box::new(std::io::Error::other(message)),
}),
FakePdsXrpcResponse::Response { status, body } => {
let raw_body: Arc<[u8]> = Arc::from(body.as_slice());
Ok(RawPdsXrpcResponse {
status: StatusCode::from_u16(status).expect("test must use valid HTTP status"),
raw_body,
content_type: Some("application/json".to_string()),
source_url: "https://pds.test/xrpc".to_string(),
})
}
}
}
}
impl Default for FakePdsXrpcClient {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PdsXrpcClient for FakePdsXrpcClient {
async fn post(
&self,
path: &str,
bearer: Option<&str>,
atproto_proxy: Option<&str>,
body: &serde_json::Value,
) -> Result<RawPdsXrpcResponse, CreateReportStageError> {
self.recorded.lock().unwrap().push(RecordedPdsRequest {
method: "POST",
path: path.to_string(),
bearer: bearer.map(String::from),
atproto_proxy: atproto_proxy.map(String::from),
body: Some(body.clone()),
query: Vec::new(),
});
self.dispatch_next()
}
async fn get(
&self,
path: &str,
bearer: Option<&str>,
query: &[(&str, &str)],
) -> Result<RawPdsXrpcResponse, CreateReportStageError> {
self.recorded.lock().unwrap().push(RecordedPdsRequest {
method: "GET",
path: path.to_string(),
bearer: bearer.map(String::from),
atproto_proxy: None,
body: None,
query: query
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect(),
});
self.dispatch_next()
}
}
pub struct FakeWebSocketClient {
scripts: Arc<Mutex<Vec<FakeScript>>>,
silent_default: bool,
}
pub struct FakeScript {
pub frames: Vec<Vec<u8>>,
pub inter_frame_delay: Duration,
pub final_wait: Option<Duration>,
pub transport_error: bool,
pub mid_stream_error: bool,
}
impl FakeWebSocketClient {
pub fn new() -> Self {
Self {
scripts: Arc::new(Mutex::new(Vec::new())),
silent_default: false,
}
}
pub fn empty() -> Self {
Self {
scripts: Arc::new(Mutex::new(Vec::new())),
silent_default: true,
}
}
pub fn add_script(&self, script: FakeScript) {
self.scripts.lock().unwrap().push(script);
}
}
impl Default for FakeWebSocketClient {
fn default() -> Self {
Self::new()
}
}
struct FakeFrameStream {
frames: Vec<Vec<u8>>,
current_frame: usize,
inter_frame_delay: Duration,
final_wait: Option<Duration>,
mid_stream_error: bool,
mid_stream_error_yielded: bool,
}
#[async_trait]
impl FrameStream for FakeFrameStream {
async fn next_frame(&mut self) -> Option<Result<Vec<u8>, SubscriptionStageError>> {
if self.current_frame < self.frames.len() {
if self.current_frame > 0 {
tokio::time::sleep(self.inter_frame_delay).await;
}
let frame = self.frames[self.current_frame].clone();
self.current_frame += 1;
return Some(Ok(frame));
}
if self.mid_stream_error && !self.mid_stream_error_yielded {
self.mid_stream_error_yielded = true;
tokio::time::sleep(self.inter_frame_delay).await;
return Some(Err(SubscriptionStageError::Transport {
message: "fake mid-stream transport error".to_string(),
source: None,
}));
}
if let Some(wait_duration) = self.final_wait.take() {
tokio::time::sleep(wait_duration).await;
}
None
}
async fn close(&mut self) {
}
}
#[async_trait]
impl WebSocketClient for FakeWebSocketClient {
async fn connect(&self, _url: &Url) -> Result<Box<dyn FrameStream>, SubscriptionStageError> {
let mut scripts = self.scripts.lock().unwrap();
if scripts.is_empty() {
if self.silent_default {
return Ok(Box::new(FakeFrameStream {
frames: vec![],
current_frame: 0,
inter_frame_delay: Duration::from_millis(0),
final_wait: None,
mid_stream_error: false,
mid_stream_error_yielded: false,
}));
}
panic!(
"FakeWebSocketClient: no script queued for this connect() call. \
Each subscription test must declare exactly the scripts it expects \
the stage to consume. Use fake_ws.add_script() for each connect() \
the stage will make. (Identity tests should use FakeWebSocketClient::empty() instead.)"
);
}
let script = scripts.remove(0);
if script.transport_error {
return Err(SubscriptionStageError::Transport {
message: "fake transport error".to_string(),
source: None,
});
}
Ok(Box::new(FakeFrameStream {
frames: script.frames,
current_frame: 0,
inter_frame_delay: script.inter_frame_delay,
final_wait: script.final_wait,
mid_stream_error: script.mid_stream_error,
mid_stream_error_yielded: false,
}))
}
}
pub fn normalize_timing(rendered: String) -> String {
let mut result = String::with_capacity(rendered.len());
let mut rest = rendered.as_str();
while let Some(pos) = rest.find("elapsed: ") {
let after = pos + "elapsed: ".len();
result.push_str(&rest[..after]);
let tail = &rest[after..];
if let Some(end) = tail.find("ms") {
result.push_str("XXms");
rest = &tail[end + 2..];
} else {
result.push_str(tail);
return result;
}
}
result.push_str(rest);
result
}