use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex, RwLock, Weak};
use std::time::{Duration, Instant};
use btleplug::api::{BDAddr, Central, CentralEvent, Manager as _, Peripheral as _};
use btleplug::platform::{Adapter, Manager, Peripheral, PeripheralId};
use btleplug::Error;
use futures::{Stream, StreamExt};
use uuid::Uuid;
use crate::Device;
use stream_cancel::{Trigger, Valved};
use tokio::sync::broadcast;
use tokio::sync::broadcast::Sender;
use tokio_stream::wrappers::BroadcastStream;
#[derive(Default)]
pub struct ScanConfig {
adapter_index: usize,
address_filter: Option<Box<dyn Fn(BDAddr) -> bool + Send>>,
name_filter: Option<Box<dyn Fn(&str) -> bool + Send + Sync>>,
characteristics_filter: Option<Box<dyn Fn(&[Uuid]) -> bool + Send + Sync>>,
max_results: Option<usize>,
timeout: Option<Duration>,
force_disconnect: bool,
}
impl ScanConfig {
pub fn adapter_index(mut self, index: usize) -> Self {
self.adapter_index = index;
self
}
pub fn filter_by_address(mut self, func: impl Fn(BDAddr) -> bool + Send + 'static) -> Self {
self.address_filter = Some(Box::new(func));
self
}
pub fn filter_by_name(mut self, func: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
self.name_filter = Some(Box::new(func));
self
}
pub fn filter_by_characteristics(
mut self,
func: impl Fn(&[Uuid]) -> bool + Send + Sync + 'static,
) -> Self {
self.characteristics_filter = Some(Box::new(func));
self
}
pub fn stop_after_matches(mut self, max_results: usize) -> Self {
self.max_results = Some(max_results);
self
}
pub fn stop_after_first_match(self) -> Self {
self.stop_after_matches(1)
}
pub fn stop_after_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn force_disconnect(mut self, force_disconnect: bool) -> Self {
self.force_disconnect = force_disconnect;
self
}
pub fn require_name(self) -> Self {
if self.name_filter.is_none() {
self.filter_by_name(|name| !name.is_empty())
} else {
self
}
}
}
pub(crate) struct Session {
pub(crate) _manager: Manager,
pub(crate) adapter: Adapter,
}
pub struct Scanner {
session: Weak<Session>,
event_sender: Sender<DeviceEvent>,
scan_stopper: Option<Trigger>,
device_stream_stoppers: Arc<RwLock<Vec<Trigger>>>,
}
impl Default for Scanner {
fn default() -> Self {
Scanner::new()
}
}
impl Scanner {
pub fn new() -> Self {
let (event_sender, _) = broadcast::channel(16);
Self {
session: Weak::new(),
event_sender,
scan_stopper: None,
device_stream_stoppers: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn start(&mut self, config: ScanConfig) -> Result<(), Error> {
if self.session.upgrade().is_some() {
log::info!("Scanner is already started.");
return Ok(());
}
let manager = Manager::new().await?;
let mut adapters = manager.adapters().await?;
if config.adapter_index >= adapters.len() {
return Err(Error::DeviceNotFound);
}
let adapter = adapters.swap_remove(config.adapter_index);
log::trace!("Using adapter: {:?}", adapter);
let session = Arc::new(Session {
_manager: manager,
adapter,
});
let stopper = ScanContext::start(
config,
session.clone(),
self.event_sender.clone(),
self.device_stream_stoppers.clone(),
)
.await?;
self.scan_stopper = Some(stopper);
self.session = Arc::downgrade(&session);
Ok(())
}
pub async fn stop(&mut self) -> Result<(), Error> {
if let Some(session) = self.session.upgrade() {
session.adapter.stop_scan().await?;
self.scan_stopper.take();
self.device_stream_stoppers.write().unwrap().clear();
} else {
log::info!("Scanner is already stopped");
}
Ok(())
}
pub fn is_active(&self) -> bool {
self.session.upgrade().is_some()
}
pub fn device_event_stream(
&mut self,
) -> Valved<Pin<Box<dyn Stream<Item = DeviceEvent> + Send>>> {
let receiver = self.event_sender.subscribe();
let stream: Pin<Box<dyn Stream<Item = DeviceEvent> + Send>> =
Box::pin(BroadcastStream::new(receiver).filter_map(|x| async move { x.ok() }));
let (trigger, stream) = Valved::new(stream);
self.device_stream_stoppers.write().unwrap().push(trigger);
stream
}
pub fn device_stream(&mut self) -> Valved<Pin<Box<dyn Stream<Item = Device> + Send>>> {
let receiver = self.event_sender.subscribe();
let stream: Pin<Box<dyn Stream<Item = Device> + Send>> =
Box::pin(BroadcastStream::new(receiver).filter_map(|x| async move {
match x {
Ok(DeviceEvent::Discovered(device)) => Some(device),
_ => None,
}
}));
let (trigger, stream) = Valved::new(stream);
self.device_stream_stoppers.write().unwrap().push(trigger);
stream
}
}
struct ScanContext {
result_count: usize,
session: Arc<Session>,
config: ScanConfig,
connection_needed: bool,
filtered: HashSet<PeripheralId>,
connecting: Arc<Mutex<HashSet<PeripheralId>>>,
matched: HashSet<PeripheralId>,
event_sender: Sender<DeviceEvent>,
}
impl ScanContext {
async fn start(
config: ScanConfig,
session: Arc<Session>,
sender: Sender<DeviceEvent>,
device_stream_stoppers: Arc<RwLock<Vec<Trigger>>>,
) -> Result<Trigger, Error> {
let connection_needed = config.characteristics_filter.is_some();
log::info!("Starting the scan");
let (stopper, events) = stream_cancel::Valved::new(session.adapter.events().await?);
session.adapter.start_scan(Default::default()).await?;
let ctx = ScanContext {
result_count: 0,
session,
config,
connection_needed,
filtered: HashSet::new(),
connecting: Arc::new(Mutex::new(HashSet::new())),
matched: HashSet::new(),
event_sender: sender,
};
tokio::spawn(async move {
ctx.listen(events, device_stream_stoppers).await;
});
Ok(stopper)
}
async fn listen(
mut self,
mut event_stream: Valved<Pin<Box<dyn Stream<Item = CentralEvent> + Send>>>,
device_stream_stoppers: Arc<RwLock<Vec<Trigger>>>,
) {
let start_time = Instant::now();
while let Some(event) = event_stream.next().await {
match event {
CentralEvent::DeviceDiscovered(peripheral_id) => {
self.on_device_discovered(peripheral_id).await;
}
CentralEvent::DeviceConnected(peripheral_id) => {
self.on_device_connected(peripheral_id).await;
}
CentralEvent::DeviceDisconnected(peripheral_id) => {
self.on_device_disconnected(peripheral_id).await;
}
CentralEvent::DeviceUpdated(peripheral_id) => {
self.on_device_updated(peripheral_id).await;
}
_ => {}
}
let timeout_reached = self
.config
.timeout
.filter(|timeout| Instant::now().duration_since(start_time).ge(timeout))
.is_some();
let max_result_reached = self
.config
.max_results
.filter(|max_results| self.result_count >= *max_results)
.is_some();
if timeout_reached || max_result_reached {
log::info!("Scanner stop condition reached.");
break;
}
}
device_stream_stoppers.write().unwrap().clear();
log::info!("Scanner was stopped.");
}
async fn on_device_discovered(&mut self, peripheral_id: PeripheralId) {
if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
log::trace!("Device discovered: {:?}", peripheral);
self.apply_filter(peripheral).await;
}
}
async fn on_device_updated(&mut self, peripheral_id: PeripheralId) {
if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
log::trace!("Device updated: {:?}", peripheral);
if self.matched.contains(&peripheral_id) {
self.event_sender
.send(DeviceEvent::Updated(Device::new(
self.session.adapter.clone(),
peripheral,
)))
.ok();
} else {
self.apply_filter(peripheral).await;
}
}
}
async fn on_device_connected(&mut self, peripheral_id: PeripheralId) {
self.connecting.lock().unwrap().remove(&peripheral_id);
if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
log::trace!("Device connected: {:?}", peripheral);
if self.matched.contains(&peripheral_id) {
self.event_sender
.send(DeviceEvent::Connected(Device::new(
self.session.adapter.clone(),
peripheral,
)))
.ok();
} else {
self.apply_filter(peripheral).await;
}
}
}
async fn on_device_disconnected(&mut self, peripheral_id: PeripheralId) {
if let Ok(peripheral) = self.session.adapter.peripheral(&peripheral_id).await {
log::trace!("Device disconnected: {:?}", peripheral);
if self.matched.contains(&peripheral_id) {
self.event_sender
.send(DeviceEvent::Disconnected(Device::new(
self.session.adapter.clone(),
peripheral,
)))
.ok();
}
}
self.connecting.lock().unwrap().remove(&peripheral_id);
}
async fn apply_filter(&mut self, peripheral: Peripheral) {
if self.filtered.contains(&peripheral.id()) {
return;
}
match self.passes_pre_connect_filters(&peripheral).await {
Some(false) => {
self.skip_peripheral(&peripheral).await;
return;
}
None => {
return;
}
_ => {
}
};
if self.connection_needed {
if !peripheral.is_connected().await.unwrap_or(false) {
if self.connecting.lock().unwrap().insert(peripheral.id()) {
log::debug!("Connecting to device {}", peripheral.address());
let peripheral_clone = peripheral.clone();
let connecting_map = self.connecting.clone();
tokio::spawn(async move {
if let Err(e) = peripheral_clone.connect().await {
log::warn!(
"Could not connect to {}: {:?}",
peripheral_clone.address(),
e
);
connecting_map
.lock()
.unwrap()
.remove(&peripheral_clone.id());
};
});
}
return;
} else if let Some(false) = self.passes_post_connect_filters(&peripheral).await {
self.skip_peripheral(&peripheral).await;
return;
} else if self.config.force_disconnect {
peripheral.disconnect().await.ok();
}
}
self.add_peripheral(peripheral).await;
}
async fn skip_peripheral(&mut self, peripheral: &Peripheral) {
self.filtered.insert(peripheral.id());
if self.config.force_disconnect {
peripheral.disconnect().await.ok();
return;
}
if let Ok(connected) = peripheral.is_connected().await {
if !connected {
return;
}
}
if self.config.address_filter.is_none() && self.config.name_filter.is_none() {
return;
}
let Ok(Some(properties)) = peripheral.properties().await else {
return;
};
if let Some(filter_by_address) = self.config.address_filter.as_ref() {
if filter_by_address(properties.address) {
peripheral.disconnect().await.ok();
}
}
if let Some(filter_by_name) = self.config.name_filter.as_ref() {
if let Some(local_name) = properties.local_name {
if filter_by_name(local_name.as_str()) {
peripheral.disconnect().await.ok();
}
}
}
}
async fn add_peripheral(&mut self, peripheral: Peripheral) {
self.filtered.insert(peripheral.id());
self.matched.insert(peripheral.id());
log::info!("Found device: {:?}", peripheral);
let device = Device::new(self.session.adapter.clone(), peripheral);
match self.event_sender.send(DeviceEvent::Discovered(device)) {
Ok(_) => {
self.result_count += 1;
}
Err(e) => log::error!("Failed to add device: {}", e),
}
}
async fn passes_pre_connect_filters(&mut self, peripheral: &Peripheral) -> Option<bool> {
let mut passed = true;
if let Some(filter_by_addr) = self.config.address_filter.as_ref() {
passed &= filter_by_addr(peripheral.address());
}
if let Some(filter_by_name) = self.config.name_filter.as_ref() {
passed &= match peripheral.properties().await {
Ok(Some(props)) => props.local_name.map(|name| filter_by_name(&name)),
_ => None,
}?;
}
Some(passed)
}
async fn passes_post_connect_filters(&mut self, peripheral: &Peripheral) -> Option<bool> {
let mut passed = true;
if !peripheral.is_connected().await.unwrap_or(false) {
return None;
}
if let Some(filter_by_characteristics) = self.config.characteristics_filter.as_ref() {
let mut characteristics = Vec::new();
characteristics.extend(peripheral.characteristics());
passed &= if characteristics.is_empty() {
let address = peripheral.address();
log::debug!("Discovering characteristics for {}", address);
match peripheral.discover_services().await {
Ok(()) => {
characteristics.extend(peripheral.characteristics());
let characteristics = characteristics
.into_iter()
.map(|c| c.uuid)
.collect::<Vec<_>>();
filter_by_characteristics(characteristics.as_slice())
}
Err(e) => {
log::warn!(
"Error: `{:?}` when discovering characteristics for {}",
e,
address
);
false
}
}
} else {
true
}
}
Some(passed)
}
}
#[derive(Clone)]
pub enum DeviceEvent {
Discovered(Device),
Connected(Device),
Disconnected(Device),
Updated(Device),
}