att/
server.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll, Waker};
6
7use futures_channel::oneshot;
8use futures_core::ready;
9use futures_core::stream::Stream;
10use futures_sink::Sink;
11use futures_util::future::FutureExt;
12use futures_util::lock::{Mutex, MutexGuard};
13use futures_util::sink::SinkExt;
14use futures_util::stream::{StreamExt, TryStreamExt};
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17use crate::packet as pkt;
18use crate::sock::{AttListener, AttStream};
19use crate::Handle;
20pub use crate::{ErrorResponse, Handler};
21use pkt::pack::{self, Unpack};
22
23const DEFAULT_MTU: usize = 23;
24
25#[derive(Debug, thiserror::Error)]
26pub enum Error {
27    #[error(transparent)]
28    Io(#[from] io::Error),
29
30    #[error(transparent)]
31    Pack(#[from] pack::Error),
32}
33
34type Result<R> = std::result::Result<R, Error>;
35
36struct PacketStream<R> {
37    inner: R,
38    rxbuf: Box<[u8]>,
39    txbuf: Box<[u8]>,
40    txlen: usize,
41    txwaker: Vec<Waker>,
42}
43
44impl<R> PacketStream<R> {
45    fn new(inner: R) -> Self {
46        Self {
47            inner,
48            rxbuf: [0; DEFAULT_MTU].into(),
49            txbuf: [0; DEFAULT_MTU].into(),
50            txlen: 0,
51            txwaker: vec![],
52        }
53    }
54
55    fn txmtu(&self) -> usize {
56        self.txbuf.len()
57    }
58
59    fn set_txmtu(&mut self, mtu: usize) {
60        let mut buf = vec![0; mtu];
61        let len = mtu.min(self.txbuf.len());
62        (&mut buf[..len]).copy_from_slice(&self.txbuf[..len]);
63        self.txbuf = buf.into();
64    }
65
66    fn set_rxmtu(&mut self, mtu: usize) {
67        let mut buf = vec![0; mtu];
68        let len = mtu.min(self.rxbuf.len());
69        (&mut buf[..len]).copy_from_slice(&self.rxbuf[..len]);
70        self.rxbuf = buf.into();
71    }
72}
73
74impl<R> Stream for PacketStream<R>
75where
76    R: AsyncRead + Unpin,
77{
78    type Item = Result<pkt::DeviceRecv>;
79
80    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        let Self { inner, rxbuf, .. } = self.get_mut();
82
83        let mut buf = ReadBuf::new(rxbuf);
84        ready!(Pin::new(inner).poll_read(cx, &mut buf))?;
85        let mut filled = buf.filled();
86        if filled.is_empty() {
87            Poll::Ready(None)
88        } else {
89            let item = Unpack::unpack(&mut filled)?;
90            log::trace!("packet recv {:?}", item);
91            Poll::Ready(Some(Ok(item)))
92        }
93    }
94}
95
96impl<W, S> Sink<S> for PacketStream<W>
97where
98    W: AsyncWrite + Unpin,
99    S: pkt::DeviceSend,
100{
101    type Error = Error;
102
103    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
104        let Self { txlen, txwaker, .. } = self.get_mut();
105        if *txlen != 0 {
106            txwaker.push(cx.waker().clone());
107            Poll::Pending
108        } else {
109            Poll::Ready(Ok(()))
110        }
111    }
112
113    fn start_send(self: Pin<&mut Self>, item: S) -> Result<()> {
114        let Self { txlen, txbuf, .. } = self.get_mut();
115        log::trace!("packet send {:?}", item);
116
117        let mut write = txbuf.as_mut();
118        let len = write.len();
119        item.pack_with_code(&mut write)?;
120        *txlen = len - write.len();
121        Ok(())
122    }
123
124    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
125        let Self {
126            inner,
127            txlen,
128            txbuf,
129            ..
130        } = self.get_mut();
131
132        while *txlen != 0 {
133            *txlen -= ready!(Pin::new(&mut *inner).poll_write(cx, &txbuf[..*txlen]))?;
134        }
135        ready!(Pin::new(&mut *inner).poll_flush(cx))?;
136        Poll::Ready(Ok(()))
137    }
138
139    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
140        let this = self.get_mut();
141        ready!(Sink::<S>::poll_flush(Pin::new(this), cx))?;
142        ready!(Pin::new(&mut this.inner).poll_shutdown(cx))?;
143        Poll::Ready(Ok(()))
144    }
145}
146
147struct Inner<IO> {
148    stream: PacketStream<IO>,
149    await_confirmation: Option<oneshot::Sender<()>>,
150    // TODO used notification / indication handles
151}
152
153impl<IO> Inner<IO> {
154    fn new(io: IO) -> Self {
155        Self {
156            stream: PacketStream::new(io),
157            await_confirmation: Default::default(),
158        }
159    }
160}
161
162enum NotificationState {
163    Write,
164    NeedFlush(usize),
165}
166
167struct NotificationInner<IO> {
168    handle: Handle,
169    inner: Arc<Mutex<Inner<IO>>>,
170    state: NotificationState,
171}
172
173impl<IO> AsyncWrite for NotificationInner<IO>
174where
175    IO: AsyncWrite + Unpin,
176{
177    fn poll_write(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &[u8],
181    ) -> Poll<io::Result<usize>> {
182        let Self {
183            handle,
184            inner,
185            state,
186            ..
187        } = self.get_mut();
188        let mut guard = ready!(inner.lock().poll_unpin(cx));
189
190        loop {
191            match &state {
192                NotificationState::Write => {
193                    if let Err(err) =
194                        ready!(Sink::<pkt::HandleValueNotificationBorrow>::poll_ready(
195                            Pin::new(&mut guard.stream),
196                            cx
197                        ))
198                    {
199                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
200                    }
201                    let item = pkt::HandleValueNotificationBorrow::new(handle.clone(), buf);
202                    if let Err(err) = guard.stream.start_send_unpin(item) {
203                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
204                    }
205                    *state = NotificationState::NeedFlush(buf.len());
206                }
207
208                NotificationState::NeedFlush(len) => {
209                    if let Err(err) =
210                        ready!(Sink::<pkt::HandleValueNotificationBorrow>::poll_flush(
211                            Pin::new(&mut guard.stream),
212                            cx
213                        ))
214                    {
215                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
216                    }
217                    let len = *len;
218                    *state = NotificationState::Write;
219                    return Poll::Ready(Ok(len));
220                }
221            }
222        }
223    }
224
225    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
226        Poll::Ready(Ok(()))
227    }
228
229    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230        let Self { inner, .. } = self.get_mut();
231        let mut guard = ready!(inner.lock().poll_unpin(cx));
232        if let Err(err) = ready!(Sink::<pkt::HandleValueNotificationBorrow>::poll_close(
233            Pin::new(&mut guard.stream),
234            cx,
235        )) {
236            return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
237        }
238        Poll::Ready(Ok(()))
239    }
240}
241
242enum IndicationState {
243    Write,
244    NeedFlush(usize),
245    AwaitConfirmation(usize, oneshot::Receiver<()>),
246}
247
248struct IndicationInner<IO> {
249    handle: Handle,
250    inner: Arc<Mutex<Inner<IO>>>,
251    state: IndicationState,
252}
253
254impl<IO> AsyncWrite for IndicationInner<IO>
255where
256    IO: AsyncWrite + Unpin,
257{
258    fn poll_write(
259        self: Pin<&mut Self>,
260        cx: &mut Context<'_>,
261        buf: &[u8],
262    ) -> Poll<io::Result<usize>> {
263        let Self {
264            state,
265            handle,
266            inner,
267            ..
268        } = self.get_mut();
269        let mut guard = ready!(inner.lock().poll_unpin(cx));
270
271        loop {
272            match state {
273                IndicationState::Write => {
274                    if let Err(err) = ready!(Sink::<pkt::HandleValueIndicationBorrow>::poll_ready(
275                        Pin::new(&mut guard.stream),
276                        cx
277                    )) {
278                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
279                    }
280                    let item = pkt::HandleValueIndicationBorrow::new(handle.clone(), buf);
281                    if let Err(err) = guard.stream.start_send_unpin(item) {
282                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
283                    }
284                    *state = IndicationState::NeedFlush(buf.len());
285                }
286
287                IndicationState::NeedFlush(len) => {
288                    if let Err(err) = ready!(Sink::<pkt::HandleValueIndicationBorrow>::poll_flush(
289                        Pin::new(&mut guard.stream),
290                        cx
291                    )) {
292                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
293                    }
294                    let (tx, rx) = oneshot::channel();
295                    guard.await_confirmation = Some(tx); // TODO check existence
296                    *state = IndicationState::AwaitConfirmation(*len, rx);
297                }
298
299                IndicationState::AwaitConfirmation(len, rx) => {
300                    if let Err(err) = ready!(rx.poll_unpin(cx)) {
301                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
302                    }
303                    let len = *len;
304                    *state = IndicationState::Write;
305                    return Poll::Ready(Ok(len));
306                }
307            }
308        }
309    }
310
311    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
312        Poll::Ready(Ok(()))
313    }
314
315    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
316        let Self { inner, .. } = self.get_mut();
317        let mut guard = ready!(inner.lock().poll_unpin(cx));
318        if let Err(err) = ready!(Sink::<pkt::HandleValueIndicationBorrow>::poll_close(
319            Pin::new(&mut guard.stream),
320            cx
321        )) {
322            return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
323        }
324        Poll::Ready(Ok(()))
325    }
326}
327
328struct TryLockNext<'a, IO> {
329    inner: &'a Mutex<Inner<IO>>,
330}
331
332impl<'a, IO> Future for TryLockNext<'a, IO>
333where
334    IO: AsyncRead + Unpin,
335{
336    type Output = (MutexGuard<'a, Inner<IO>>, Option<Result<pkt::DeviceRecv>>);
337
338    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
339        let Self { inner } = self.get_mut();
340
341        let mut guard = ready!(inner.lock().poll_unpin(cx));
342        let item = ready!(guard.stream.poll_next_unpin(cx));
343        Poll::Ready((guard, item))
344    }
345}
346
347async fn respond<IO, R>(
348    stream: &mut PacketStream<IO>,
349    r: std::result::Result<R::Response, crate::handler::ErrorResponse>,
350) -> Result<()>
351where
352    IO: AsyncWrite + Unpin,
353    R: pkt::Request,
354{
355    let mtu = stream.txmtu();
356    match r {
357        Ok(mut r) => {
358            pkt::Response::truncate(&mut r, mtu);
359            stream.send(r).await?;
360        }
361        Err(crate::ErrorResponse(handle, code)) => {
362            let err = pkt::ErrorResponse::new(R::opcode(), handle, code);
363            stream.send(err).await?;
364        }
365    }
366    Ok(())
367}
368
369async fn handle<IO, H>(
370    inner: &mut Inner<IO>,
371    handler: &mut H,
372    request: pkt::DeviceRecv,
373) -> Result<()>
374where
375    IO: AsyncWrite + Unpin,
376    H: crate::Handler,
377{
378    match request {
379        pkt::DeviceRecv::ExchangeMtuRequest(item) => {
380            let response = handler.handle_exchange_mtu_request(&item);
381            if let Ok(response) = &response {
382                let client_rx_mtu = *item.client_rx_mtu() as usize;
383                let server_rx_mtu = *response.server_rx_mtu() as usize;
384                inner.stream.set_txmtu(client_rx_mtu);
385                inner.stream.set_rxmtu(server_rx_mtu);
386            }
387            respond::<_, pkt::ExchangeMtuRequest>(&mut inner.stream, response).await?;
388        }
389
390        pkt::DeviceRecv::FindInformationRequest(item) => {
391            let response = handler.handle_find_information_request(&item);
392            respond::<_, pkt::FindInformationRequest>(&mut inner.stream, response).await?;
393        }
394
395        pkt::DeviceRecv::FindByTypeValueRequest(item) => {
396            let response = handler.handle_find_by_type_value_request(&item);
397            respond::<_, pkt::FindByTypeValueRequest>(&mut inner.stream, response).await?;
398        }
399
400        pkt::DeviceRecv::ReadByTypeRequest(item) => {
401            let response = handler.handle_read_by_type_request(&item);
402            respond::<_, pkt::ReadByTypeRequest>(&mut inner.stream, response).await?;
403        }
404
405        pkt::DeviceRecv::ReadRequest(item) => {
406            let response = handler.handle_read_request(&item);
407            respond::<_, pkt::ReadRequest>(&mut inner.stream, response).await?;
408        }
409
410        pkt::DeviceRecv::ReadBlobRequest(item) => {
411            let response = handler.handle_read_blob_request(&item);
412            respond::<_, pkt::ReadBlobRequest>(&mut inner.stream, response).await?;
413        }
414
415        pkt::DeviceRecv::ReadMultipleRequest(item) => {
416            let response = handler.handle_read_multiple_request(&item);
417            respond::<_, pkt::ReadMultipleRequest>(&mut inner.stream, response).await?;
418        }
419
420        pkt::DeviceRecv::ReadByGroupTypeRequest(item) => {
421            let response = handler.handle_read_by_group_type_request(&item);
422            respond::<_, pkt::ReadByGroupTypeRequest>(&mut inner.stream, response).await?;
423        }
424
425        pkt::DeviceRecv::WriteRequest(item) => {
426            let response = handler.handle_write_request(&item);
427            respond::<_, pkt::WriteRequest>(&mut inner.stream, response).await?;
428        }
429
430        pkt::DeviceRecv::WriteCommand(item) => {
431            handler.handle_write_command(&item);
432        }
433
434        pkt::DeviceRecv::PrepareWriteRequest(item) => {
435            let response = handler.handle_prepare_write_request(&item);
436            respond::<_, pkt::PrepareWriteRequest>(&mut inner.stream, response).await?;
437        }
438
439        pkt::DeviceRecv::ExecuteWriteRequest(item) => {
440            let response = handler.handle_execute_write_request(&item);
441            respond::<_, pkt::ExecuteWriteRequest>(&mut inner.stream, response).await?;
442        }
443
444        pkt::DeviceRecv::SignedWriteCommand(item) => {
445            handler.handle_signed_write_command(&item);
446        }
447
448        pkt::DeviceRecv::HandleValueConfirmation(..) => {
449            if let Some(channel) = inner.await_confirmation.take() {
450                channel.send(()).ok();
451            }
452        }
453    }
454    Ok(())
455}
456
457struct ConnectionInner<IO> {
458    inner: Arc<Mutex<Inner<IO>>>,
459}
460
461impl<IO> ConnectionInner<IO>
462where
463    IO: AsyncRead + AsyncWrite + Unpin,
464{
465    fn notification(&self, handle: Handle) -> NotificationInner<IO> {
466        NotificationInner {
467            handle,
468            inner: self.inner.clone(),
469            state: NotificationState::Write,
470        }
471    }
472
473    fn indication(&self, handle: Handle) -> IndicationInner<IO> {
474        IndicationInner {
475            handle,
476            inner: self.inner.clone(),
477            state: IndicationState::Write,
478        }
479    }
480
481    async fn run<H>(self, mut handler: H) -> Result<()>
482    where
483        H: crate::Handler,
484    {
485        loop {
486            let (mut guard, request) = TryLockNext { inner: &self.inner }.await;
487            let request = if let Some(request) = request {
488                request?
489            } else {
490                return Ok(());
491            };
492
493            handle(&mut *guard, &mut handler, request).await?;
494        }
495    }
496}
497
498pub struct Notification {
499    inner: NotificationInner<AttStream>,
500}
501
502impl AsyncWrite for Notification {
503    fn poll_write(
504        self: Pin<&mut Self>,
505        cx: &mut Context<'_>,
506        buf: &[u8],
507    ) -> Poll<io::Result<usize>> {
508        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
509    }
510
511    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
512        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
513    }
514
515    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
516        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
517    }
518}
519
520pub struct Indication {
521    inner: IndicationInner<AttStream>,
522}
523
524impl AsyncWrite for Indication {
525    fn poll_write(
526        self: Pin<&mut Self>,
527        cx: &mut Context<'_>,
528        buf: &[u8],
529    ) -> Poll<io::Result<usize>> {
530        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
531    }
532
533    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
534        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
535    }
536
537    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
538        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
539    }
540}
541
542struct ServerInner<L> {
543    inner: L,
544}
545
546impl<L, IO> ServerInner<L>
547where
548    L: Stream<Item = io::Result<(IO, socket2::SockAddr)>> + Unpin,
549    IO: AsyncRead + AsyncWrite + Unpin,
550{
551    async fn accept(&mut self) -> io::Result<Option<(ConnectionInner<IO>, socket2::SockAddr)>> {
552        if let Some((sock, addr)) = self.inner.try_next().await? {
553            return Ok(Some((
554                ConnectionInner {
555                    inner: Arc::new(Mutex::new(Inner::new(sock))),
556                },
557                addr,
558            )));
559        }
560        Ok(None)
561    }
562}
563
564pub struct Connection {
565    inner: ConnectionInner<AttStream>,
566    addr: crate::Address,
567}
568
569impl Connection {
570    pub fn address(&self) -> &crate::Address {
571        &self.addr
572    }
573
574    pub fn notification(&self, handle: Handle) -> Notification {
575        Notification {
576            inner: self.inner.notification(handle),
577        }
578    }
579
580    pub fn indication(&self, handle: Handle) -> Indication {
581        Indication {
582            inner: self.inner.indication(handle),
583        }
584    }
585
586    pub async fn run<H>(self, handler: H) -> Result<()>
587    where
588        H: crate::Handler,
589    {
590        log::debug!("Start serving.");
591        self.inner.run(handler).await?;
592        log::debug!("Done serving.");
593        Ok(())
594    }
595}
596
597pub struct Server {
598    inner: ServerInner<AttListener>,
599}
600
601impl Server {
602    /// Constract Instance.
603    pub fn new() -> io::Result<Self> {
604        let sock = AttListener::new()?;
605        Ok(Self {
606            inner: ServerInner { inner: sock },
607        })
608    }
609
610    pub fn needs_bond(&self) -> io::Result<()> {
611        self.inner
612            .inner
613            .set_sockopt_bt_security(crate::sock::BT_SECURITY_MEDIUM, 0)
614    }
615
616    pub fn needs_bond_mitm(&self) -> io::Result<()> {
617        self.inner
618            .inner
619            .set_sockopt_bt_security(crate::sock::BT_SECURITY_HIGH, 0)
620    }
621
622    pub async fn accept(&mut self) -> io::Result<Option<(Connection, crate::Address)>> {
623        if let Some((connection, addr)) = self.inner.accept().await? {
624            log::debug!("Connection accepted.");
625            let addr = crate::sock::try_from(addr)?;
626            Ok(Some((
627                Connection {
628                    inner: connection,
629                    addr: addr.clone(),
630                },
631                addr,
632            )))
633        } else {
634            Ok(None)
635        }
636    }
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use std::convert::TryFrom;
643    use tokio::io::AsyncWriteExt;
644    use tokio_test::io::Builder;
645
646    #[tokio::test]
647    async fn test_stream() {
648        let stream = Builder::new()
649            .read(&[0x02, 0x17, 0x00])
650            .write(&[0x03, 0x18, 0x00])
651            .build();
652        let mut stream = PacketStream::new(stream);
653        let packet = stream.try_next().await.unwrap().unwrap();
654        let packet = pkt::ExchangeMtuRequest::try_from(packet).unwrap();
655        assert_eq!(*packet.client_rx_mtu(), 23);
656
657        let packet = pkt::ExchangeMtuResponse::new(0x0018);
658        stream.send(packet).await.unwrap();
659    }
660
661    #[tokio::test]
662    async fn test_connection() {
663        struct H;
664        impl Handler for H {}
665
666        let stream = Builder::new()
667            .write(&[0x1B, 0x01, 0x00, 0x6F, 0x6B])
668            .read(&[0x02, 0x17, 0x00])
669            .write(&[0x03, 0x17, 0x00])
670            .build();
671        let connection = ConnectionInner {
672            inner: Arc::new(Mutex::new(Inner::new(stream))),
673        };
674
675        let mut notification = connection.notification(Handle::new(1));
676        notification.write_all(b"ok").await.unwrap();
677        connection.run(H).await.unwrap();
678    }
679
680    #[tokio::test]
681    async fn test_indication() {
682        struct H;
683        impl Handler for H {}
684
685        let stream = Builder::new()
686            .write(&[0x1D, 0x01, 0x00, 0x6F, 0x6B])
687            .read(&[0x1E, 0x17, 0x00])
688            .build();
689        let connection = ConnectionInner {
690            inner: Arc::new(Mutex::new(Inner::new(stream))),
691        };
692
693        let mut indication = connection.indication(Handle::new(1));
694        let task = tokio::spawn(connection.run(H));
695
696        indication.write_all(b"ok").await.unwrap();
697
698        task.await.unwrap().unwrap();
699    }
700}