use crate::error::{BleError, Result};
use crate::gatt::{
u16_le, AdvertSource, GattCharacteristic, GattConnection, GattService, RawAdvert,
HAP_INSTANCE_ID_DESC, HAP_SERVICE_ID_CHAR,
};
use async_trait::async_trait;
use bluest::error::ErrorKind;
use bluest::{Adapter, Characteristic, Device};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::{mpsc, Mutex};
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const SERVICE_SIGNATURE_CHAR: &str = "000000a5-0000-1000-8000-0026bb765291";
const PROTOCOL_INFO_SERVICE: &str = "000000a2-0000-1000-8000-0026bb765291";
const MAX_OP_RECONNECTS: u32 = 8;
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn be(e: bluest::Error) -> BleError {
match e.kind() {
ErrorKind::NotConnected
| ErrorKind::AdapterUnavailable
| ErrorKind::ConnectionFailed
| ErrorKind::ServiceChanged
| ErrorKind::NotReady
| ErrorKind::Timeout => BleError::Disconnected,
_ => BleError::Backend(e.to_string()),
}
}
fn is_disconnect(e: &BleError) -> bool {
matches!(e, BleError::Disconnected)
}
#[derive(Clone)]
struct ServiceShape {
uuid: String,
char_uuids: Vec<String>,
}
pub struct BluestConnection {
adapter: Adapter,
device: Device,
chars: Mutex<HashMap<String, Characteristic>>,
shape: Vec<ServiceShape>,
generation: AtomicU64,
}
impl BluestConnection {
pub async fn new(adapter: Adapter, device: Device) -> Result<Self> {
let (chars, shape) = Self::discover(&device).await?;
Ok(Self {
adapter,
device,
chars: Mutex::new(chars),
shape,
generation: AtomicU64::new(0),
})
}
pub async fn disconnect(&self) {
let _ = self.adapter.disconnect_device(&self.device).await;
}
async fn discover(
device: &Device,
) -> Result<(HashMap<String, Characteristic>, Vec<ServiceShape>)> {
let mut chars = HashMap::new();
let mut shape = Vec::new();
for svc in device.discover_services().await.map_err(be)? {
let svc_uuid = svc.uuid().to_string().to_ascii_lowercase();
let is_protocol_info = svc_uuid == PROTOCOL_INFO_SERVICE;
let mut char_uuids = Vec::new();
for ch in svc.discover_characteristics().await.map_err(be)? {
let uuid = ch.uuid().to_string().to_ascii_lowercase();
char_uuids.push(uuid.clone());
if uuid == SERVICE_SIGNATURE_CHAR && !is_protocol_info {
continue;
}
chars.insert(uuid, ch);
}
shape.push(ServiceShape {
uuid: svc.uuid().to_string(),
char_uuids,
});
}
Ok((chars, shape))
}
async fn reconnect(&self) -> Result<()> {
self.generation.fetch_add(1, Ordering::SeqCst);
let _ = self.adapter.disconnect_device(&self.device).await;
let _ = self.adapter.wait_available().await;
let establish = async {
self.adapter
.connect_device(&self.device)
.await
.map_err(be)?;
Self::discover(&self.device).await
};
let (fresh, _shape) = tokio::time::timeout(CONNECT_TIMEOUT, establish)
.await
.map_err(|_| BleError::Disconnected)??;
*self.chars.lock().await = fresh;
Ok(())
}
async fn reconnect_bounded(&self, attempts: &mut u32) -> Result<()> {
*attempts += 1;
if *attempts > MAX_OP_RECONNECTS {
return Err(BleError::Disconnected);
}
self.reconnect().await
}
async fn handle(&self, char_uuid: &str) -> Result<Characteristic> {
self.chars
.lock()
.await
.get(&char_uuid.to_ascii_lowercase())
.cloned()
.ok_or(BleError::MalformedPdu("gatt characteristic not found"))
}
async fn read_iid(&self, char_uuid: &str) -> Result<Option<u16>> {
let mut attempts = 0;
loop {
let ch = self.handle(char_uuid).await?;
let attempt = async {
let descriptors = ch.discover_descriptors().await.map_err(be)?;
let Some(desc) = descriptors.iter().find(|d| {
d.uuid()
.to_string()
.eq_ignore_ascii_case(HAP_INSTANCE_ID_DESC)
}) else {
return Ok(None);
};
Ok(u16_le(&desc.read().await.map_err(be)?))
}
.await;
match attempt {
Ok(v) => return Ok(v),
Err(ref e) if is_disconnect(e) => self.reconnect_bounded(&mut attempts).await?,
Err(e) => return Err(e),
}
}
}
}
#[async_trait]
impl GattConnection for BluestConnection {
async fn instance_id(&self, char_uuid: &str) -> Result<u16> {
self.read_iid(char_uuid)
.await?
.ok_or(BleError::MalformedPdu("no instance id descriptor"))
}
async fn max_write(&self) -> usize {
let ch = self.chars.lock().await.values().next().cloned();
ch.and_then(|c| c.max_write_len().ok())
.map_or(crate::gatt::DEFAULT_FRAGMENT_SIZE, |n| n.clamp(20, 512))
}
async fn generation(&self) -> u64 {
self.generation.load(Ordering::SeqCst)
}
async fn write(&self, char_uuid: &str, value: &[u8]) -> Result<()> {
let mut attempts = 0;
loop {
let ch = self.handle(char_uuid).await?;
match ch.write(value).await.map_err(be) {
Ok(()) => return Ok(()),
Err(ref e) if is_disconnect(e) => self.reconnect_bounded(&mut attempts).await?,
Err(e) => return Err(e),
}
}
}
async fn read(&self, char_uuid: &str) -> Result<Vec<u8>> {
let mut attempts = 0;
loop {
let ch = self.handle(char_uuid).await?;
match ch.read().await.map_err(be) {
Ok(v) => return Ok(v),
Err(ref e) if is_disconnect(e) => self.reconnect_bounded(&mut attempts).await?,
Err(e) => return Err(e),
}
}
}
async fn subscribe(&self, char_uuid: &str) -> Result<mpsc::Receiver<Vec<u8>>> {
let ch = self.handle(char_uuid).await?;
let (tx, rx) = mpsc::channel(16);
tokio::spawn(async move {
use tokio_stream::StreamExt as _;
if let Ok(mut stream) = ch.notify().await {
while let Some(item) = stream.next().await {
let Ok(v) = item else { break };
if tx.send(v).await.is_err() {
break;
}
}
}
});
Ok(rx)
}
async fn enumerate(&self) -> Result<Vec<GattService>> {
let mut services = Vec::new();
for svc in &self.shape {
let mut characteristics = Vec::new();
for char_uuid in &svc.char_uuids {
if char_uuid.eq_ignore_ascii_case(HAP_SERVICE_ID_CHAR) {
continue;
}
if char_uuid.eq_ignore_ascii_case(SERVICE_SIGNATURE_CHAR) {
continue;
}
if let Some(iid) = self.read_iid(char_uuid).await? {
characteristics.push(GattCharacteristic {
uuid: char_uuid.clone(),
iid,
});
}
}
services.push(GattService {
uuid: svc.uuid.clone(),
iid: 0,
characteristics,
});
}
Ok(services)
}
}
const APPLE_COMPANY_ID: u16 = 0x004C;
#[async_trait]
impl AdvertSource for BluestConnection {
async fn watch_adverts(&self) -> Result<mpsc::Receiver<RawAdvert>> {
let adapter = self.adapter.clone();
let (tx, rx) = mpsc::channel(32);
tokio::spawn(async move {
use tokio_stream::StreamExt as _;
let Ok(mut scan) = adapter.scan(&[]).await else {
return;
};
while let Some(adv) = scan.next().await {
let Some(md) = adv.adv_data.manufacturer_data else {
continue;
};
if md.company_id == APPLE_COMPANY_ID
&& tx
.send(RawAdvert {
manufacturer_data: md.data,
})
.await
.is_err()
{
return; }
}
});
Ok(rx)
}
}