1use crate::{Error, SearchTarget};
2
3use futures_core::stream::Stream;
4use genawaiter::sync::{Co, Gen};
5use std::{net::SocketAddr, time::Duration};
6use tokio::net::UdpSocket;
7
8const INSUFFICIENT_BUFFER_MSG: &str = "buffer size too small, udp packets lost";
9const DEFAULT_SEARCH_TTL: u32 = 2;
10
11#[derive(Debug)]
12pub struct SearchResponse {
14 location: String,
15 st: SearchTarget,
16 usn: String,
17 server: String,
18}
19
20impl SearchResponse {
21 pub fn location(&self) -> &str {
23 &self.location
24 }
25 pub fn search_target(&self) -> &SearchTarget {
27 &self.st
28 }
29 pub fn usn(&self) -> &str {
31 &self.usn
32 }
33 pub fn server(&self) -> &str {
35 &self.server
36 }
37}
38
39#[cfg(not(windows))]
40async fn get_bind_addr() -> Result<SocketAddr, std::io::Error> {
41 Ok(([0, 0, 0, 0], 0).into())
42}
43
44#[cfg(windows)]
45async fn get_bind_addr() -> Result<SocketAddr, std::io::Error> {
46 let any: SocketAddr = ([0, 0, 0, 0], 0).into();
49 let socket = UdpSocket::bind(any).await?;
50 let googledns: SocketAddr = ([8, 8, 8, 8], 80).into();
51 socket.connect(googledns).await?;
52 let bind_addr = socket.local_addr()?;
53
54 Ok(bind_addr)
55}
56
57pub async fn search(
61 search_target: &SearchTarget,
62 timeout: Duration,
63 mx: usize,
64 ttl: Option<u32>,
65) -> Result<impl Stream<Item = Result<SearchResponse, Error>>, Error> {
66 let bind_addr: SocketAddr = get_bind_addr().await?;
67 let broadcast_address: SocketAddr = ([239, 255, 255, 250], 1900).into();
68
69 let socket = UdpSocket::bind(&bind_addr).await?;
70 socket.set_multicast_ttl_v4(ttl.unwrap_or(DEFAULT_SEARCH_TTL)).ok();
71
72 let msg = format!(
73 "M-SEARCH * HTTP/1.1\r
74Host:239.255.255.250:1900\r
75Man:\"ssdp:discover\"\r
76ST: {}\r
77MX: {}\r\n\r\n",
78 search_target, mx
79 );
80 socket.send_to(msg.as_bytes(), &broadcast_address).await?;
81
82 Ok(Gen::new(move |co| socket_stream(socket, timeout, co)))
83}
84
85macro_rules! yield_try {
86 ( $co:expr => $expr:expr ) => {
87 match $expr {
88 Ok(val) => val,
89 Err(e) => {
90 $co.yield_(Err(e.into())).await;
91 continue;
92 }
93 }
94 };
95}
96
97async fn socket_stream(
98 socket: UdpSocket,
99 timeout: Duration,
100 co: Co<Result<SearchResponse, Error>>,
101) {
102 loop {
103 let mut buf = [0u8; 2048];
104 let text = match tokio::time::timeout(timeout, socket.recv(&mut buf)).await {
105 Err(_) => break,
106 Ok(res) => match res {
107 Ok(read) if read == 2048 => {
108 log::warn!("{}", INSUFFICIENT_BUFFER_MSG);
109 continue;
110 }
111 Ok(read) => yield_try!(co => std::str::from_utf8(&buf[..read])),
112 Err(e) => {
113 co.yield_(Err(e.into())).await;
114 continue;
115 }
116 },
117 };
118
119 let headers = yield_try!(co => parse_headers(text));
120
121 let mut location = None;
122 let mut st = None;
123 let mut usn = None;
124 let mut server = None;
125
126 for (header, value) in headers {
127 if header.eq_ignore_ascii_case("location") {
128 location = Some(value);
129 } else if header.eq_ignore_ascii_case("st") {
130 st = Some(value);
131 } else if header.eq_ignore_ascii_case("usn") {
132 usn = Some(value);
133 } else if header.eq_ignore_ascii_case("server") {
134 server = Some(value);
135 }
136 }
137
138 let location = yield_try!(co => location
139 .ok_or(Error::MissingHeader("location")))
140 .to_string();
141 let st = yield_try!(co => yield_try!(co => st.ok_or(Error::MissingHeader("st"))).parse::<SearchTarget>());
142 let usn = yield_try!(co => usn.ok_or(Error::MissingHeader("urn"))).to_string();
143 let server = yield_try!(co => server.ok_or(Error::MissingHeader("server"))).to_string();
144
145 co.yield_(Ok(SearchResponse {
146 location,
147 st,
148 usn,
149 server,
150 }))
151 .await;
152 }
153}
154
155fn parse_headers(response: &str) -> Result<impl Iterator<Item = (&str, &str)>, Error> {
156 let mut response = response.split("\r\n");
157 let status_code = response
158 .next()
159 .ok_or(Error::InvalidHTTP("http response is empty"))?
160 .trim_start_matches("HTTP/1.1 ")
161 .chars()
162 .take_while(|x| x.is_numeric())
163 .collect::<String>()
164 .parse::<u32>()
165 .map_err(|_| Error::InvalidHTTP("status code is not a number"))?;
166
167 if status_code != 200 {
168 return Err(Error::HTTPError(status_code));
169 }
170
171 let iter = response.filter_map(|l| {
172 let mut split = l.splitn(2, ':');
173 match (split.next(), split.next()) {
174 (Some(header), Some(value)) => Some((header, value.trim())),
175 _ => None,
176 }
177 });
178
179 Ok(iter)
180}