use std::collections::{HashMap, VecDeque};
use std::io::{BufReader, Write};
use std::net::TcpListener;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use ff_rdp_core::transport::{encode_frame, recv_from};
use serde_json::Value;
type SeqEntry = (Value, Vec<Value>);
enum HandlerKind {
Fixed(Value, Vec<Value>),
Sequence {
queue: Arc<Mutex<VecDeque<SeqEntry>>>,
last: Arc<Mutex<SeqEntry>>,
},
}
pub struct MockRdpServer {
listener: TcpListener,
greeting: Value,
handlers: Vec<(String, HandlerKind)>,
close_after_followups: bool,
call_counters: HashMap<String, Arc<AtomicUsize>>,
}
impl MockRdpServer {
pub fn new() -> Self {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind random port");
Self {
listener,
greeting: serde_json::json!({
"from": "root",
"applicationType": "browser",
"traits": {}
}),
handlers: Vec::new(),
close_after_followups: false,
call_counters: HashMap::new(),
}
}
pub fn call_counter(&mut self, method: &str) -> Arc<AtomicUsize> {
Arc::clone(
self.call_counters
.entry(method.to_owned())
.or_insert_with(|| Arc::new(AtomicUsize::new(0))),
)
}
#[allow(dead_code)]
pub fn close_after_followups(mut self) -> Self {
self.close_after_followups = true;
self
}
pub fn port(&self) -> u16 {
self.listener.local_addr().expect("local_addr").port()
}
pub fn on(mut self, method: &str, response: Value) -> Self {
self.handlers
.push((method.to_owned(), HandlerKind::Fixed(response, Vec::new())));
self
}
pub fn on_with_followup(mut self, method: &str, response: Value, followup: Value) -> Self {
self.handlers.push((
method.to_owned(),
HandlerKind::Fixed(response, vec![followup]),
));
self
}
pub fn on_with_followups(
mut self,
method: &str,
response: Value,
followups: Vec<Value>,
) -> Self {
self.handlers
.push((method.to_owned(), HandlerKind::Fixed(response, followups)));
self
}
pub fn on_sequence(mut self, method: &str, responses: Vec<(Value, Vec<Value>)>) -> Self {
assert!(
!responses.is_empty(),
"on_sequence requires at least one response"
);
let last = responses.last().expect("checked non-empty").clone();
let queue: VecDeque<(Value, Vec<Value>)> = responses.into();
self.handlers.push((
method.to_owned(),
HandlerKind::Sequence {
queue: Arc::new(Mutex::new(queue)),
last: Arc::new(Mutex::new(last)),
},
));
self
}
pub fn serve_one(self) {
let (stream, _peer) = self.listener.accept().expect("accept");
let mut writer = stream.try_clone().expect("try_clone");
let mut reader = BufReader::new(stream);
let greeting_json = serde_json::to_string(&self.greeting).expect("greeting encode");
writer
.write_all(encode_frame(&greeting_json).as_bytes())
.expect("greeting write");
'conn: loop {
let request = match recv_from(&mut reader) {
Ok(v) => v,
Err(ff_rdp_core::ProtocolError::RecvFailed(io_err))
if io_err.kind() == std::io::ErrorKind::UnexpectedEof
|| io_err.kind() == std::io::ErrorKind::ConnectionReset =>
{
break;
}
Err(_) => break,
};
let method = request
.get("type")
.and_then(Value::as_str)
.unwrap_or_default();
if let Some(counter) = self.call_counters.get(method) {
counter.fetch_add(1, Ordering::SeqCst);
}
let handler = self.handlers.iter().find(|(m, _)| m == method);
let (reply, followups) = if let Some((_, kind)) = handler {
match kind {
HandlerKind::Fixed(resp, follows) => (resp.clone(), follows.clone()),
HandlerKind::Sequence { queue, last } => {
let mut q = queue.lock().expect("sequence queue lock");
if let Some(entry) = q.pop_front() {
*last.lock().expect("sequence last lock") = entry.clone();
entry
} else {
last.lock().expect("sequence last lock").clone()
}
}
}
} else {
(
serde_json::json!({
"from": "root",
"error": "unknownMethod",
"message": format!("no handler for type={method:?}")
}),
Vec::new(),
)
};
let json = serde_json::to_string(&reply).expect("response encode");
if writer.write_all(encode_frame(&json).as_bytes()).is_err() {
break;
}
let has_followups = !followups.is_empty();
for followup_msg in followups {
let followup_json = serde_json::to_string(&followup_msg).expect("followup encode");
if writer
.write_all(encode_frame(&followup_json).as_bytes())
.is_err()
{
break 'conn;
}
}
if self.close_after_followups && has_followups {
break;
}
}
}
pub fn serve_one_silent(self) {
let (stream, _peer) = self.listener.accept().expect("accept");
let mut buf = [0u8; 1];
let _ = std::io::Read::read(&mut &stream, &mut buf);
}
}