use std::collections::HashMap;
use std::hash::Hash;
use std::io;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use att::packet as pkt;
use att::server::{
Connection as AttConnection, Error as AttError, ErrorResponse, Handler, Server as AttServer,
};
pub use att::server::{Indication, Notification};
use att::Handle;
use futures_channel::mpsc;
use futures_util::stream::StreamExt;
use crate::database::Database;
use crate::Registration;
#[derive(Debug)]
struct GattHandler<T> {
db: Database,
write_tokens: HashMap<Handle, T>,
events_txs: Vec<mpsc::UnboundedSender<Event<T>>>,
authenticated: Arc<AtomicBool>,
}
impl<T> GattHandler<T> {
fn new(
db: Database,
write_tokens: HashMap<Handle, T>,
events_txs: Vec<mpsc::UnboundedSender<Event<T>>>,
authenticated: Arc<AtomicBool>,
) -> Self {
Self {
db,
write_tokens,
events_txs,
authenticated,
}
}
fn authenticated(&self) -> bool {
self.authenticated.load(Ordering::SeqCst)
}
}
impl<T> Handler for GattHandler<T>
where
T: Clone,
{
fn handle_exchange_mtu_request(
&mut self,
item: &pkt::ExchangeMtuRequest,
) -> Result<pkt::ExchangeMtuResponse, ErrorResponse> {
Ok(pkt::ExchangeMtuResponse::new(*item.client_rx_mtu()))
}
fn handle_find_information_request(
&mut self,
item: &pkt::FindInformationRequest,
) -> Result<pkt::FindInformationResponse, ErrorResponse> {
let r = match self
.db
.find_information(item.starting_handle().clone()..=item.ending_handle().clone())
{
Ok(v) => v,
Err((h, e)) => return Err(ErrorResponse::new(h, e)),
};
Ok(r.into_iter().map(Into::into).collect())
}
fn handle_find_by_type_value_request(
&mut self,
item: &pkt::FindByTypeValueRequest,
) -> Result<pkt::FindByTypeValueResponse, ErrorResponse> {
let r = match self.db.find_by_type_value(
item.starting_handle().clone()..=item.ending_handle().clone(),
item.attribute_type(),
item.attribute_value(),
false,
self.authenticated(),
) {
Ok(v) => v,
Err((h, e)) => return Err(ErrorResponse::new(h, e)),
};
Ok(r.into_iter().map(Into::into).collect())
}
fn handle_read_by_type_request(
&mut self,
item: &pkt::ReadByTypeRequest,
) -> Result<pkt::ReadByTypeResponse, ErrorResponse> {
let r = match self.db.read_by_type(
item.starting_handle().clone()..=item.ending_handle().clone(),
item.attribute_type(),
false,
self.authenticated(),
) {
Ok(v) => v,
Err((h, e)) => return Err(ErrorResponse::new(h, e)),
};
Ok(r.into_iter().map(Into::into).collect())
}
fn handle_read_request(
&mut self,
item: &pkt::ReadRequest,
) -> Result<pkt::ReadResponse, ErrorResponse> {
let r = match self
.db
.read(item.attribute_handle(), false, self.authenticated())
{
Ok(v) => v,
Err((h, e)) => return Err(ErrorResponse::new(h, e)),
};
Ok(pkt::ReadResponse::new(r))
}
fn handle_read_blob_request(
&mut self,
item: &pkt::ReadBlobRequest,
) -> Result<pkt::ReadBlobResponse, ErrorResponse> {
let r = match self
.db
.read(item.attribute_handle(), false, self.authenticated())
{
Ok(v) => v,
Err((h, e)) => return Err(ErrorResponse::new(h, e)),
};
let offset = *item.attribute_offset() as usize;
Ok(pkt::ReadBlobResponse::new(r[offset..].into()))
}
fn handle_read_by_group_type_request(
&mut self,
item: &pkt::ReadByGroupTypeRequest,
) -> Result<pkt::ReadByGroupTypeResponse, ErrorResponse> {
let r = match self.db.read_by_group_type(
item.starting_handle().clone()..=item.ending_handle().clone(),
item.attribute_group_type(),
false,
self.authenticated(),
) {
Ok(v) => v,
Err((h, e)) => return Err(ErrorResponse::new(h, e)),
};
Ok(r.into_iter().map(Into::into).collect())
}
fn handle_write_request(
&mut self,
item: &pkt::WriteRequest,
) -> Result<pkt::WriteResponse, ErrorResponse> {
let value = item.attribute_value();
if let Some(token) = self.write_tokens.get(item.attribute_handle()) {
for tx in &self.events_txs {
tx.unbounded_send(Event::Write(token.clone(), value.to_vec().into()))
.ok();
}
}
match self.db.write(item.attribute_handle(), value, false, false) {
Ok(_) => Ok(pkt::WriteResponse::new()),
Err((h, e)) => Err(ErrorResponse::new(h, e)),
}
}
fn handle_write_command(&mut self, item: &pkt::WriteCommand) {
let value = item.attribute_value();
if let Some(token) = self.write_tokens.get(item.attribute_handle()) {
for tx in &self.events_txs {
tx.unbounded_send(Event::Write(token.clone(), value.to_vec().into()))
.ok();
}
}
if let Err(err) = self.db.write(
item.attribute_handle(),
item.attribute_value(),
false,
false,
) {
log::warn!("{:?}", err);
};
}
fn handle_signed_write_command(&mut self, item: &pkt::SignedWriteCommand) {
let value = item.attribute_value();
if let Some(token) = self.write_tokens.get(item.attribute_handle()) {
for tx in &self.events_txs {
tx.unbounded_send(Event::Write(token.clone(), value.to_vec().into()))
.ok();
}
}
if let Err(err) =
self.db
.write(item.attribute_handle(), item.attribute_value(), false, true)
{
log::warn!("{:?}", err);
};
}
}
#[derive(Debug, thiserror::Error)]
#[error("channel error")]
pub struct ChannelError;
#[derive(Debug)]
pub struct Authenticator {
authenticated: Arc<AtomicBool>,
}
impl Authenticator {
pub fn mark_authenticated(&self) {
self.authenticated.store(true, Ordering::SeqCst);
}
}
#[derive(Debug)]
pub enum Event<T> {
Write(T, Box<[u8]>),
}
#[derive(Debug)]
pub struct Events<T>(mpsc::UnboundedReceiver<Event<T>>);
impl<T> Events<T> {
pub async fn next(&mut self) -> Option<Event<T>> {
self.0.next().await
}
}
#[derive(Debug, thiserror::Error)]
#[error("handle not found.")]
pub struct HandleNotFound;
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct RunError(#[from] AttError);
pub struct Connection<T> {
inner: AttConnection,
event_txs: Vec<mpsc::UnboundedSender<Event<T>>>,
db: Database,
write_tokens: HashMap<Handle, T>,
notify_or_indicate_handles: HashMap<T, Handle>,
authenticated: Arc<AtomicBool>, }
impl<T> Connection<T>
where
T: Eq + Hash + Clone,
{
fn new(inner: AttConnection, registration: Registration<T>) -> Self {
let (db, write_tokens, notify_or_indicate_handles) = registration.build();
Self {
inner,
event_txs: vec![],
db,
write_tokens,
notify_or_indicate_handles,
authenticated: Arc::new(AtomicBool::from(false)),
}
}
pub fn authenticator(&self) -> Authenticator {
Authenticator {
authenticated: self.authenticated.clone(),
}
}
pub fn events(&mut self) -> Events<T> {
let (tx, rx) = mpsc::unbounded();
self.event_txs.push(tx);
Events(rx)
}
pub fn notification(&self, token: &T) -> Result<Notification, HandleNotFound> {
if let Some(handle) = self.notify_or_indicate_handles.get(token) {
let notification = self.inner.notification(handle.clone());
Ok(notification)
} else {
Err(HandleNotFound)
}
}
pub fn indication(&self, token: &T) -> Result<Indication, HandleNotFound> {
if let Some(handle) = self.notify_or_indicate_handles.get(token) {
let indication = self.inner.indication(handle.clone());
Ok(indication)
} else {
Err(HandleNotFound)
}
}
pub fn address(&self) -> &att::Address {
self.inner.address()
}
pub async fn run(self) -> Result<(), RunError> {
let Self {
db,
write_tokens,
event_txs,
authenticated,
..
} = self;
self.inner
.run(GattHandler::<T>::new(
db,
write_tokens,
event_txs,
authenticated,
))
.await?;
Ok(())
}
}
pub struct Server {
inner: AttServer,
}
impl Server {
pub fn bind() -> io::Result<Self> {
let server = AttServer::new()?;
Ok(Self { inner: server })
}
pub async fn accept<T>(
&mut self,
registration: Registration<T>,
) -> io::Result<Option<Connection<T>>>
where
T: Eq + Hash + Clone,
{
if let Some((connection, _)) = self.inner.accept().await? {
Ok(Some(Connection::new(connection, registration)))
} else {
Ok(None)
}
}
pub fn needs_bond(&self) -> io::Result<()> {
self.inner.needs_bond()?;
Ok(())
}
pub fn needs_bond_mitm(&self) -> io::Result<()> {
self.inner.needs_bond_mitm()?;
Ok(())
}
}