1use std::{fmt::Display, time::Duration};
2
3use futures_util::{pin_mut, stream::StreamExt};
4use mdns::RecordKind;
5use reqwest::{Client, StatusCode};
6use serde::de::DeserializeOwned;
7pub use url::Url;
8
9pub mod capabilities;
10use capabilities::ScannerCapabilities;
11
12pub mod status;
13use status::ScannerStatus;
14
15pub mod settings;
16use settings::ScanSettings;
17
18const SERVICE_NAME: &str = "_uscan._tcp.local";
19
20#[derive(Debug)]
21pub struct Scanner {
22 base_url: Url,
23 http_client: Client,
24}
25
26#[derive(Debug)]
27pub enum Error {
28 Http(reqwest::Error),
29 Xml(serde_xml_rs::Error),
30 UnexpectedStatusCode(StatusCode),
31 LocationHeader,
32}
33
34#[derive(Debug)]
35pub struct ScanJob {
36 job_url: Url,
37 http_client: Client,
38}
39
40#[derive(Debug)]
41pub struct ScannerService {
42 base_url: Url,
43 name: String,
44}
45
46#[derive(Debug)]
47pub enum DiscoverError {
48 Mdns(mdns::Error),
49}
50
51impl Scanner {
52 pub fn new(base_url: Url) -> Self {
64 Self {
65 base_url,
66 http_client: Client::new(),
67 }
68 }
69
70 pub async fn capabilities(&self) -> Result<ScannerCapabilities, Error> {
71 self.send_get_request(self.extended_url(&["ScannerCapabilities"]))
72 .await
73 }
74
75 pub async fn status(&self) -> Result<ScannerStatus, Error> {
76 self.send_get_request(self.extended_url(&["ScannerStatus"]))
77 .await
78 }
79
80 pub async fn scan(&self, settings: &ScanSettings) -> Result<ScanJob, Error> {
81 let url = self.extended_url(&["ScanJobs"]);
82
83 let request_body = serde_xml_rs::to_string(settings).map_err(Error::Xml)?;
84
85 let response = self
86 .http_client
87 .post(url)
88 .header("Content-Type", "text/xml")
89 .body(request_body)
90 .send()
91 .await
92 .map_err(Error::Http)?;
93
94 let status_code = response.status();
95 if status_code != StatusCode::CREATED {
96 return Err(Error::UnexpectedStatusCode(status_code));
97 }
98
99 let location: Url = response
100 .headers()
101 .get("location")
102 .ok_or(Error::LocationHeader)?
103 .to_str()
104 .map_err(|_| Error::LocationHeader)?
105 .parse()
106 .map_err(|_| Error::LocationHeader)?;
107
108 Ok(ScanJob {
109 job_url: location,
110 http_client: self.http_client.clone(),
111 })
112 }
113
114 fn extended_url(&self, segments: &[&'static str]) -> Url {
115 let mut url = self.base_url.clone();
116 url.path_segments_mut()
117 .expect("Invalid base URL")
118 .extend(segments);
119
120 url
121 }
122
123 async fn send_get_request<T>(&self, url: Url) -> Result<T, Error>
124 where
125 T: DeserializeOwned,
126 {
127 let response = self
128 .http_client
129 .get(url)
130 .send()
131 .await
132 .map_err(Error::Http)?;
133
134 let response_body = response.text().await.map_err(Error::Http)?;
135
136 serde_xml_rs::from_str(&response_body).map_err(Error::Xml)
137 }
138}
139
140impl Display for Error {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 match self {
143 Error::Http(err) => write!(f, "http error: {}", err),
144 Error::Xml(err) => write!(f, "xml error: {}", err),
145 Error::UnexpectedStatusCode(code) => write!(f, "unexpected http status code {}", code),
146 Error::LocationHeader => write!(f, "missing or invalid `Location` header in response"),
147 }
148 }
149}
150
151impl std::error::Error for Error {}
152
153impl ScanJob {
154 pub async fn next_document(&self) -> Result<Option<Vec<u8>>, Error> {
155 let url = self.extended_url(&["NextDocument"]);
156
157 let response = self
158 .http_client
159 .get(url)
160 .send()
161 .await
162 .map_err(Error::Http)?;
163
164 let status_code = response.status();
165 if status_code == StatusCode::NOT_FOUND {
166 return Ok(None);
167 } else if status_code != StatusCode::OK {
168 return Err(Error::UnexpectedStatusCode(status_code));
169 }
170
171 let bytes = response.bytes().await.map_err(Error::Http)?;
172 Ok(Some(bytes.to_vec()))
173 }
174
175 fn extended_url(&self, segments: &[&'static str]) -> Url {
176 let mut url = self.job_url.clone();
177 url.path_segments_mut()
178 .expect("Invalid base URL")
179 .extend(segments);
180
181 url
182 }
183
184 pub fn job_url(&self) -> &Url {
185 &self.job_url
186 }
187}
188
189impl ScannerService {
190 pub fn url(&self) -> &Url {
192 &self.base_url
193 }
194
195 pub fn name(&self) -> &str {
197 &self.name
198 }
199}
200
201impl From<&ScannerService> for Scanner {
202 fn from(value: &ScannerService) -> Self {
203 Self {
204 base_url: value.base_url.clone(),
205 http_client: Client::new(),
206 }
207 }
208}
209
210impl From<ScannerService> for Scanner {
211 fn from(value: ScannerService) -> Self {
212 Self {
213 base_url: value.base_url,
214 http_client: Client::new(),
215 }
216 }
217}
218
219impl Display for DiscoverError {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 match self {
222 Self::Mdns(err) => write!(f, "mDNS error: {}", err),
223 }
224 }
225}
226
227impl std::error::Error for DiscoverError {}
228
229pub async fn discover(timeout: Duration) -> Result<Vec<ScannerService>, DiscoverError> {
231 let mdns_stream = mdns::discover::all(SERVICE_NAME, timeout)
232 .map_err(DiscoverError::Mdns)?
233 .listen();
234 pin_mut!(mdns_stream);
235
236 let services = match mdns_stream.next().await {
237 Some(Ok(response)) => {
238 response
239 .records()
240 .filter_map(|record| {
241 if record.name == SERVICE_NAME {
242 match &record.kind {
243 RecordKind::PTR(ptr_record) => {
244 let txt_record = response.records().find_map(|record| {
248 if &record.name == ptr_record {
249 match &record.kind {
250 RecordKind::TXT(txt) => Some(txt),
251 _ => None,
252 }
253 } else {
254 None
255 }
256 })?;
257
258 let rs = txt_record.iter().find_map(|item| {
260 let (key, value) = item.split_once('=')?;
261
262 if key == "rs" {
263 Some(value)
264 } else {
265 None
266 }
267 })?;
268
269 let ty = txt_record.iter().find_map(|item| {
271 let (key, value) = item.split_once('=')?;
272
273 if key == "ty" {
274 Some(value)
275 } else {
276 None
277 }
278 })?;
279
280 let (srv_record, port) = response.records().find_map(|record| {
282 if &record.name == ptr_record {
283 match &record.kind {
284 RecordKind::SRV { target, port, .. } => {
285 Some((target, port))
286 }
287 _ => None,
288 }
289 } else {
290 None
291 }
292 })?;
293
294 let ip_addr = response.records().find_map(|record| {
296 if &record.name == srv_record {
297 match &record.kind {
298 RecordKind::A(ip_addr) => Some(ip_addr),
299 _ => None,
300 }
301 } else {
302 None
303 }
304 })?;
305
306 let url =
307 Url::parse(&format!("http://{}:{}/{}", ip_addr, port, rs))
308 .ok()?;
309
310 Some(ScannerService {
311 base_url: url,
312 name: ty.to_owned(),
313 })
314 }
315 _ => None,
316 }
317 } else {
318 None
319 }
320 })
321 .collect::<Vec<_>>()
322 }
323 Some(Err(err)) => return Err(DiscoverError::Mdns(err)),
324 _ => {
325 vec![]
326 }
327 };
328
329 Ok(services)
330}