1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll, Waker};
6
7use futures_channel::oneshot;
8use futures_core::ready;
9use futures_core::stream::Stream;
10use futures_sink::Sink;
11use futures_util::future::FutureExt;
12use futures_util::lock::{Mutex, MutexGuard};
13use futures_util::sink::SinkExt;
14use futures_util::stream::{StreamExt, TryStreamExt};
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17use crate::packet as pkt;
18use crate::sock::{AttListener, AttStream};
19use crate::Handle;
20pub use crate::{ErrorResponse, Handler};
21use pkt::pack::{self, Unpack};
22
23const DEFAULT_MTU: usize = 23;
24
25#[derive(Debug, thiserror::Error)]
26pub enum Error {
27 #[error(transparent)]
28 Io(#[from] io::Error),
29
30 #[error(transparent)]
31 Pack(#[from] pack::Error),
32}
33
34type Result<R> = std::result::Result<R, Error>;
35
36struct PacketStream<R> {
37 inner: R,
38 rxbuf: Box<[u8]>,
39 txbuf: Box<[u8]>,
40 txlen: usize,
41 txwaker: Vec<Waker>,
42}
43
44impl<R> PacketStream<R> {
45 fn new(inner: R) -> Self {
46 Self {
47 inner,
48 rxbuf: [0; DEFAULT_MTU].into(),
49 txbuf: [0; DEFAULT_MTU].into(),
50 txlen: 0,
51 txwaker: vec![],
52 }
53 }
54
55 fn txmtu(&self) -> usize {
56 self.txbuf.len()
57 }
58
59 fn set_txmtu(&mut self, mtu: usize) {
60 let mut buf = vec![0; mtu];
61 let len = mtu.min(self.txbuf.len());
62 (&mut buf[..len]).copy_from_slice(&self.txbuf[..len]);
63 self.txbuf = buf.into();
64 }
65
66 fn set_rxmtu(&mut self, mtu: usize) {
67 let mut buf = vec![0; mtu];
68 let len = mtu.min(self.rxbuf.len());
69 (&mut buf[..len]).copy_from_slice(&self.rxbuf[..len]);
70 self.rxbuf = buf.into();
71 }
72}
73
74impl<R> Stream for PacketStream<R>
75where
76 R: AsyncRead + Unpin,
77{
78 type Item = Result<pkt::DeviceRecv>;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let Self { inner, rxbuf, .. } = self.get_mut();
82
83 let mut buf = ReadBuf::new(rxbuf);
84 ready!(Pin::new(inner).poll_read(cx, &mut buf))?;
85 let mut filled = buf.filled();
86 if filled.is_empty() {
87 Poll::Ready(None)
88 } else {
89 let item = Unpack::unpack(&mut filled)?;
90 log::trace!("packet recv {:?}", item);
91 Poll::Ready(Some(Ok(item)))
92 }
93 }
94}
95
96impl<W, S> Sink<S> for PacketStream<W>
97where
98 W: AsyncWrite + Unpin,
99 S: pkt::DeviceSend,
100{
101 type Error = Error;
102
103 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
104 let Self { txlen, txwaker, .. } = self.get_mut();
105 if *txlen != 0 {
106 txwaker.push(cx.waker().clone());
107 Poll::Pending
108 } else {
109 Poll::Ready(Ok(()))
110 }
111 }
112
113 fn start_send(self: Pin<&mut Self>, item: S) -> Result<()> {
114 let Self { txlen, txbuf, .. } = self.get_mut();
115 log::trace!("packet send {:?}", item);
116
117 let mut write = txbuf.as_mut();
118 let len = write.len();
119 item.pack_with_code(&mut write)?;
120 *txlen = len - write.len();
121 Ok(())
122 }
123
124 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
125 let Self {
126 inner,
127 txlen,
128 txbuf,
129 ..
130 } = self.get_mut();
131
132 while *txlen != 0 {
133 *txlen -= ready!(Pin::new(&mut *inner).poll_write(cx, &txbuf[..*txlen]))?;
134 }
135 ready!(Pin::new(&mut *inner).poll_flush(cx))?;
136 Poll::Ready(Ok(()))
137 }
138
139 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
140 let this = self.get_mut();
141 ready!(Sink::<S>::poll_flush(Pin::new(this), cx))?;
142 ready!(Pin::new(&mut this.inner).poll_shutdown(cx))?;
143 Poll::Ready(Ok(()))
144 }
145}
146
147struct Inner<IO> {
148 stream: PacketStream<IO>,
149 await_confirmation: Option<oneshot::Sender<()>>,
150 }
152
153impl<IO> Inner<IO> {
154 fn new(io: IO) -> Self {
155 Self {
156 stream: PacketStream::new(io),
157 await_confirmation: Default::default(),
158 }
159 }
160}
161
162enum NotificationState {
163 Write,
164 NeedFlush(usize),
165}
166
167struct NotificationInner<IO> {
168 handle: Handle,
169 inner: Arc<Mutex<Inner<IO>>>,
170 state: NotificationState,
171}
172
173impl<IO> AsyncWrite for NotificationInner<IO>
174where
175 IO: AsyncWrite + Unpin,
176{
177 fn poll_write(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<io::Result<usize>> {
182 let Self {
183 handle,
184 inner,
185 state,
186 ..
187 } = self.get_mut();
188 let mut guard = ready!(inner.lock().poll_unpin(cx));
189
190 loop {
191 match &state {
192 NotificationState::Write => {
193 if let Err(err) =
194 ready!(Sink::<pkt::HandleValueNotificationBorrow>::poll_ready(
195 Pin::new(&mut guard.stream),
196 cx
197 ))
198 {
199 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
200 }
201 let item = pkt::HandleValueNotificationBorrow::new(handle.clone(), buf);
202 if let Err(err) = guard.stream.start_send_unpin(item) {
203 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
204 }
205 *state = NotificationState::NeedFlush(buf.len());
206 }
207
208 NotificationState::NeedFlush(len) => {
209 if let Err(err) =
210 ready!(Sink::<pkt::HandleValueNotificationBorrow>::poll_flush(
211 Pin::new(&mut guard.stream),
212 cx
213 ))
214 {
215 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
216 }
217 let len = *len;
218 *state = NotificationState::Write;
219 return Poll::Ready(Ok(len));
220 }
221 }
222 }
223 }
224
225 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
226 Poll::Ready(Ok(()))
227 }
228
229 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230 let Self { inner, .. } = self.get_mut();
231 let mut guard = ready!(inner.lock().poll_unpin(cx));
232 if let Err(err) = ready!(Sink::<pkt::HandleValueNotificationBorrow>::poll_close(
233 Pin::new(&mut guard.stream),
234 cx,
235 )) {
236 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
237 }
238 Poll::Ready(Ok(()))
239 }
240}
241
242enum IndicationState {
243 Write,
244 NeedFlush(usize),
245 AwaitConfirmation(usize, oneshot::Receiver<()>),
246}
247
248struct IndicationInner<IO> {
249 handle: Handle,
250 inner: Arc<Mutex<Inner<IO>>>,
251 state: IndicationState,
252}
253
254impl<IO> AsyncWrite for IndicationInner<IO>
255where
256 IO: AsyncWrite + Unpin,
257{
258 fn poll_write(
259 self: Pin<&mut Self>,
260 cx: &mut Context<'_>,
261 buf: &[u8],
262 ) -> Poll<io::Result<usize>> {
263 let Self {
264 state,
265 handle,
266 inner,
267 ..
268 } = self.get_mut();
269 let mut guard = ready!(inner.lock().poll_unpin(cx));
270
271 loop {
272 match state {
273 IndicationState::Write => {
274 if let Err(err) = ready!(Sink::<pkt::HandleValueIndicationBorrow>::poll_ready(
275 Pin::new(&mut guard.stream),
276 cx
277 )) {
278 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
279 }
280 let item = pkt::HandleValueIndicationBorrow::new(handle.clone(), buf);
281 if let Err(err) = guard.stream.start_send_unpin(item) {
282 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
283 }
284 *state = IndicationState::NeedFlush(buf.len());
285 }
286
287 IndicationState::NeedFlush(len) => {
288 if let Err(err) = ready!(Sink::<pkt::HandleValueIndicationBorrow>::poll_flush(
289 Pin::new(&mut guard.stream),
290 cx
291 )) {
292 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
293 }
294 let (tx, rx) = oneshot::channel();
295 guard.await_confirmation = Some(tx); *state = IndicationState::AwaitConfirmation(*len, rx);
297 }
298
299 IndicationState::AwaitConfirmation(len, rx) => {
300 if let Err(err) = ready!(rx.poll_unpin(cx)) {
301 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
302 }
303 let len = *len;
304 *state = IndicationState::Write;
305 return Poll::Ready(Ok(len));
306 }
307 }
308 }
309 }
310
311 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
312 Poll::Ready(Ok(()))
313 }
314
315 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
316 let Self { inner, .. } = self.get_mut();
317 let mut guard = ready!(inner.lock().poll_unpin(cx));
318 if let Err(err) = ready!(Sink::<pkt::HandleValueIndicationBorrow>::poll_close(
319 Pin::new(&mut guard.stream),
320 cx
321 )) {
322 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)));
323 }
324 Poll::Ready(Ok(()))
325 }
326}
327
328struct TryLockNext<'a, IO> {
329 inner: &'a Mutex<Inner<IO>>,
330}
331
332impl<'a, IO> Future for TryLockNext<'a, IO>
333where
334 IO: AsyncRead + Unpin,
335{
336 type Output = (MutexGuard<'a, Inner<IO>>, Option<Result<pkt::DeviceRecv>>);
337
338 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
339 let Self { inner } = self.get_mut();
340
341 let mut guard = ready!(inner.lock().poll_unpin(cx));
342 let item = ready!(guard.stream.poll_next_unpin(cx));
343 Poll::Ready((guard, item))
344 }
345}
346
347async fn respond<IO, R>(
348 stream: &mut PacketStream<IO>,
349 r: std::result::Result<R::Response, crate::handler::ErrorResponse>,
350) -> Result<()>
351where
352 IO: AsyncWrite + Unpin,
353 R: pkt::Request,
354{
355 let mtu = stream.txmtu();
356 match r {
357 Ok(mut r) => {
358 pkt::Response::truncate(&mut r, mtu);
359 stream.send(r).await?;
360 }
361 Err(crate::ErrorResponse(handle, code)) => {
362 let err = pkt::ErrorResponse::new(R::opcode(), handle, code);
363 stream.send(err).await?;
364 }
365 }
366 Ok(())
367}
368
369async fn handle<IO, H>(
370 inner: &mut Inner<IO>,
371 handler: &mut H,
372 request: pkt::DeviceRecv,
373) -> Result<()>
374where
375 IO: AsyncWrite + Unpin,
376 H: crate::Handler,
377{
378 match request {
379 pkt::DeviceRecv::ExchangeMtuRequest(item) => {
380 let response = handler.handle_exchange_mtu_request(&item);
381 if let Ok(response) = &response {
382 let client_rx_mtu = *item.client_rx_mtu() as usize;
383 let server_rx_mtu = *response.server_rx_mtu() as usize;
384 inner.stream.set_txmtu(client_rx_mtu);
385 inner.stream.set_rxmtu(server_rx_mtu);
386 }
387 respond::<_, pkt::ExchangeMtuRequest>(&mut inner.stream, response).await?;
388 }
389
390 pkt::DeviceRecv::FindInformationRequest(item) => {
391 let response = handler.handle_find_information_request(&item);
392 respond::<_, pkt::FindInformationRequest>(&mut inner.stream, response).await?;
393 }
394
395 pkt::DeviceRecv::FindByTypeValueRequest(item) => {
396 let response = handler.handle_find_by_type_value_request(&item);
397 respond::<_, pkt::FindByTypeValueRequest>(&mut inner.stream, response).await?;
398 }
399
400 pkt::DeviceRecv::ReadByTypeRequest(item) => {
401 let response = handler.handle_read_by_type_request(&item);
402 respond::<_, pkt::ReadByTypeRequest>(&mut inner.stream, response).await?;
403 }
404
405 pkt::DeviceRecv::ReadRequest(item) => {
406 let response = handler.handle_read_request(&item);
407 respond::<_, pkt::ReadRequest>(&mut inner.stream, response).await?;
408 }
409
410 pkt::DeviceRecv::ReadBlobRequest(item) => {
411 let response = handler.handle_read_blob_request(&item);
412 respond::<_, pkt::ReadBlobRequest>(&mut inner.stream, response).await?;
413 }
414
415 pkt::DeviceRecv::ReadMultipleRequest(item) => {
416 let response = handler.handle_read_multiple_request(&item);
417 respond::<_, pkt::ReadMultipleRequest>(&mut inner.stream, response).await?;
418 }
419
420 pkt::DeviceRecv::ReadByGroupTypeRequest(item) => {
421 let response = handler.handle_read_by_group_type_request(&item);
422 respond::<_, pkt::ReadByGroupTypeRequest>(&mut inner.stream, response).await?;
423 }
424
425 pkt::DeviceRecv::WriteRequest(item) => {
426 let response = handler.handle_write_request(&item);
427 respond::<_, pkt::WriteRequest>(&mut inner.stream, response).await?;
428 }
429
430 pkt::DeviceRecv::WriteCommand(item) => {
431 handler.handle_write_command(&item);
432 }
433
434 pkt::DeviceRecv::PrepareWriteRequest(item) => {
435 let response = handler.handle_prepare_write_request(&item);
436 respond::<_, pkt::PrepareWriteRequest>(&mut inner.stream, response).await?;
437 }
438
439 pkt::DeviceRecv::ExecuteWriteRequest(item) => {
440 let response = handler.handle_execute_write_request(&item);
441 respond::<_, pkt::ExecuteWriteRequest>(&mut inner.stream, response).await?;
442 }
443
444 pkt::DeviceRecv::SignedWriteCommand(item) => {
445 handler.handle_signed_write_command(&item);
446 }
447
448 pkt::DeviceRecv::HandleValueConfirmation(..) => {
449 if let Some(channel) = inner.await_confirmation.take() {
450 channel.send(()).ok();
451 }
452 }
453 }
454 Ok(())
455}
456
457struct ConnectionInner<IO> {
458 inner: Arc<Mutex<Inner<IO>>>,
459}
460
461impl<IO> ConnectionInner<IO>
462where
463 IO: AsyncRead + AsyncWrite + Unpin,
464{
465 fn notification(&self, handle: Handle) -> NotificationInner<IO> {
466 NotificationInner {
467 handle,
468 inner: self.inner.clone(),
469 state: NotificationState::Write,
470 }
471 }
472
473 fn indication(&self, handle: Handle) -> IndicationInner<IO> {
474 IndicationInner {
475 handle,
476 inner: self.inner.clone(),
477 state: IndicationState::Write,
478 }
479 }
480
481 async fn run<H>(self, mut handler: H) -> Result<()>
482 where
483 H: crate::Handler,
484 {
485 loop {
486 let (mut guard, request) = TryLockNext { inner: &self.inner }.await;
487 let request = if let Some(request) = request {
488 request?
489 } else {
490 return Ok(());
491 };
492
493 handle(&mut *guard, &mut handler, request).await?;
494 }
495 }
496}
497
498pub struct Notification {
499 inner: NotificationInner<AttStream>,
500}
501
502impl AsyncWrite for Notification {
503 fn poll_write(
504 self: Pin<&mut Self>,
505 cx: &mut Context<'_>,
506 buf: &[u8],
507 ) -> Poll<io::Result<usize>> {
508 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
509 }
510
511 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
512 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
513 }
514
515 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
516 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
517 }
518}
519
520pub struct Indication {
521 inner: IndicationInner<AttStream>,
522}
523
524impl AsyncWrite for Indication {
525 fn poll_write(
526 self: Pin<&mut Self>,
527 cx: &mut Context<'_>,
528 buf: &[u8],
529 ) -> Poll<io::Result<usize>> {
530 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
531 }
532
533 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
534 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
535 }
536
537 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
538 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
539 }
540}
541
542struct ServerInner<L> {
543 inner: L,
544}
545
546impl<L, IO> ServerInner<L>
547where
548 L: Stream<Item = io::Result<(IO, socket2::SockAddr)>> + Unpin,
549 IO: AsyncRead + AsyncWrite + Unpin,
550{
551 async fn accept(&mut self) -> io::Result<Option<(ConnectionInner<IO>, socket2::SockAddr)>> {
552 if let Some((sock, addr)) = self.inner.try_next().await? {
553 return Ok(Some((
554 ConnectionInner {
555 inner: Arc::new(Mutex::new(Inner::new(sock))),
556 },
557 addr,
558 )));
559 }
560 Ok(None)
561 }
562}
563
564pub struct Connection {
565 inner: ConnectionInner<AttStream>,
566 addr: crate::Address,
567}
568
569impl Connection {
570 pub fn address(&self) -> &crate::Address {
571 &self.addr
572 }
573
574 pub fn notification(&self, handle: Handle) -> Notification {
575 Notification {
576 inner: self.inner.notification(handle),
577 }
578 }
579
580 pub fn indication(&self, handle: Handle) -> Indication {
581 Indication {
582 inner: self.inner.indication(handle),
583 }
584 }
585
586 pub async fn run<H>(self, handler: H) -> Result<()>
587 where
588 H: crate::Handler,
589 {
590 log::debug!("Start serving.");
591 self.inner.run(handler).await?;
592 log::debug!("Done serving.");
593 Ok(())
594 }
595}
596
597pub struct Server {
598 inner: ServerInner<AttListener>,
599}
600
601impl Server {
602 pub fn new() -> io::Result<Self> {
604 let sock = AttListener::new()?;
605 Ok(Self {
606 inner: ServerInner { inner: sock },
607 })
608 }
609
610 pub fn needs_bond(&self) -> io::Result<()> {
611 self.inner
612 .inner
613 .set_sockopt_bt_security(crate::sock::BT_SECURITY_MEDIUM, 0)
614 }
615
616 pub fn needs_bond_mitm(&self) -> io::Result<()> {
617 self.inner
618 .inner
619 .set_sockopt_bt_security(crate::sock::BT_SECURITY_HIGH, 0)
620 }
621
622 pub async fn accept(&mut self) -> io::Result<Option<(Connection, crate::Address)>> {
623 if let Some((connection, addr)) = self.inner.accept().await? {
624 log::debug!("Connection accepted.");
625 let addr = crate::sock::try_from(addr)?;
626 Ok(Some((
627 Connection {
628 inner: connection,
629 addr: addr.clone(),
630 },
631 addr,
632 )))
633 } else {
634 Ok(None)
635 }
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use std::convert::TryFrom;
643 use tokio::io::AsyncWriteExt;
644 use tokio_test::io::Builder;
645
646 #[tokio::test]
647 async fn test_stream() {
648 let stream = Builder::new()
649 .read(&[0x02, 0x17, 0x00])
650 .write(&[0x03, 0x18, 0x00])
651 .build();
652 let mut stream = PacketStream::new(stream);
653 let packet = stream.try_next().await.unwrap().unwrap();
654 let packet = pkt::ExchangeMtuRequest::try_from(packet).unwrap();
655 assert_eq!(*packet.client_rx_mtu(), 23);
656
657 let packet = pkt::ExchangeMtuResponse::new(0x0018);
658 stream.send(packet).await.unwrap();
659 }
660
661 #[tokio::test]
662 async fn test_connection() {
663 struct H;
664 impl Handler for H {}
665
666 let stream = Builder::new()
667 .write(&[0x1B, 0x01, 0x00, 0x6F, 0x6B])
668 .read(&[0x02, 0x17, 0x00])
669 .write(&[0x03, 0x17, 0x00])
670 .build();
671 let connection = ConnectionInner {
672 inner: Arc::new(Mutex::new(Inner::new(stream))),
673 };
674
675 let mut notification = connection.notification(Handle::new(1));
676 notification.write_all(b"ok").await.unwrap();
677 connection.run(H).await.unwrap();
678 }
679
680 #[tokio::test]
681 async fn test_indication() {
682 struct H;
683 impl Handler for H {}
684
685 let stream = Builder::new()
686 .write(&[0x1D, 0x01, 0x00, 0x6F, 0x6B])
687 .read(&[0x1E, 0x17, 0x00])
688 .build();
689 let connection = ConnectionInner {
690 inner: Arc::new(Mutex::new(Inner::new(stream))),
691 };
692
693 let mut indication = connection.indication(Handle::new(1));
694 let task = tokio::spawn(connection.run(H));
695
696 indication.write_all(b"ok").await.unwrap();
697
698 task.await.unwrap().unwrap();
699 }
700}