use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use crossbeam_channel::{bounded, Receiver, Sender};
use parking_lot::Mutex;
use crate::error::{Error, Result};
use crate::sip::message::Message;
use crate::transport::SipTransport;
#[derive(Debug, Clone)]
pub struct Response {
pub code: u16,
pub reason: String,
}
impl Response {
pub fn new(code: u16, reason: &str) -> Self {
Self {
code,
reason: reason.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct SentMessage {
pub method: String,
pub headers: Option<HashMap<String, String>>,
}
struct Inner {
responses: Vec<Response>,
sequence: Vec<Response>,
seq_index: usize,
fail_remain: u32,
sent: Vec<SentMessage>,
keepalives: u32,
closed: bool,
advertised: Option<std::net::SocketAddr>,
early_sdp: Option<String>,
invite_func: Option<Arc<dyn Fn() + Send + Sync>>,
drop_handler: Option<Arc<dyn Fn() + Send + Sync>>,
incoming_handler: Option<Arc<dyn Fn(String, String) + Send + Sync>>,
#[allow(clippy::type_complexity)]
dialog_invite_handler:
Option<Arc<dyn Fn(Arc<dyn crate::dialog::Dialog>, String, String, String) + Send + Sync>>,
info_dtmf_handler: Option<Arc<dyn Fn(String, String) + Send + Sync>>,
mwi_notify_handler: Option<Arc<dyn Fn(String) + Send + Sync>>,
#[allow(clippy::type_complexity)]
message_handler: Option<Arc<dyn Fn(String, String, String) + Send + Sync>>,
#[allow(clippy::type_complexity)]
subscription_notify_handler:
Option<Arc<dyn Fn(String, String, String, String, String) + Send + Sync>>,
response_watchers: HashMap<u16, Vec<Sender<bool>>>,
}
pub struct MockTransport {
inner: Mutex<Inner>,
response_ready_tx: Sender<()>,
response_ready_rx: Receiver<()>,
}
impl MockTransport {
pub fn new() -> Self {
let (tx, rx) = bounded(1);
Self {
inner: Mutex::new(Inner {
responses: Vec::new(),
sequence: Vec::new(),
seq_index: 0,
fail_remain: 0,
sent: Vec::new(),
keepalives: 0,
closed: false,
advertised: None,
early_sdp: None,
invite_func: None,
drop_handler: None,
incoming_handler: None,
dialog_invite_handler: None,
info_dtmf_handler: None,
mwi_notify_handler: None,
message_handler: None,
subscription_notify_handler: None,
response_watchers: HashMap::new(),
}),
response_ready_tx: tx,
response_ready_rx: rx,
}
}
pub fn respond_with(&self, code: u16, reason: &str) {
{
let mut inner = self.inner.lock();
inner.responses.push(Response::new(code, reason));
}
let _ = self.response_ready_tx.try_send(());
}
pub fn respond_sequence(&self, responses: Vec<Response>) {
{
let mut inner = self.inner.lock();
inner.sequence.extend(responses);
inner.seq_index = 0;
}
let _ = self.response_ready_tx.try_send(());
}
pub fn fail_next(&self, n: u32) {
self.inner.lock().fail_remain = n;
}
pub fn on_invite<F: Fn() + Send + Sync + 'static>(&self, f: F) {
self.inner.lock().invite_func = Some(Arc::new(f));
}
pub fn simulate_drop(&self) {
let handler = self.inner.lock().drop_handler.clone();
if let Some(h) = handler {
h();
}
}
pub fn simulate_invite(&self, from: &str, to: &str) {
let handler = self.inner.lock().incoming_handler.clone();
if let Some(h) = handler {
h(from.into(), to.into());
}
}
pub fn simulate_dialog_invite(&self, from: &str, to: &str, remote_sdp: &str) {
let handler = self.inner.lock().dialog_invite_handler.clone();
if let Some(h) = handler {
let dlg = Arc::new(crate::mock::dialog::MockDialog::new());
h(
dlg as Arc<dyn crate::dialog::Dialog>,
from.into(),
to.into(),
remote_sdp.into(),
);
}
}
pub fn simulate_info_dtmf(&self, call_id: &str, digit: &str) {
let handler = self.inner.lock().info_dtmf_handler.clone();
if let Some(h) = handler {
h(call_id.into(), digit.into());
}
}
pub fn simulate_mwi_notify(&self, body: &str) {
let handler = self.inner.lock().mwi_notify_handler.clone();
if let Some(h) = handler {
h(body.into());
}
}
pub fn simulate_message(&self, from: &str, content_type: &str, body: &str) {
let handler = self.inner.lock().message_handler.clone();
if let Some(h) = handler {
h(from.into(), content_type.into(), body.into());
}
}
pub fn simulate_subscription_notify(
&self,
event: &str,
content_type: &str,
body: &str,
subscription_state: &str,
from_uri: &str,
) {
let handler = self.inner.lock().subscription_notify_handler.clone();
if let Some(h) = handler {
h(
event.into(),
content_type.into(),
body.into(),
subscription_state.into(),
from_uri.into(),
);
}
}
pub fn closed(&self) -> bool {
self.inner.lock().closed
}
pub fn count_sent(&self, method: &str) -> usize {
let inner = self.inner.lock();
inner.sent.iter().filter(|m| m.method == method).count()
}
pub fn count_keepalives(&self) -> u32 {
self.inner.lock().keepalives
}
pub fn last_sent(&self, method: &str) -> Option<SentMessage> {
let inner = self.inner.lock();
inner
.sent
.iter()
.rev()
.find(|m| m.method == method)
.cloned()
}
pub fn wait_for_response(&self, code: u16, timeout: Duration) -> Receiver<bool> {
let (tx, rx) = bounded(1);
self.inner
.lock()
.response_watchers
.entry(code)
.or_default()
.push(tx.clone());
std::thread::spawn(move || {
std::thread::sleep(timeout);
let _ = tx.try_send(false);
});
rx
}
pub fn set_advertised_addr(&self, addr: std::net::SocketAddr) {
self.inner.lock().advertised = Some(addr);
}
pub fn set_early_sdp(&self, sdp: &str) {
self.inner.lock().early_sdp = Some(sdp.to_string());
}
fn await_response(&self, timeout: Duration) -> Result<(u16, String)> {
let deadline = std::time::Instant::now() + timeout;
loop {
{
let mut inner = self.inner.lock();
if inner.seq_index < inner.sequence.len() {
let resp = inner.sequence[inner.seq_index].clone();
inner.seq_index += 1;
return Ok((resp.code, resp.reason));
}
if !inner.responses.is_empty() {
let resp = inner.responses.remove(0);
return Ok((resp.code, resp.reason));
}
}
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return Err(Error::Other("mock: response timeout".into()));
}
let _ = self.response_ready_rx.recv_timeout(remaining);
}
}
}
impl Default for MockTransport {
fn default() -> Self {
Self::new()
}
}
impl SipTransport for MockTransport {
fn send_request(
&self,
method: &str,
headers: Option<&HashMap<String, String>>,
timeout: Duration,
) -> Result<Message> {
let invite_fn = {
let mut inner = self.inner.lock();
inner.sent.push(SentMessage {
method: method.into(),
headers: headers.cloned(),
});
if inner.fail_remain > 0 {
inner.fail_remain -= 1;
return Err(Error::Other("transport error".into()));
}
if method == "INVITE" {
inner.invite_func.clone()
} else {
None
}
};
if let Some(f) = invite_fn {
f();
}
let (code, reason) = self.await_response(timeout)?;
let mut msg = Message::new_response(code, &reason);
msg.set_header("CSeq", &format!("1 {}", method));
Ok(msg)
}
fn read_response(&self, timeout: Duration) -> Result<Message> {
let (code, reason) = self.await_response(timeout)?;
Ok(Message::new_response(code, &reason))
}
fn send_keepalive(&self) -> Result<()> {
self.inner.lock().keepalives += 1;
Ok(())
}
fn respond(&self, code: u16, _reason: &str) {
let watchers = {
let mut inner = self.inner.lock();
inner.response_watchers.remove(&code).unwrap_or_default()
};
for ch in watchers {
let _ = ch.try_send(true);
}
}
fn on_drop(&self, f: Box<dyn Fn() + Send + Sync>) {
self.inner.lock().drop_handler = Some(Arc::from(f));
}
fn on_incoming(&self, f: Box<dyn Fn(String, String) + Send + Sync>) {
self.inner.lock().incoming_handler = Some(Arc::from(f));
}
#[allow(clippy::type_complexity)]
fn on_dialog_invite(
&self,
f: Box<dyn Fn(Arc<dyn crate::dialog::Dialog>, String, String, String) + Send + Sync>,
) {
self.inner.lock().dialog_invite_handler = Some(Arc::from(f));
}
fn on_info_dtmf(&self, f: Box<dyn Fn(String, String) + Send + Sync>) {
self.inner.lock().info_dtmf_handler = Some(Arc::from(f));
}
fn send_subscribe(
&self,
_uri: &str,
_headers: &HashMap<String, String>,
timeout: Duration,
) -> Result<Message> {
{
let mut inner = self.inner.lock();
inner.sent.push(SentMessage {
method: "SUBSCRIBE".into(),
headers: None,
});
if inner.fail_remain > 0 {
inner.fail_remain -= 1;
return Err(Error::Other("transport error".into()));
}
}
let (code, reason) = self.await_response(timeout)?;
let mut msg = Message::new_response(code, &reason);
msg.set_header("CSeq", "1 SUBSCRIBE");
Ok(msg)
}
fn on_mwi_notify(&self, f: Box<dyn Fn(String) + Send + Sync>) {
self.inner.lock().mwi_notify_handler = Some(Arc::from(f));
}
fn send_message(
&self,
_target: &str,
_content_type: &str,
_body: &[u8],
timeout: Duration,
) -> Result<()> {
{
let mut inner = self.inner.lock();
inner.sent.push(SentMessage {
method: "MESSAGE".into(),
headers: None,
});
if inner.fail_remain > 0 {
inner.fail_remain -= 1;
return Err(Error::Other("transport error".into()));
}
}
let (code, _reason) = self.await_response(timeout)?;
if (200..300).contains(&code) {
Ok(())
} else {
Err(Error::Other(format!("MESSAGE rejected: {}", code)))
}
}
fn on_message(&self, f: Box<dyn Fn(String, String, String) + Send + Sync>) {
self.inner.lock().message_handler = Some(Arc::from(f));
}
fn on_subscription_notify(
&self,
f: Box<dyn Fn(String, String, String, String, String) + Send + Sync>,
) {
self.inner.lock().subscription_notify_handler = Some(Arc::from(f));
}
fn dial(
&self,
_target: &str,
_local_sdp: &[u8],
timeout: Duration,
_opts: &crate::config::DialOptions,
) -> Result<crate::transport::DialResult> {
{
let mut inner = self.inner.lock();
inner.sent.push(SentMessage {
method: "INVITE".into(),
headers: None,
});
if inner.fail_remain > 0 {
inner.fail_remain -= 1;
return Err(Error::Other("transport error".into()));
}
}
let (code, reason) = self.await_response(timeout)?;
if code >= 300 {
return Err(Error::Other(format!("INVITE failed: {} {}", code, reason)));
}
let early_sdp = self.inner.lock().early_sdp.take();
let dlg = Arc::new(crate::mock::dialog::MockDialog::new());
Ok(crate::transport::DialResult {
dialog: dlg as Arc<dyn crate::dialog::Dialog>,
remote_sdp: String::new(),
early_sdp,
})
}
fn advertised_addr(&self) -> Option<std::net::SocketAddr> {
self.inner.lock().advertised
}
fn close(&self) -> Result<()> {
self.inner.lock().closed = true;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn respond_with_returns_queued_response() {
let tr = MockTransport::new();
tr.respond_with(200, "OK");
let (code, reason) = tr.await_response(Duration::from_secs(1)).unwrap();
assert_eq!(code, 200);
assert_eq!(reason, "OK");
}
#[test]
fn respond_sequence_returns_in_order() {
let tr = MockTransport::new();
tr.respond_sequence(vec![
Response::new(100, "Trying"),
Response::new(180, "Ringing"),
Response::new(200, "OK"),
]);
let (c1, _) = tr.await_response(Duration::from_secs(1)).unwrap();
let (c2, _) = tr.await_response(Duration::from_secs(1)).unwrap();
let (c3, _) = tr.await_response(Duration::from_secs(1)).unwrap();
assert_eq!(c1, 100);
assert_eq!(c2, 180);
assert_eq!(c3, 200);
}
#[test]
fn fail_next_causes_errors() {
let tr = MockTransport::new();
tr.fail_next(2);
tr.respond_with(200, "OK");
let r1 = tr.send_request("REGISTER", None, Duration::from_secs(1));
assert!(r1.is_err());
let r2 = tr.send_request("REGISTER", None, Duration::from_secs(1));
assert!(r2.is_err());
let r3 = tr.send_request("REGISTER", None, Duration::from_secs(1));
assert!(r3.is_ok());
}
#[test]
fn count_sent_tracks_methods() {
let tr = MockTransport::new();
tr.respond_with(200, "OK");
tr.respond_with(200, "OK");
let _ = tr.send_request("REGISTER", None, Duration::from_secs(1));
let _ = tr.send_request("INVITE", None, Duration::from_secs(1));
assert_eq!(tr.count_sent("REGISTER"), 1);
assert_eq!(tr.count_sent("INVITE"), 1);
}
#[test]
fn simulate_drop_fires_handler() {
let tr = Arc::new(MockTransport::new());
let dropped = Arc::new(Mutex::new(false));
let dropped_clone = Arc::clone(&dropped);
tr.on_drop(Box::new(move || {
*dropped_clone.lock() = true;
}));
tr.simulate_drop();
assert!(*dropped.lock());
}
#[test]
fn close_sets_flag() {
let tr = MockTransport::new();
assert!(!tr.closed());
tr.close().unwrap();
assert!(tr.closed());
}
#[test]
fn send_keepalive_increments() {
let tr = MockTransport::new();
tr.send_keepalive().unwrap();
tr.send_keepalive().unwrap();
assert_eq!(tr.count_keepalives(), 2);
}
#[test]
fn simulate_message_fires_handler() {
let tr = Arc::new(MockTransport::new());
let received = Arc::new(Mutex::new(String::new()));
let received_clone = Arc::clone(&received);
tr.on_message(Box::new(move |_from, _ct, body| {
*received_clone.lock() = body;
}));
tr.simulate_message("sip:1001@pbx.local", "text/plain", "Hello!");
assert_eq!(*received.lock(), "Hello!");
}
#[test]
fn send_message_records_sent() {
let tr = MockTransport::new();
tr.respond_with(200, "OK");
let result = tr.send_message(
"sip:1002@pbx.local",
"text/plain",
b"Hi",
Duration::from_secs(1),
);
assert!(result.is_ok());
assert_eq!(tr.count_sent("MESSAGE"), 1);
}
#[test]
fn send_message_rejected() {
let tr = MockTransport::new();
tr.respond_with(403, "Forbidden");
let result = tr.send_message(
"sip:1002@pbx.local",
"text/plain",
b"Hi",
Duration::from_secs(1),
);
assert!(result.is_err());
}
}