use std::borrow::{Cow, IntoCow, ToOwned};
use std::io::{Write};
use std::net::{ToSocketAddrs, SocketAddr};
use hyper::{Url};
use hyper::buffer::{BufReader};
use hyper::client::request::{Request};
use hyper::header::{Headers, Header, HeaderFormat, ContentLength, Host};
use hyper::http::{self, Incoming, RawStatus};
use hyper::method::{Method};
use hyper::net::{NetworkConnector, NetworkStream};
use hyper::server::response::{Response};
use hyper::status::{StatusCode};
use hyper::uri::{RequestUri};
use hyper::version::{HttpVersion};
use {SSDPResult, SSDPError};
use header::{HeaderRef, HeaderMut};
use message::{MessageType};
use net::{self};
use receiver::{FromRawSSDP};
const VALID_RESPONSE_CODE: u16 = 200;
const BASE_HOST_URL: &'static str = "http://";
const NOTIFY_METHOD: &'static str = "NOTIFY";
const SEARCH_METHOD: &'static str = "M-SEARCH";
#[derive(Debug, Clone)]
pub struct SSDPMessage {
method: MessageType,
headers: Headers
}
impl SSDPMessage {
pub fn new(message_type: MessageType) -> SSDPMessage {
SSDPMessage{ method: message_type, headers: Headers::new() }
}
pub fn message_type(&self) -> MessageType {
self.method
}
pub fn send<A: ToSocketAddrs, C, S>(&self, connector: &mut C, dst_addr: A) -> SSDPResult<()>
where C: NetworkConnector<Stream=S>, S: Into<Box<NetworkStream + Send>> {
let dst_sock_addr = try!(net::addr_from_trait(dst_addr));
match self.method {
MessageType::Notify => {
send_request(NOTIFY_METHOD, &self.headers, connector, dst_sock_addr)
},
MessageType::Search => {
send_request(SEARCH_METHOD, &self.headers, connector, dst_sock_addr)
},
MessageType::Response => {
let dst_ip_string = dst_sock_addr.ip().to_string();
let dst_port = dst_sock_addr.port();
let net_stream = try!(connector.connect(&dst_ip_string[..], dst_port, "")).into();
send_response(&self.headers, net_stream)
}
}
}
}
#[allow(unused)]
fn send_request<C, S>(method: &str, headers: &Headers, connector: &mut C, dst_addr: SocketAddr)
-> SSDPResult<()> where C: NetworkConnector<Stream=S>, S: Into<Box<NetworkStream + Send>> {
let url = try!(url_from_addr(dst_addr));
let mut request = try!(Request::with_connector(
Method::Extension(method.to_owned()),
url,
connector
));
copy_headers(headers, request.headers_mut());
request.headers_mut().set(ContentLength(0));
try!(request.start()).send();
Ok(())
}
fn send_response<W>(headers: &Headers, mut dst_writer: W) -> SSDPResult<()>
where W: Write {
let mut temp_headers = Headers::new();
copy_headers(headers, &mut temp_headers);
temp_headers.set(ContentLength(0));
let mut response = Response::new(&mut dst_writer as &mut Write, &mut temp_headers);
*response.status_mut() = StatusCode::Ok;
try!(try!(response.start()).end());
Ok(())
}
fn url_from_addr(addr: SocketAddr) -> SSDPResult<Url> {
let str_url = BASE_HOST_URL.chars()
.chain(addr.to_string()[..].chars())
.collect::<String>();
Ok(try!(Url::parse(&str_url[..])))
}
fn copy_headers(src_headers: &Headers, dst_headers: &mut Headers) {
let iter = src_headers.iter();
for view in iter {
dst_headers.set_raw(view.name().to_owned().into_cow(),
vec![view.value_string().into_bytes()]);
}
}
impl HeaderRef for SSDPMessage {
fn get<H>(&self) -> Option<&H> where H: Header + HeaderFormat {
HeaderRef::get::<H>(&self.headers)
}
fn get_raw(&self, name: &str) -> Option<&[Vec<u8>]> {
HeaderRef::get_raw(&self.headers, name)
}
}
impl HeaderMut for SSDPMessage {
fn set<H>(&mut self, value: H) where H: Header + HeaderFormat {
HeaderMut::set(&mut self.headers, value)
}
fn set_raw<K>(&mut self, name: K, value: Vec<Vec<u8>>) where K: Into<Cow<'static, str>> {
HeaderMut::set_raw(&mut self.headers, name, value)
}
}
impl FromRawSSDP for SSDPMessage {
fn raw_ssdp(bytes: &[u8]) -> SSDPResult<SSDPMessage> {
let mut buf_reader = BufReader::new(bytes);
if let Ok(parts) = http::parse_request(&mut buf_reader) {
let message_result = message_from_request(parts);
log_message_result(&message_result, bytes);
message_result
} else if let Ok(parts) = http::parse_response(&mut buf_reader) {
let message_result = message_from_response(parts);
log_message_result(&message_result, bytes);
message_result
} else {
debug!("Received Invalid HTTP: {}", String::from_utf8_lossy(bytes));
Err(SSDPError::InvalidHttp(bytes.to_owned()))
}
}
}
fn log_message_result(result: &SSDPResult<SSDPMessage>, message: &[u8]) {
match *result {
Ok(_) => {
debug!("Received Valid SSDPMessage:\n{}", String::from_utf8_lossy(message))
},
Err(ref e) => {
debug!("Received Invalid SSDPMessage Error: {}", e)
}
}
}
fn message_from_request(parts: Incoming<(Method, RequestUri)>) -> SSDPResult<SSDPMessage> {
let headers = parts.headers;
try!(validate_http_version(parts.version));
try!(validate_http_host(&headers));
match parts.subject {
(Method::Extension(n), RequestUri::Star) => {
match &n[..] {
NOTIFY_METHOD => Ok(SSDPMessage{ method: MessageType::Notify, headers: headers }),
SEARCH_METHOD => Ok(SSDPMessage{ method: MessageType::Search, headers: headers }),
_ => Err(SSDPError::InvalidMethod(n))
}
},
(n, RequestUri::Star) => Err(SSDPError::InvalidMethod(n.to_string())),
(_, RequestUri::AbsolutePath(n)) => Err(SSDPError::InvalidUri(n)),
(_, RequestUri::Authority(n)) => Err(SSDPError::InvalidUri(n)),
(_, RequestUri::AbsoluteUri(n)) => Err(SSDPError::InvalidUri(n.serialize()))
}
}
fn message_from_response(parts: Incoming<RawStatus>) -> SSDPResult<SSDPMessage> {
let RawStatus(status_code, _) = parts.subject;
let headers = parts.headers;
try!(validate_http_version(parts.version));
try!(validate_response_code(status_code));
Ok(SSDPMessage{ method: MessageType::Response, headers: headers })
}
fn validate_http_version(version: HttpVersion) -> SSDPResult<()> {
if version != HttpVersion::Http11 {
Err(SSDPError::InvalidHttpVersion)
} else { Ok(()) }
}
fn validate_http_host<T>(headers: T) -> SSDPResult<()>
where T: HeaderRef {
if headers.get::<Host>().is_none() {
Err(SSDPError::MissingHeader(Host::header_name()))
} else { Ok(()) }
}
fn validate_response_code(code: u16) -> SSDPResult<()> {
if code != VALID_RESPONSE_CODE {
Err(SSDPError::ResponseCode(code))
} else { Ok(()) }
}
#[cfg(test)]
mod mocks {
use std::cell::{RefCell};
use std::io::{self, Read, Write, ErrorKind};
use std::marker::{Reflect};
use std::net::{SocketAddr};
use std::sync::mpsc::{self, Sender, Receiver};
use hyper::error::{self};
use hyper::net::{NetworkConnector, NetworkStream, ContextVerifier};
pub struct MockConnector {
pub receivers: RefCell<Vec<Receiver<Vec<u8>>>>
}
impl MockConnector {
pub fn new() -> MockConnector {
MockConnector{ receivers: RefCell::new(Vec::new()) }
}
}
impl NetworkConnector for MockConnector {
type Stream = MockStream;
fn connect(&self, _: &str, _: u16, _: &str) -> error::Result<Self::Stream> {
let (send, recv) = mpsc::channel();
self.receivers.borrow_mut().push(recv);
Ok(MockStream{ sender: send })
}
fn set_ssl_verifier(&mut self, _: ContextVerifier) { }
}
pub struct MockStream {
sender: Sender<Vec<u8>>
}
impl NetworkStream for MockStream {
fn peer_addr(&mut self) -> io::Result<SocketAddr> {
Err(io::Error::new(ErrorKind::AddrNotAvailable, ""))
}
}
unsafe impl Send for MockStream { }
impl Reflect for MockStream { }
impl Read for MockStream {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(ErrorKind::ConnectionAborted, ""))
}
}
impl Write for MockStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut buffer = vec![0u8; buf.len()];
let mut found = false;
for (src, dst) in buf.iter().zip(buffer.iter_mut()) {
if *src == b'/' && !found && buf[0] != b'H' {
*dst = b'*';
found = true;
} else {
*dst = *src;
}
}
self.sender.send(buffer).unwrap();
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> { Ok(()) }
}
}
#[cfg(test)]
mod tests {
mod send {
use std::sync::mpsc::{Receiver};
use super::super::mocks::{MockConnector};
use super::super::{SSDPMessage};
use message::{MessageType};
fn join_buffers(recv_list: &[Receiver<Vec<u8>>]) -> Vec<u8> {
let mut buffer = Vec::new();
for recv in recv_list {
for recv_buf in recv {
buffer.push_all(&recv_buf[..])
}
}
buffer
}
#[test]
fn positive_search_method_line() {
let message = SSDPMessage::new(MessageType::Search);
let mut connector = MockConnector::new();
message.send(&mut connector, ("127.0.0.1", 0)).unwrap();
let sent_message = String::from_utf8(
join_buffers(&*connector.receivers.borrow())
).unwrap();
assert_eq!(&sent_message[..19], "M-SEARCH * HTTP/1.1");
}
#[test]
fn positive_notify_method_line() {
let message = SSDPMessage::new(MessageType::Notify);
let mut connector = MockConnector::new();
message.send(&mut connector, ("127.0.0.1", 0)).unwrap();
let sent_message = String::from_utf8(
join_buffers(&*connector.receivers.borrow())
).unwrap();
assert_eq!(&sent_message[..17], "NOTIFY * HTTP/1.1");
}
#[test]
fn positive_response_method_line() {
let message = SSDPMessage::new(MessageType::Response);
let mut connector = MockConnector::new();
message.send(&mut connector, ("127.0.0.1", 0)).unwrap();
let sent_message = String::from_utf8(
join_buffers(&*connector.receivers.borrow())
).unwrap();
assert_eq!(&sent_message[..15], "HTTP/1.1 200 OK");
}
#[test]
fn positive_host_header() {
let message = SSDPMessage::new(MessageType::Search);
let mut connector = MockConnector::new();
message.send(&mut connector, ("127.0.0.1", 0)).unwrap();
let sent_message = String::from_utf8(
join_buffers(&*connector.receivers.borrow())
).unwrap();
assert!(sent_message.contains("Host: 127.0.0.1:0"));
}
}
mod parse {
use super::super::{SSDPMessage};
use header::{HeaderRef};
use receiver::{FromRawSSDP};
#[test]
fn positive_valid_http() {
let raw_message = "NOTIFY * HTTP/1.1\r\nHOST: 192.168.1.1\r\n\r\n";
SSDPMessage::raw_ssdp(raw_message.as_bytes()).unwrap();
}
#[test]
fn positive_intact_header() {
let raw_message = "NOTIFY * HTTP/1.1\r\nHOST: 192.168.1.1\r\n\r\n";
let message = SSDPMessage::raw_ssdp(raw_message.as_bytes()).unwrap();
assert_eq!(&message.get_raw("Host").unwrap()[0][..], &b"192.168.1.1"[..]);
}
#[test]
#[should_panic]
fn negative_http_version() {
let raw_message = "NOTIFY * HTTP/2.0\r\nHOST: 192.168.1.1\r\n\r\n";
SSDPMessage::raw_ssdp(raw_message.as_bytes()).unwrap();
}
#[test]
#[should_panic]
fn negative_no_host() {
let raw_message = "NOTIFY * HTTP/1.1\r\n\r\n";
SSDPMessage::raw_ssdp(raw_message.as_bytes()).unwrap();
}
#[test]
#[should_panic]
fn negative_path_included() {
let raw_message = "NOTIFY / HTTP/1.1\r\n\r\n";
SSDPMessage::raw_ssdp(raw_message.as_bytes()).unwrap();
}
}
}