use std::{
io,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
thread::{self, JoinHandle},
};
use lunar_lib::log::error;
use crate::{
ConnectionType, HandshakeType,
client::ConnectError,
common::{
HandshakeError, IPC_PROTOCOL_VERSION, SessionHandshakeRequest, SessionHandshakeResponse,
SessionPacket, SessionTarget, SessionToken, Stream,
},
ipc_common::{Packetable, RemoteEvent},
};
type EventCallbackFn = Box<dyn FnMut(RemoteEvent) + Send + Sync + 'static>;
trait SessionStreamExt: Stream {
fn handshake(
&mut self,
handshake: SessionHandshakeRequest,
) -> io::Result<Result<SessionHandshakeResponse, HandshakeError>> {
self.write_all(&HandshakeType::SESSION)?;
let mut version = [0u8; 4];
self.read_exact(&mut version)?;
if version != IPC_PROTOCOL_VERSION {
return Ok(Err(HandshakeError::WrongVersion {
expected: u32::from_be_bytes(IPC_PROTOCOL_VERSION),
connected: u32::from_be_bytes(version),
}));
}
handshake.serialize_into_writer(self)?;
let handshake_response = self.read_data()?;
Packetable::deserialize_packet(&handshake_response)
}
}
impl<T: Stream + ?Sized> SessionStreamExt for T {}
pub struct SessionEventListener {
session_token: SessionToken,
shutdown_sig: Arc<AtomicBool>,
thread_handle: Option<JoinHandle<()>>,
}
pub struct SessionConnectOptions {
connection_type: ConnectionType,
session_target: SessionTarget,
event_callback: EventCallbackFn,
}
impl SessionConnectOptions {
#[must_use]
pub fn new(
connection_type: ConnectionType,
session_target: SessionTarget,
event_callback: EventCallbackFn,
) -> Self {
Self {
connection_type,
session_target,
event_callback,
}
}
}
impl SessionEventListener {
pub fn connect(options: SessionConnectOptions) -> Result<SessionEventListener, ConnectError> {
let mut stream = options.connection_type.connect()?;
let handshake = SessionHandshakeRequest {
session_target: options.session_target,
};
let handshake_response = stream.handshake(handshake)??;
let shutdown_sig = Arc::new(AtomicBool::new(false));
let handle_sig = shutdown_sig.clone();
let handle = thread::spawn(|| {
let thread = SessionClientThread {
event_callback: options.event_callback,
stream,
shutdown_sig: handle_sig,
};
if let Err(err) = thread.run() {
error!("Client returned with error: {err}");
}
});
Ok(SessionEventListener {
session_token: handshake_response.token,
shutdown_sig,
thread_handle: Some(handle),
})
}
#[must_use]
pub fn session_token(&self) -> SessionToken {
self.session_token
}
pub fn disconnect(mut self) {
self.shutdown_sig.store(true, Ordering::Relaxed);
let _ = self.thread_handle.take().unwrap().join();
}
}
impl Drop for SessionEventListener {
fn drop(&mut self) {
self.shutdown_sig.store(true, Ordering::Relaxed);
}
}
struct SessionClientThread {
event_callback: EventCallbackFn,
stream: Box<dyn Stream>,
shutdown_sig: Arc<AtomicBool>,
}
impl SessionClientThread {
fn run(mut self) -> anyhow::Result<()> {
while !self.shutdown_sig.load(Ordering::Relaxed)
&& let Ok(data) = self.stream.read_data()
{
let packet: SessionPacket = postcard::from_bytes(&data)?;
match packet {
SessionPacket::Event(event) => {
(self.event_callback)(event);
}
}
}
Ok(())
}
}