use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::prelude::*;
use slog::{error, info};
use tokio::sync::{mpsc, oneshot};
use ts_bookkeeping::ChannelId;
#[cfg(feature = "audio")]
use tsproto_packets::packets::InAudioBuf;
#[cfg(feature = "unstable")]
use tsproto_packets::packets::OutCommand;
use crate::{
events, AudioEvent, DisconnectOptions, Error, InMessage, Result, StreamItem,
TemporaryDisconnectReason,
};
enum SyncConMessage {
RunFn(Box<dyn FnOnce(&mut SyncConnection) + Send>),
#[cfg(feature = "unstable")]
SendCommand(OutCommand, oneshot::Sender<Result<()>>),
WaitConnected(oneshot::Sender<Result<()>>),
Disconnect(DisconnectOptions, oneshot::Sender<Result<()>>),
DownloadFile {
channel_id: ChannelId,
path: String,
channel_password: Option<String>,
seek_position: Option<u64>,
send: oneshot::Sender<Result<super::FileDownloadResult>>,
},
UploadFile {
channel_id: ChannelId,
path: String,
channel_password: Option<String>,
size: u64,
overwrite: bool,
resume: bool,
send: oneshot::Sender<Result<super::FileUploadResult>>,
},
}
pub enum SyncStreamItem {
BookEvents(Vec<events::Event>),
MessageEvent(InMessage),
#[cfg(feature = "audio")]
Audio(InAudioBuf),
IdentityLevelIncreasing(u8),
IdentityLevelIncreased,
DisconnectedTemporarily(TemporaryDisconnectReason),
NetworkStatsUpdated,
AudioChange(AudioEvent),
}
#[derive(Clone)]
pub struct SyncConnectionHandle {
send: mpsc::Sender<SyncConMessage>,
}
pub struct SyncConnection {
con: super::Connection,
recv: mpsc::Receiver<SyncConMessage>,
send: mpsc::Sender<SyncConMessage>,
commands: HashMap<super::MessageHandle, oneshot::Sender<Result<()>>>,
connects: Vec<oneshot::Sender<Result<()>>>,
disconnects: Vec<oneshot::Sender<Result<()>>>,
downloads:
HashMap<super::FiletransferHandle, oneshot::Sender<Result<super::FileDownloadResult>>>,
uploads: HashMap<super::FiletransferHandle, oneshot::Sender<Result<super::FileUploadResult>>>,
}
impl From<super::Connection> for SyncConnection {
fn from(con: super::Connection) -> Self {
let (send, recv) = mpsc::channel(1);
Self {
con,
recv,
send,
commands: Default::default(),
connects: Default::default(),
disconnects: Default::default(),
downloads: Default::default(),
uploads: Default::default(),
}
}
}
impl Deref for SyncConnection {
type Target = super::Connection;
#[inline]
fn deref(&self) -> &Self::Target { &self.con }
}
impl DerefMut for SyncConnection {
#[inline]
fn deref_mut(&mut self) -> &mut <Self as Deref>::Target { &mut self.con }
}
impl Stream for SyncConnection {
type Item = Result<SyncStreamItem>;
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
if let Poll::Ready(msg) = self.recv.poll_recv(ctx) {
if let Some(msg) = msg {
match msg {
SyncConMessage::RunFn(f) => f(&mut *self),
#[cfg(feature = "unstable")]
SyncConMessage::SendCommand(arg, send) => {
let handle = match self.con.send_command_with_result(arg) {
Ok(r) => r,
Err(e) => {
let _ = send.send(Err(e));
continue;
}
};
self.commands.insert(handle, send);
}
SyncConMessage::WaitConnected(send) => {
if self.con.get_state().is_ok() {
let _ = send.send(Ok(()));
} else {
self.connects.push(send);
}
}
SyncConMessage::Disconnect(arg, send) => {
match self.con.disconnect(arg) {
Ok(r) => r,
Err(e) => {
let _ = send.send(Err(e));
continue;
}
}
self.disconnects.push(send);
}
SyncConMessage::DownloadFile {
channel_id,
path,
channel_password,
seek_position,
send,
} => {
let handle = match self.con.download_file(
channel_id,
&path,
channel_password.as_deref(),
seek_position,
) {
Ok(r) => r,
Err(e) => {
let _ = send.send(Err(e));
continue;
}
};
self.downloads.insert(handle, send);
}
SyncConMessage::UploadFile {
channel_id,
path,
channel_password,
size,
overwrite,
resume,
send,
} => {
let handle = match self.con.upload_file(
channel_id,
&path,
channel_password.as_deref(),
size,
overwrite,
resume,
) {
Ok(r) => r,
Err(e) => {
let _ = send.send(Err(e));
continue;
}
};
self.uploads.insert(handle, send);
}
}
continue;
} else {
error!(self.con.logger, "Message stream ended unexpectedly");
}
}
break;
}
loop {
break if let Poll::Ready(item) = self.con.poll_next(ctx) {
Poll::Ready(match item {
Some(Ok(item)) => Some(Ok(match item {
StreamItem::BookEvents(i) => {
self.connects.drain(..).for_each(|send| {
let _ = send.send(Ok(()));
});
SyncStreamItem::BookEvents(i)
}
StreamItem::MessageEvent(i) => SyncStreamItem::MessageEvent(i),
#[cfg(feature = "audio")]
StreamItem::Audio(i) => SyncStreamItem::Audio(i),
StreamItem::IdentityLevelIncreasing(i) => {
SyncStreamItem::IdentityLevelIncreasing(i)
}
StreamItem::IdentityLevelIncreased => {
SyncStreamItem::IdentityLevelIncreased
}
StreamItem::DisconnectedTemporarily(reason) => {
SyncStreamItem::DisconnectedTemporarily(reason)
}
StreamItem::MessageResult(handle, res) => {
if let Some(send) = self.commands.remove(&handle) {
let _ = send.send(res.map_err(|e| e.into()));
} else {
info!(self.con.logger, "Got untracked message result");
}
continue;
}
StreamItem::FileDownload(handle, res) => {
if let Some(send) = self.downloads.remove(&handle) {
let _ = send.send(Ok(res));
} else {
info!(self.con.logger, "Got untracked download");
}
continue;
}
StreamItem::FileUpload(handle, res) => {
if let Some(send) = self.uploads.remove(&handle) {
let _ = send.send(Ok(res));
} else {
info!(self.con.logger, "Got untracked upload");
}
continue;
}
StreamItem::FiletransferFailed(handle, res) => {
if let Some(send) = self.downloads.remove(&handle) {
let _ = send.send(Err(res));
} else if let Some(send) = self.uploads.remove(&handle) {
let _ = send.send(Err(res));
} else {
info!(self.con.logger, "Got untracked file transfer");
}
continue;
}
StreamItem::NetworkStatsUpdated => SyncStreamItem::NetworkStatsUpdated,
StreamItem::AudioChange(change) => SyncStreamItem::AudioChange(change),
})),
Some(Err(e)) => Some(Err(e)),
None => {
self.disconnects.drain(..).for_each(|send| {
let _ = send.send(Ok(()));
});
None
}
})
} else {
Poll::Pending
};
}
}
}
impl SyncConnection {
#[inline]
pub fn get_handle(&self) -> SyncConnectionHandle {
SyncConnectionHandle { send: self.send.clone() }
}
}
impl SyncConnectionHandle {
pub async fn with_connection<
T: Send + 'static,
F: FnOnce(&mut SyncConnection) -> T + Send + 'static,
>(
&mut self, f: F,
) -> Result<T> {
let (send, recv) = oneshot::channel();
self.send
.send(SyncConMessage::RunFn(Box::new(move |con| {
let _ = send.send(f(con));
})))
.await
.map_err(|_| Error::ConnectionGone)?;
Ok(recv.await.map_err(|_| Error::ConnectionGone)?)
}
#[cfg(feature = "unstable")]
pub async fn send_command(&mut self, arg: OutCommand) -> Result<()> {
let (send, recv) = oneshot::channel();
self.send
.send(SyncConMessage::SendCommand(arg, send))
.await
.map_err(|_| Error::ConnectionGone)?;
recv.await.map_err(|_| Error::ConnectionGone)?
}
pub async fn wait_until_connected(&mut self) -> Result<()> {
let (send, recv) = oneshot::channel();
self.send
.send(SyncConMessage::WaitConnected(send))
.await
.map_err(|_| Error::ConnectionGone)?;
recv.await.map_err(|_| Error::ConnectionGone)?
}
pub async fn disconnect(&mut self, arg: DisconnectOptions) -> Result<()> {
let (send, recv) = oneshot::channel();
self.send
.send(SyncConMessage::Disconnect(arg, send))
.await
.map_err(|_| Error::ConnectionGone)?;
recv.await.map_err(|_| Error::ConnectionGone)?
}
pub async fn download_file(
&mut self, channel_id: ChannelId, path: String, channel_password: Option<String>,
seek_position: Option<u64>,
) -> Result<super::FileDownloadResult> {
let (send, recv) = oneshot::channel();
self.send
.send(SyncConMessage::DownloadFile {
channel_id,
path,
channel_password,
seek_position,
send,
})
.await
.map_err(|_| Error::ConnectionGone)?;
recv.await.map_err(|_| Error::ConnectionGone)?
}
pub async fn upload_file(
&mut self, channel_id: ChannelId, path: String, channel_password: Option<String>,
size: u64, overwrite: bool, resume: bool,
) -> Result<super::FileUploadResult> {
let (send, recv) = oneshot::channel();
self.send
.send(SyncConMessage::UploadFile {
channel_id,
path,
channel_password,
size,
overwrite,
resume,
send,
})
.await
.map_err(|_| Error::ConnectionGone)?;
recv.await.map_err(|_| Error::ConnectionGone)?
}
}