bleasy/
scanner.rs

1use btleplug::{
2    api::{Central, CentralEvent, Manager as _, Peripheral as _},
3    platform::{Adapter, Manager, Peripheral, PeripheralId},
4    Error,
5};
6use futures::{Stream, StreamExt};
7use std::{
8    collections::HashSet,
9    pin::Pin,
10    sync::{
11        atomic::{AtomicBool, Ordering},
12        Arc, Mutex, RwLock, Weak,
13    },
14    time::{Duration, Instant},
15};
16use stream_cancel::{Trigger, Valved};
17use tokio::sync::broadcast::{self, Sender};
18use tokio_stream::wrappers::BroadcastStream;
19use uuid::Uuid;
20
21use crate::{Device, DeviceEvent};
22
23#[derive(Debug, Clone, Hash, Eq, PartialEq)]
24pub enum Filter {
25    Address(String),
26    Characteristic(Uuid),
27    Name(String),
28    Rssi(i16),
29    Service(Uuid),
30}
31
32#[derive(Default)]
33pub struct ScanConfig {
34    /// Index of the Bluetooth adapter to use. The first found adapter is used by default.
35    adapter_index: usize,
36    /// Filters objects
37    filters: Vec<Filter>,
38    /// Filters the found devices based on device address.
39    address_filter: Option<Box<dyn Fn(&str, &str) -> bool + Send + Sync>>,
40    /// Filters the found devices based on local name.
41    name_filter: Option<Box<dyn Fn(&str, &str) -> bool + Send + Sync>>,
42    /// Filters the found devices based on rssi.
43    rssi_filter: Option<Box<dyn Fn(i16, i16) -> bool + Send + Sync>>,
44    /// Filters the found devices based on service's uuid.
45    service_filter: Option<Box<dyn Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync>>,
46    /// Filters the found devices based on characteristics. Requires a connection to the device.
47    characteristics_filter: Option<Box<dyn Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync>>,
48    /// Maximum results before the scan is stopped.
49    max_results: Option<usize>,
50    /// The scan is stopped when timeout duration is reached.
51    timeout: Option<Duration>,
52    /// Force disconnect when listen the device is connected.
53    force_disconnect: bool,
54}
55
56impl ScanConfig {
57    /// Index of bluetooth adapter to use
58    #[inline]
59    pub fn adapter_index(mut self, index: usize) -> Self {
60        self.adapter_index = index;
61        self
62    }
63
64    #[inline]
65    pub fn with_filters(mut self, filters: &[Filter]) -> Self {
66        self.filters.extend_from_slice(filters);
67        self
68    }
69
70    /// Filter scanned devices based on the device address
71    #[inline]
72    pub fn filter_by_address(
73        mut self,
74        func: impl Fn(&str, &str) -> bool + Send + Sync + 'static,
75    ) -> Self {
76        self.address_filter = Some(Box::new(func));
77        self
78    }
79
80    /// Filter scanned devices based on the device name
81    #[inline]
82    pub fn filter_by_name(
83        mut self,
84        func: impl Fn(&str, &str) -> bool + Send + Sync + 'static,
85    ) -> Self {
86        self.name_filter = Some(Box::new(func));
87        self
88    }
89
90    #[inline]
91    pub fn filter_by_rssi(
92        mut self,
93        func: impl Fn(i16, i16) -> bool + Send + Sync + 'static,
94    ) -> Self {
95        self.rssi_filter = Some(Box::new(func));
96        self
97    }
98
99    #[inline]
100    pub fn filter_by_service(
101        mut self,
102        func: impl Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync + 'static,
103    ) -> Self {
104        self.service_filter = Some(Box::new(func));
105        self
106    }
107
108    /// Filter scanned devices based on available characteristics
109    #[inline]
110    pub fn filter_by_characteristics(
111        mut self,
112        func: impl Fn(&Vec<Uuid>, &Uuid) -> bool + Send + Sync + 'static,
113    ) -> Self {
114        self.characteristics_filter = Some(Box::new(func));
115        self
116    }
117
118    /// Stop the scan after given number of matches
119    #[inline]
120    pub fn stop_after_matches(mut self, max_results: usize) -> Self {
121        self.max_results = Some(max_results);
122        self
123    }
124
125    /// Stop the scan after the first match
126    #[inline]
127    pub fn stop_after_first_match(self) -> Self {
128        self.stop_after_matches(1)
129    }
130
131    /// Stop the scan after given duration
132    #[inline]
133    pub fn stop_after_timeout(mut self, timeout: Duration) -> Self {
134        self.timeout = Some(timeout);
135        self
136    }
137
138    #[inline]
139    pub fn force_disconnect(mut self, force_disconnect: bool) -> Self {
140        self.force_disconnect = force_disconnect;
141        self
142    }
143
144    /// Require that the scanned devices have a name
145    #[inline]
146    pub fn require_name(self) -> Self {
147        if self.name_filter.is_none() {
148            self.filter_by_name(|src, _dst| !src.is_empty())
149        } else {
150            self
151        }
152    }
153}
154
155#[derive(Debug, Clone)]
156pub(crate) struct Session {
157    pub(crate) _manager: Manager,
158    pub(crate) adapter: Adapter,
159}
160
161#[derive(Debug, Clone)]
162pub struct Scanner {
163    session: Weak<Session>,
164    event_sender: Sender<DeviceEvent>,
165    stoppers: Arc<RwLock<Vec<Trigger>>>,
166    scan_stopper: Arc<AtomicBool>,
167}
168
169impl Default for Scanner {
170    fn default() -> Self {
171        Scanner::new()
172    }
173}
174
175impl Scanner {
176    pub fn new() -> Self {
177        let (event_sender, _) = broadcast::channel(32);
178        Self {
179            scan_stopper: Arc::new(AtomicBool::new(false)),
180            session: Weak::new(),
181            event_sender,
182            stoppers: Arc::new(RwLock::new(Vec::new())),
183        }
184    }
185
186    /// Start scanning for ble devices.
187    pub async fn start(&mut self, config: ScanConfig) -> Result<(), Error> {
188        if self.session.upgrade().is_some() {
189            log::info!("Scanner is already started.");
190            return Ok(());
191        }
192
193        let manager = Manager::new().await?;
194        let mut adapters = manager.adapters().await?;
195
196        if config.adapter_index >= adapters.len() {
197            return Err(Error::DeviceNotFound);
198        }
199
200        let adapter = adapters.swap_remove(config.adapter_index);
201        log::trace!("Using adapter: {:?}", adapter);
202
203        let session = Arc::new(Session {
204            _manager: manager,
205            adapter,
206        });
207        self.session = Arc::downgrade(&session);
208
209        let event_sender = self.event_sender.clone();
210
211        let mut worker = ScannerWorker::new(
212            config,
213            session.clone(),
214            event_sender,
215            self.scan_stopper.clone(),
216        );
217        tokio::spawn(async move {
218            let _ = worker.scan().await;
219        });
220
221        Ok(())
222    }
223
224    /// Stop scanning for ble devices.
225    pub async fn stop(&self) -> Result<(), Error> {
226        self.scan_stopper.store(true, Ordering::Relaxed);
227        self.stoppers.write()?.clear();
228        log::info!("Scanner is stopped.");
229
230        Ok(())
231    }
232
233    /// Returns true if the scanner is active.
234    pub fn is_active(&self) -> bool {
235        self.session.upgrade().is_some()
236    }
237
238    /// Create a new stream that receives ble device events.
239    pub fn device_event_stream(
240        &self,
241    ) -> Result<Valved<Pin<Box<dyn Stream<Item = DeviceEvent> + Send>>>, Error> {
242        let receiver = self.event_sender.subscribe();
243
244        let stream: Pin<Box<dyn Stream<Item = DeviceEvent> + Send>> =
245            Box::pin(BroadcastStream::new(receiver).filter_map(|x| async move {
246                match x {
247                    Ok(event) => {
248                        log::debug!("Broadcasting device: {:?}", event);
249                        Some(event)
250                    }
251                    Err(e) => {
252                        log::warn!("Error: {:?} when broadcasting device event!", e);
253                        None
254                    }
255                }
256            }));
257
258        let (trigger, stream) = Valved::new(stream);
259        self.stoppers.write()?.push(trigger);
260
261        Ok(stream)
262    }
263
264    /// Create a new stream that receives discovered ble devices.
265    pub fn device_stream(
266        &self,
267    ) -> Result<Valved<Pin<Box<dyn Stream<Item = Device> + Send>>>, Error> {
268        let receiver = self.event_sender.subscribe();
269
270        let stream: Pin<Box<dyn Stream<Item = Device> + Send>> =
271            Box::pin(BroadcastStream::new(receiver).filter_map(|x| async move {
272                match x {
273                    Ok(DeviceEvent::Discovered(device)) => {
274                        log::debug!("Broadcasting device: {:?}", device.address());
275                        Some(device)
276                    }
277                    Err(e) => {
278                        log::warn!("Error: {:?} when broadcasting device!", e);
279                        None
280                    }
281                    _ => None,
282                }
283            }));
284
285        let (trigger, stream) = Valved::new(stream);
286        self.stoppers.write()?.push(trigger);
287
288        Ok(stream)
289    }
290}
291
292pub struct ScannerWorker {
293    /// Configurations for the scan, such as filters and stop conditions
294    config: ScanConfig,
295    /// Reference to the bluetooth session instance
296    session: Arc<Session>,
297    /// Number of matching devices found so far
298    result_count: usize,
299    /// Set of devices that have been filtered and will be ignored
300    filtered: HashSet<PeripheralId>,
301    /// Set of devices that we are currently connecting to
302    connecting: Arc<Mutex<HashSet<PeripheralId>>>,
303    /// Set of devices that matched the filters
304    matched: HashSet<PeripheralId>,
305    /// Channel for sending events to the client
306    event_sender: Sender<DeviceEvent>,
307    /// Stop the scan event.
308    stopper: Arc<AtomicBool>,
309}
310
311impl ScannerWorker {
312    fn new(
313        config: ScanConfig,
314        session: Arc<Session>,
315        event_sender: Sender<DeviceEvent>,
316        stopper: Arc<AtomicBool>,
317    ) -> Self {
318        Self {
319            config,
320            session,
321            result_count: 0,
322            filtered: HashSet::new(),
323            connecting: Arc::new(Mutex::new(HashSet::new())),
324            matched: HashSet::new(),
325            event_sender,
326            stopper,
327        }
328    }
329
330    async fn scan(&mut self) -> Result<(), Error> {
331        log::info!("Starting the scan");
332
333        self.session.adapter.start_scan(Default::default()).await?;
334
335        while let Ok(mut stream) = self.session.adapter.events().await {
336            let start_time = Instant::now();
337
338            while let Some(event) = stream.next().await {
339                match event {
340                    CentralEvent::DeviceDiscovered(v) => self.on_device_discovered(v).await,
341                    CentralEvent::DeviceUpdated(v) => self.on_device_updated(v).await,
342                    CentralEvent::DeviceConnected(v) => self.on_device_connected(v).await?,
343                    CentralEvent::DeviceDisconnected(v) => self.on_device_disconnected(v).await?,
344                    _ => {}
345                }
346
347                let timeout_reached = self
348                    .config
349                    .timeout
350                    .filter(|timeout| Instant::now().duration_since(start_time).ge(timeout))
351                    .is_some();
352
353                let max_result_reached = self
354                    .config
355                    .max_results
356                    .filter(|max_results| self.result_count >= *max_results)
357                    .is_some();
358
359                if timeout_reached || max_result_reached || self.stopper.load(Ordering::Relaxed) {
360                    log::info!("Scanner stop condition reached.");
361                    return Ok(());
362                }
363            }
364        }
365
366        Ok(())
367    }
368
369    async fn on_device_discovered(&mut self, peripheral_id: PeripheralId) {
370        if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
371            log::trace!("Device discovered: {:?}", peripheral);
372
373            self.apply_filter(peripheral_id).await;
374        }
375    }
376
377    async fn on_device_updated(&mut self, peripheral_id: PeripheralId) {
378        if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
379            log::trace!("Device updated: {:?}", peripheral);
380
381            if self.matched.contains(&peripheral_id) {
382                let address = peripheral.address();
383                match self.event_sender.send(DeviceEvent::Updated(Device::new(
384                    self.session.adapter.clone(),
385                    peripheral,
386                ))) {
387                    Ok(value) => log::debug!("Sent device: {}, size: {}...", address, value),
388                    Err(e) => log::debug!("Error: {:?} when Sending device: {}...", e, address),
389                }
390            } else {
391                self.apply_filter(peripheral_id).await;
392            }
393        }
394    }
395
396    async fn on_device_connected(&mut self, peripheral_id: PeripheralId) -> Result<(), Error> {
397        self.connecting.lock()?.remove(&peripheral_id);
398
399        if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
400            log::trace!("Device connected: {:?}", peripheral);
401
402            if self.matched.contains(&peripheral_id) {
403                let address = peripheral.address();
404                match self.event_sender.send(DeviceEvent::Connected(Device::new(
405                    self.session.adapter.clone(),
406                    peripheral,
407                ))) {
408                    Ok(value) => log::trace!("Sent device: {}, size: {}...", address, value),
409                    Err(e) => log::warn!("Error: {:?} when Sending device: {}...", e, address),
410                }
411            } else {
412                self.apply_filter(peripheral_id).await;
413            }
414        }
415
416        Ok(())
417    }
418
419    async fn on_device_disconnected(&self, peripheral_id: PeripheralId) -> Result<(), Error> {
420        if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
421            log::trace!("Device disconnected: {:?}", peripheral);
422
423            if self.matched.contains(&peripheral_id) {
424                let address = peripheral.address();
425                match self
426                    .event_sender
427                    .send(DeviceEvent::Disconnected(Device::new(
428                        self.session.adapter.clone(),
429                        peripheral,
430                    ))) {
431                    Ok(value) => log::trace!("Sent device: {}, size: {}...", address, value),
432                    Err(e) => log::warn!("Error: {:?} when Sending device: {}...", e, address),
433                }
434            }
435        }
436
437        self.connecting.lock()?.remove(&peripheral_id);
438
439        Ok(())
440    }
441
442    async fn apply_filter(&mut self, peripheral_id: PeripheralId) {
443        if self.filtered.contains(&peripheral_id) {
444            return;
445        }
446
447        if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
448            if let Ok(Some(property)) = peripheral.properties().await {
449                let mut passed = true;
450                log::trace!("filtering: {:?}", property);
451
452                for filter in self.config.filters.iter() {
453                    if !passed {
454                        break;
455                    }
456                    match filter {
457                        Filter::Name(v) => {
458                            passed &= property.local_name.as_ref().is_some_and(|name| {
459                                if let Some(name_filter) = &self.config.name_filter {
460                                    name_filter(name, v)
461                                } else {
462                                    name == v
463                                }
464                            })
465                        }
466                        Filter::Rssi(v) => {
467                            passed &= property.rssi.is_some_and(|rssi| {
468                                if let Some(rssi_filter) = &self.config.rssi_filter {
469                                    rssi_filter(rssi, *v)
470                                } else {
471                                    rssi >= *v
472                                }
473                            });
474                        }
475                        Filter::Service(v) => {
476                            let services = &property.services;
477                            if let Some(service_filter) = &self.config.service_filter {
478                                passed &= service_filter(&services, v);
479                            } else {
480                                passed &= property.services.contains(v);
481                            }
482                        }
483                        Filter::Address(v) => {
484                            let addr = property.address.to_string();
485                            if let Some(address_filter) = &self.config.address_filter {
486                                passed &= address_filter(&addr, v);
487                            } else {
488                                passed &= addr == *v;
489                            }
490                        }
491                        Filter::Characteristic(v) => {
492                            let _ = self
493                                .apply_character_filter(&peripheral, v, &mut passed)
494                                .await;
495                        }
496                    }
497                }
498
499                if passed {
500                    self.matched.insert(peripheral_id.clone());
501                    self.result_count += 1;
502
503                    if let Err(e) = self.event_sender.send(DeviceEvent::Discovered(Device::new(
504                        self.session.adapter.clone(),
505                        peripheral,
506                    ))) {
507                        log::warn!("error: {} when sending device", e);
508                    }
509                }
510
511                log::debug!(
512                    "current matched: {}, current filtered: {}",
513                    self.matched.len(),
514                    self.filtered.len()
515                );
516            }
517
518            self.filtered.insert(peripheral_id);
519        }
520    }
521
522    async fn apply_character_filter(
523        &self,
524        peripheral: &Peripheral,
525        uuid: &Uuid,
526        passed: &mut bool,
527    ) -> Result<(), Error> {
528        if !peripheral.is_connected().await.unwrap_or(false) {
529            if self.connecting.lock()?.insert(peripheral.id()) {
530                log::debug!("Connecting to device {}", peripheral.address());
531
532                // Connect in another thread, so we can keep filtering other devices meanwhile.
533                // let peripheral_clone = peripheral.clone();
534                let connecting_map = self.connecting.clone();
535                if let Err(e) = peripheral.connect().await {
536                    log::warn!("Could not connect to {}: {:?}", peripheral.address(), e);
537
538                    connecting_map.lock()?.remove(&peripheral.id());
539
540                    return Ok(());
541                };
542            }
543        }
544
545        let mut characteristics = Vec::new();
546        characteristics.extend(peripheral.characteristics());
547
548        if self.config.force_disconnect {
549            if let Err(e) = peripheral.disconnect().await {
550                log::warn!("Error: {} when disconnect device", e);
551            }
552        }
553
554        *passed &= if characteristics.is_empty() {
555            let address = peripheral.address();
556            log::debug!("Discovering characteristics for {}", address);
557
558            match peripheral.discover_services().await {
559                Ok(()) => {
560                    characteristics.extend(peripheral.characteristics());
561                    let characteristics = characteristics
562                        .into_iter()
563                        .map(|c| c.uuid)
564                        .collect::<Vec<_>>();
565
566                    if let Some(characteristics_filter) = &self.config.characteristics_filter {
567                        characteristics_filter(&characteristics, uuid)
568                    } else {
569                        characteristics.contains(uuid)
570                    }
571                }
572                Err(e) => {
573                    log::warn!(
574                        "Error: `{:?}` when discovering characteristics for {}",
575                        e,
576                        address
577                    );
578                    false
579                }
580            }
581        } else {
582            true
583        };
584
585        Ok(())
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::{Filter, ScanConfig, Scanner};
592    use crate::Device;
593    use btleplug::{api::BDAddr, Error};
594    use futures::StreamExt;
595    use std::{future::Future, time::Duration};
596    use uuid::Uuid;
597
598    async fn device_stream<T: Future<Output = ()>>(
599        scanner: Scanner,
600        callback: impl Fn(Device) -> T,
601    ) {
602        let duration = Duration::from_millis(15_000);
603        if let Err(_) = tokio::time::timeout(duration, async move {
604            while let Some(device) = scanner.device_stream().next().await {
605                callback(device).await;
606                break;
607            }
608        })
609        .await
610        {
611            eprintln!("timeout....");
612        }
613    }
614
615    #[tokio::test]
616    async fn test_filter_by_address() -> Result<(), Error> {
617        pretty_env_logger::init();
618
619        let mac_addr = [0xE3, 0x9E, 0x2A, 0x4D, 0xAA, 0x97];
620        let filers = vec![Filter::Address("E3:9E:2A:4D:AA:97".into())];
621        let cfg = ScanConfig::default()
622            .with_filters(&filers)
623            .stop_after_first_match();
624        let mut scanner = Scanner::default();
625
626        scanner.start(cfg).await?;
627        device_stream(scanner, |device| async move {
628            assert_eq!(device.address(), BDAddr::from(mac_addr));
629        })
630        .await;
631
632        Ok(())
633    }
634
635    #[tokio::test]
636    async fn test_filter_by_character() -> Result<(), Error> {
637        pretty_env_logger::init();
638
639        let filers = vec![Filter::Characteristic(Uuid::from_u128(
640            0x6e400001_b5a3_f393_e0a9_e50e24dcca9e,
641        ))];
642        let cfg = ScanConfig::default()
643            .with_filters(&filers)
644            .stop_after_first_match();
645        let mut scanner = Scanner::default();
646
647        scanner.start(cfg).await?;
648        device_stream(scanner, |device| async move {
649            println!("device: {:?} found", device);
650        })
651        .await;
652
653        Ok(())
654    }
655
656    #[tokio::test]
657    async fn test_filter_by_name() -> Result<(), Error> {
658        pretty_env_logger::init();
659
660        let name = "73429485";
661        let filers = vec![Filter::Name(name.into())];
662        let cfg = ScanConfig::default()
663            .with_filters(&filers)
664            .stop_after_first_match();
665        let mut scanner = Scanner::default();
666
667        scanner.start(cfg).await?;
668        device_stream(scanner, |device| async move {
669            assert_eq!(device.local_name().await, Some(name.into()));
670        })
671        .await;
672
673        Ok(())
674    }
675
676    #[tokio::test]
677    async fn test_filter_by_rssi() -> Result<(), Error> {
678        pretty_env_logger::init();
679
680        let filers = vec![Filter::Rssi(-70)];
681        let cfg = ScanConfig::default()
682            .with_filters(&filers)
683            .stop_after_first_match();
684        let mut scanner = Scanner::default();
685
686        scanner.start(cfg).await?;
687        device_stream(scanner, |device| async move {
688            println!("device: {:?} found", device);
689        })
690        .await;
691
692        Ok(())
693    }
694
695    #[tokio::test]
696    async fn test_filter_by_service() -> Result<(), Error> {
697        pretty_env_logger::init();
698
699        let service = Uuid::from_u128(0x6e400001_b5a3_f393_e0a9_e50e24dcca9e);
700        let filers = vec![Filter::Service(service)];
701        let cfg = ScanConfig::default()
702            .with_filters(&filers)
703            .stop_after_first_match();
704        let mut scanner = Scanner::default();
705
706        scanner.start(cfg).await?;
707        device_stream(scanner, |device| async move {
708            println!("device: {:?} found", device);
709        })
710        .await;
711
712        Ok(())
713    }
714}