1use btleplug::{
2 api::{Central, CentralEvent, Manager as _, Peripheral as _},
3 platform::{Adapter, Manager, Peripheral, PeripheralId},
4 Error,
5};
6use futures::{Stream, StreamExt};
7use std::{
8 collections::HashSet,
9 pin::Pin,
10 sync::{
11 atomic::{AtomicBool, Ordering},
12 Arc, Mutex, RwLock, Weak,
13 },
14 time::{Duration, Instant},
15};
16use stream_cancel::{Trigger, Valved};
17use tokio::sync::broadcast::{self, Sender};
18use tokio_stream::wrappers::BroadcastStream;
19use uuid::Uuid;
20
21use crate::{Device, DeviceEvent};
22
23#[derive(Debug, Clone, Hash, Eq, PartialEq)]
24pub enum Filter {
25 Address(String),
26 Characteristic(Uuid),
27 Name(String),
28 Rssi(i16),
29 Service(Uuid),
30}
31
32#[derive(Default)]
33pub struct ScanConfig {
34 adapter_index: usize,
36 filters: Vec<Filter>,
38 address_filter: Option<Box<dyn Fn(&str, &str) -> bool + Send + Sync>>,
40 name_filter: Option<Box<dyn Fn(&str, &str) -> bool + Send + Sync>>,
42 rssi_filter: Option<Box<dyn Fn(i16, i16) -> bool + Send + Sync>>,
44 service_filter: Option<Box<dyn Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync>>,
46 characteristics_filter: Option<Box<dyn Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync>>,
48 max_results: Option<usize>,
50 timeout: Option<Duration>,
52 force_disconnect: bool,
54}
55
56impl ScanConfig {
57 #[inline]
59 pub fn adapter_index(mut self, index: usize) -> Self {
60 self.adapter_index = index;
61 self
62 }
63
64 #[inline]
65 pub fn with_filters(mut self, filters: &[Filter]) -> Self {
66 self.filters.extend_from_slice(filters);
67 self
68 }
69
70 #[inline]
72 pub fn filter_by_address(
73 mut self,
74 func: impl Fn(&str, &str) -> bool + Send + Sync + 'static,
75 ) -> Self {
76 self.address_filter = Some(Box::new(func));
77 self
78 }
79
80 #[inline]
82 pub fn filter_by_name(
83 mut self,
84 func: impl Fn(&str, &str) -> bool + Send + Sync + 'static,
85 ) -> Self {
86 self.name_filter = Some(Box::new(func));
87 self
88 }
89
90 #[inline]
91 pub fn filter_by_rssi(
92 mut self,
93 func: impl Fn(i16, i16) -> bool + Send + Sync + 'static,
94 ) -> Self {
95 self.rssi_filter = Some(Box::new(func));
96 self
97 }
98
99 #[inline]
100 pub fn filter_by_service(
101 mut self,
102 func: impl Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync + 'static,
103 ) -> Self {
104 self.service_filter = Some(Box::new(func));
105 self
106 }
107
108 #[inline]
110 pub fn filter_by_characteristics(
111 mut self,
112 func: impl Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync + 'static,
113 ) -> Self {
114 self.characteristics_filter = Some(Box::new(func));
115 self
116 }
117
118 #[inline]
120 pub fn stop_after_matches(mut self, max_results: usize) -> Self {
121 self.max_results = Some(max_results);
122 self
123 }
124
125 #[inline]
127 pub fn stop_after_first_match(self) -> Self {
128 self.stop_after_matches(1)
129 }
130
131 #[inline]
133 pub fn stop_after_timeout(mut self, timeout: Duration) -> Self {
134 self.timeout = Some(timeout);
135 self
136 }
137
138 #[inline]
139 pub fn force_disconnect(mut self, force_disconnect: bool) -> Self {
140 self.force_disconnect = force_disconnect;
141 self
142 }
143
144 #[inline]
146 pub fn require_name(self) -> Self {
147 if self.name_filter.is_none() {
148 self.filter_by_name(|src, _dst| !src.is_empty())
149 } else {
150 self
151 }
152 }
153}
154
155#[derive(Debug, Clone)]
156pub(crate) struct Session {
157 pub(crate) _manager: Manager,
158 pub(crate) adapter: Adapter,
159}
160
161#[derive(Debug, Clone)]
162pub struct Scanner {
163 session: Weak<Session>,
164 event_sender: Sender<DeviceEvent>,
165 stoppers: Arc<RwLock<Vec<Trigger>>>,
166 scan_stopper: Arc<AtomicBool>,
167}
168
169impl Default for Scanner {
170 fn default() -> Self {
171 Scanner::new()
172 }
173}
174
175impl Scanner {
176 pub fn new() -> Self {
177 let (event_sender, _) = broadcast::channel(32);
178 Self {
179 scan_stopper: Arc::new(AtomicBool::new(false)),
180 session: Weak::new(),
181 event_sender,
182 stoppers: Arc::new(RwLock::new(Vec::new())),
183 }
184 }
185
186 pub async fn start(&mut self, config: ScanConfig) -> Result<(), Error> {
188 if self.session.upgrade().is_some() {
189 log::info!("Scanner is already started.");
190 return Ok(());
191 }
192
193 let manager = Manager::new().await?;
194 let mut adapters = manager.adapters().await?;
195
196 if config.adapter_index >= adapters.len() {
197 return Err(Error::DeviceNotFound);
198 }
199
200 let adapter = adapters.swap_remove(config.adapter_index);
201 log::trace!("Using adapter: {:?}", adapter);
202
203 let session = Arc::new(Session {
204 _manager: manager,
205 adapter,
206 });
207 self.session = Arc::downgrade(&session);
208
209 let event_sender = self.event_sender.clone();
210
211 let mut worker = ScannerWorker::new(
212 config,
213 session.clone(),
214 event_sender,
215 self.scan_stopper.clone(),
216 );
217 tokio::spawn(async move {
218 let _ = worker.scan().await;
219 });
220
221 Ok(())
222 }
223
224 pub async fn stop(&self) -> Result<(), Error> {
226 self.scan_stopper.store(true, Ordering::Relaxed);
227 self.stoppers.write()?.clear();
228 log::info!("Scanner is stopped.");
229
230 Ok(())
231 }
232
233 pub fn is_active(&self) -> bool {
235 self.session.upgrade().is_some()
236 }
237
238 pub fn device_event_stream(
240 &self,
241 ) -> Result<Valved<Pin<Box<dyn Stream<Item = DeviceEvent> + Send>>>, Error> {
242 let receiver = self.event_sender.subscribe();
243
244 let stream: Pin<Box<dyn Stream<Item = DeviceEvent> + Send>> =
245 Box::pin(BroadcastStream::new(receiver).filter_map(|x| async move {
246 match x {
247 Ok(event) => {
248 log::debug!("Broadcasting device: {:?}", event);
249 Some(event)
250 }
251 Err(e) => {
252 log::warn!("Error: {:?} when broadcasting device event!", e);
253 None
254 }
255 }
256 }));
257
258 let (trigger, stream) = Valved::new(stream);
259 self.stoppers.write()?.push(trigger);
260
261 Ok(stream)
262 }
263
264 pub fn device_stream(
266 &self,
267 ) -> Result<Valved<Pin<Box<dyn Stream<Item = Device> + Send>>>, Error> {
268 let receiver = self.event_sender.subscribe();
269
270 let stream: Pin<Box<dyn Stream<Item = Device> + Send>> =
271 Box::pin(BroadcastStream::new(receiver).filter_map(|x| async move {
272 match x {
273 Ok(DeviceEvent::Discovered(device)) => {
274 log::debug!("Broadcasting device: {:?}", device.address());
275 Some(device)
276 }
277 Err(e) => {
278 log::warn!("Error: {:?} when broadcasting device!", e);
279 None
280 }
281 _ => None,
282 }
283 }));
284
285 let (trigger, stream) = Valved::new(stream);
286 self.stoppers.write()?.push(trigger);
287
288 Ok(stream)
289 }
290}
291
292pub struct ScannerWorker {
293 config: ScanConfig,
295 session: Arc<Session>,
297 result_count: usize,
299 filtered: HashSet<PeripheralId>,
301 connecting: Arc<Mutex<HashSet<PeripheralId>>>,
303 matched: HashSet<PeripheralId>,
305 event_sender: Sender<DeviceEvent>,
307 stopper: Arc<AtomicBool>,
309}
310
311impl ScannerWorker {
312 fn new(
313 config: ScanConfig,
314 session: Arc<Session>,
315 event_sender: Sender<DeviceEvent>,
316 stopper: Arc<AtomicBool>,
317 ) -> Self {
318 Self {
319 config,
320 session,
321 result_count: 0,
322 filtered: HashSet::new(),
323 connecting: Arc::new(Mutex::new(HashSet::new())),
324 matched: HashSet::new(),
325 event_sender,
326 stopper,
327 }
328 }
329
330 async fn scan(&mut self) -> Result<(), Error> {
331 log::info!("Starting the scan");
332
333 self.session.adapter.start_scan(Default::default()).await?;
334
335 while let Ok(mut stream) = self.session.adapter.events().await {
336 let start_time = Instant::now();
337
338 while let Some(event) = stream.next().await {
339 match event {
340 CentralEvent::DeviceDiscovered(v) => self.on_device_discovered(v).await,
341 CentralEvent::DeviceUpdated(v) => self.on_device_updated(v).await,
342 CentralEvent::DeviceConnected(v) => self.on_device_connected(v).await?,
343 CentralEvent::DeviceDisconnected(v) => self.on_device_disconnected(v).await?,
344 _ => {}
345 }
346
347 let timeout_reached = self
348 .config
349 .timeout
350 .filter(|timeout| Instant::now().duration_since(start_time).ge(timeout))
351 .is_some();
352
353 let max_result_reached = self
354 .config
355 .max_results
356 .filter(|max_results| self.result_count >= *max_results)
357 .is_some();
358
359 if timeout_reached || max_result_reached || self.stopper.load(Ordering::Relaxed) {
360 log::info!("Scanner stop condition reached.");
361 return Ok(());
362 }
363 }
364 }
365
366 Ok(())
367 }
368
369 async fn on_device_discovered(&mut self, peripheral_id: PeripheralId) {
370 if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
371 log::trace!("Device discovered: {:?}", peripheral);
372
373 self.apply_filter(peripheral_id).await;
374 }
375 }
376
377 async fn on_device_updated(&mut self, peripheral_id: PeripheralId) {
378 if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
379 log::trace!("Device updated: {:?}", peripheral);
380
381 if self.matched.contains(&peripheral_id) {
382 let address = peripheral.address();
383 match self.event_sender.send(DeviceEvent::Updated(Device::new(
384 self.session.adapter.clone(),
385 peripheral,
386 ))) {
387 Ok(value) => log::debug!("Sent device: {}, size: {}...", address, value),
388 Err(e) => log::debug!("Error: {:?} when Sending device: {}...", e, address),
389 }
390 } else {
391 self.apply_filter(peripheral_id).await;
392 }
393 }
394 }
395
396 async fn on_device_connected(&mut self, peripheral_id: PeripheralId) -> Result<(), Error> {
397 self.connecting.lock()?.remove(&peripheral_id);
398
399 if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
400 log::trace!("Device connected: {:?}", peripheral);
401
402 if self.matched.contains(&peripheral_id) {
403 let address = peripheral.address();
404 match self.event_sender.send(DeviceEvent::Connected(Device::new(
405 self.session.adapter.clone(),
406 peripheral,
407 ))) {
408 Ok(value) => log::trace!("Sent device: {}, size: {}...", address, value),
409 Err(e) => log::warn!("Error: {:?} when Sending device: {}...", e, address),
410 }
411 } else {
412 self.apply_filter(peripheral_id).await;
413 }
414 }
415
416 Ok(())
417 }
418
419 async fn on_device_disconnected(&self, peripheral_id: PeripheralId) -> Result<(), Error> {
420 if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
421 log::trace!("Device disconnected: {:?}", peripheral);
422
423 if self.matched.contains(&peripheral_id) {
424 let address = peripheral.address();
425 match self
426 .event_sender
427 .send(DeviceEvent::Disconnected(Device::new(
428 self.session.adapter.clone(),
429 peripheral,
430 ))) {
431 Ok(value) => log::trace!("Sent device: {}, size: {}...", address, value),
432 Err(e) => log::warn!("Error: {:?} when Sending device: {}...", e, address),
433 }
434 }
435 }
436
437 self.connecting.lock()?.remove(&peripheral_id);
438
439 Ok(())
440 }
441
442 async fn apply_filter(&mut self, peripheral_id: PeripheralId) {
443 if self.filtered.contains(&peripheral_id) {
444 return;
445 }
446
447 if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
448 if let Ok(Some(property)) = peripheral.properties().await {
449 let mut passed = true;
450 log::trace!("filtering: {:?}", property);
451
452 for filter in self.config.filters.iter() {
453 if !passed {
454 break;
455 }
456 match filter {
457 Filter::Name(v) => {
458 passed &= property.local_name.as_ref().is_some_and(|name| {
459 if let Some(name_filter) = &self.config.name_filter {
460 name_filter(name, v)
461 } else {
462 name == v
463 }
464 })
465 }
466 Filter::Rssi(v) => {
467 passed &= property.rssi.is_some_and(|rssi| {
468 if let Some(rssi_filter) = &self.config.rssi_filter {
469 rssi_filter(rssi, *v)
470 } else {
471 rssi >= *v
472 }
473 });
474 }
475 Filter::Service(v) => {
476 let services = &property.services;
477 if let Some(service_filter) = &self.config.service_filter {
478 passed &= service_filter(&services, v);
479 } else {
480 passed &= property.services.contains(v);
481 }
482 }
483 Filter::Address(v) => {
484 let addr = property.address.to_string();
485 if let Some(address_filter) = &self.config.address_filter {
486 passed &= address_filter(&addr, v);
487 } else {
488 passed &= addr == *v;
489 }
490 }
491 Filter::Characteristic(v) => {
492 let _ = self
493 .apply_character_filter(&peripheral, v, &mut passed)
494 .await;
495 }
496 }
497 }
498
499 if passed {
500 self.matched.insert(peripheral_id.clone());
501 self.result_count += 1;
502
503 if let Err(e) = self.event_sender.send(DeviceEvent::Discovered(Device::new(
504 self.session.adapter.clone(),
505 peripheral,
506 ))) {
507 log::warn!("error: {} when sending device", e);
508 }
509 }
510
511 log::debug!(
512 "current matched: {}, current filtered: {}",
513 self.matched.len(),
514 self.filtered.len()
515 );
516 }
517
518 self.filtered.insert(peripheral_id);
519 }
520 }
521
522 async fn apply_character_filter(
523 &self,
524 peripheral: &Peripheral,
525 uuid: &Uuid,
526 passed: &mut bool,
527 ) -> Result<(), Error> {
528 if !peripheral.is_connected().await.unwrap_or(false) {
529 if self.connecting.lock()?.insert(peripheral.id()) {
530 log::debug!("Connecting to device {}", peripheral.address());
531
532 let connecting_map = self.connecting.clone();
535 if let Err(e) = peripheral.connect().await {
536 log::warn!("Could not connect to {}: {:?}", peripheral.address(), e);
537
538 connecting_map.lock()?.remove(&peripheral.id());
539
540 return Ok(());
541 };
542 }
543 }
544
545 let mut characteristics = Vec::new();
546 characteristics.extend(peripheral.characteristics());
547
548 if self.config.force_disconnect {
549 if let Err(e) = peripheral.disconnect().await {
550 log::warn!("Error: {} when disconnect device", e);
551 }
552 }
553
554 *passed &= if characteristics.is_empty() {
555 let address = peripheral.address();
556 log::debug!("Discovering characteristics for {}", address);
557
558 match peripheral.discover_services().await {
559 Ok(()) => {
560 characteristics.extend(peripheral.characteristics());
561 let characteristics = characteristics
562 .into_iter()
563 .map(|c| c.uuid)
564 .collect::<Vec<_>>();
565
566 if let Some(characteristics_filter) = &self.config.characteristics_filter {
567 characteristics_filter(&characteristics, uuid)
568 } else {
569 characteristics.contains(uuid)
570 }
571 }
572 Err(e) => {
573 log::warn!(
574 "Error: `{:?}` when discovering characteristics for {}",
575 e,
576 address
577 );
578 false
579 }
580 }
581 } else {
582 true
583 };
584
585 Ok(())
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::{Filter, ScanConfig, Scanner};
592 use crate::Device;
593 use btleplug::{api::BDAddr, Error};
594 use futures::StreamExt;
595 use std::{future::Future, time::Duration};
596 use uuid::Uuid;
597
598 async fn device_stream<T: Future<Output = ()>>(
599 scanner: Scanner,
600 callback: impl Fn(Device) -> T,
601 ) {
602 let duration = Duration::from_millis(15_000);
603 if let Err(_) = tokio::time::timeout(duration, async move {
604 while let Some(device) = scanner.device_stream().next().await {
605 callback(device).await;
606 break;
607 }
608 })
609 .await
610 {
611 eprintln!("timeout....");
612 }
613 }
614
615 #[tokio::test]
616 async fn test_filter_by_address() -> Result<(), Error> {
617 pretty_env_logger::init();
618
619 let mac_addr = [0xE3, 0x9E, 0x2A, 0x4D, 0xAA, 0x97];
620 let filers = vec![Filter::Address("E3:9E:2A:4D:AA:97".into())];
621 let cfg = ScanConfig::default()
622 .with_filters(&filers)
623 .stop_after_first_match();
624 let mut scanner = Scanner::default();
625
626 scanner.start(cfg).await?;
627 device_stream(scanner, |device| async move {
628 assert_eq!(device.address(), BDAddr::from(mac_addr));
629 })
630 .await;
631
632 Ok(())
633 }
634
635 #[tokio::test]
636 async fn test_filter_by_character() -> Result<(), Error> {
637 pretty_env_logger::init();
638
639 let filers = vec![Filter::Characteristic(Uuid::from_u128(
640 0x6e400001_b5a3_f393_e0a9_e50e24dcca9e,
641 ))];
642 let cfg = ScanConfig::default()
643 .with_filters(&filers)
644 .stop_after_first_match();
645 let mut scanner = Scanner::default();
646
647 scanner.start(cfg).await?;
648 device_stream(scanner, |device| async move {
649 println!("device: {:?} found", device);
650 })
651 .await;
652
653 Ok(())
654 }
655
656 #[tokio::test]
657 async fn test_filter_by_name() -> Result<(), Error> {
658 pretty_env_logger::init();
659
660 let name = "73429485";
661 let filers = vec![Filter::Name(name.into())];
662 let cfg = ScanConfig::default()
663 .with_filters(&filers)
664 .stop_after_first_match();
665 let mut scanner = Scanner::default();
666
667 scanner.start(cfg).await?;
668 device_stream(scanner, |device| async move {
669 assert_eq!(device.local_name().await, Some(name.into()));
670 })
671 .await;
672
673 Ok(())
674 }
675
676 #[tokio::test]
677 async fn test_filter_by_rssi() -> Result<(), Error> {
678 pretty_env_logger::init();
679
680 let filers = vec![Filter::Rssi(-70)];
681 let cfg = ScanConfig::default()
682 .with_filters(&filers)
683 .stop_after_first_match();
684 let mut scanner = Scanner::default();
685
686 scanner.start(cfg).await?;
687 device_stream(scanner, |device| async move {
688 println!("device: {:?} found", device);
689 })
690 .await;
691
692 Ok(())
693 }
694
695 #[tokio::test]
696 async fn test_filter_by_service() -> Result<(), Error> {
697 pretty_env_logger::init();
698
699 let service = Uuid::from_u128(0x6e400001_b5a3_f393_e0a9_e50e24dcca9e);
700 let filers = vec![Filter::Service(service)];
701 let cfg = ScanConfig::default()
702 .with_filters(&filers)
703 .stop_after_first_match();
704 let mut scanner = Scanner::default();
705
706 scanner.start(cfg).await?;
707 device_stream(scanner, |device| async move {
708 println!("device: {:?} found", device);
709 })
710 .await;
711
712 Ok(())
713 }
714}