use crate::error::{Result, TakError};
use crate::framing::{decode_tak_header, encode_tak_message};
use crate::proto::{CotEvent, TakMessage};
use crate::tls::TlsConfig;
use bytes::BytesMut;
use prost::Message;
use rustls::pki_types::ServerName;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
enum Connection {
Plain(TcpStream),
Tls(tokio_rustls::client::TlsStream<TcpStream>),
}
impl Connection {
async fn read_buf(&mut self, buf: &mut BytesMut) -> std::io::Result<usize> {
match self {
Connection::Plain(stream) => stream.read_buf(buf).await,
Connection::Tls(stream) => stream.read_buf(buf).await,
}
}
async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
match self {
Connection::Plain(stream) => stream.write_all(buf).await,
Connection::Tls(stream) => stream.write_all(buf).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
Connection::Plain(stream) => stream.flush().await,
Connection::Tls(stream) => stream.flush().await,
}
}
async fn shutdown(&mut self) -> std::io::Result<()> {
match self {
Connection::Plain(stream) => stream.shutdown().await,
Connection::Tls(stream) => stream.shutdown().await,
}
}
}
pub struct TakClient {
connection: Connection,
read_buffer: BytesMut,
}
impl TakClient {
pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> Result<Self> {
let stream = TcpStream::connect(addr).await?;
Ok(Self {
connection: Connection::Plain(stream),
read_buffer: BytesMut::with_capacity(8192),
})
}
pub async fn connect_tls(
addr: impl tokio::net::ToSocketAddrs,
server_name: &str,
tls_config: TlsConfig,
) -> Result<Self> {
let tcp_stream = TcpStream::connect(addr).await?;
let connector = TlsConnector::from(tls_config.config);
let server_name = ServerName::try_from(server_name.to_owned())
.map_err(|e| TakError::Tls(format!("Invalid server name: {}", e)))?;
let tls_stream = connector.connect(server_name, tcp_stream).await?;
Ok(Self {
connection: Connection::Tls(tls_stream),
read_buffer: BytesMut::with_capacity(8192),
})
}
pub async fn send_cot_event(&mut self, event: CotEvent) -> Result<()> {
let tak_message = TakMessage {
tak_control: None,
cot_event: Some(event),
};
self.send_tak_message(tak_message).await
}
pub async fn send_tak_message(&mut self, message: TakMessage) -> Result<()> {
let frame = encode_tak_message(&message)?;
self.connection.write_all(&frame).await?;
self.connection.flush().await?;
Ok(())
}
pub async fn receive_tak_message(&mut self) -> Result<Option<TakMessage>> {
loop {
if let Some(message) = self.try_decode_message()? {
return Ok(Some(message));
}
let n = self.connection.read_buf(&mut self.read_buffer).await?;
if n == 0 {
if self.read_buffer.is_empty() {
return Ok(None);
} else {
return Err(TakError::ConnectionClosed);
}
}
}
}
fn try_decode_message(&mut self) -> Result<Option<TakMessage>> {
let payload_len = match decode_tak_header(&mut self.read_buffer)? {
Some(len) => len,
None => return Ok(None), };
if self.read_buffer.len() < payload_len {
return Ok(None); }
let payload = self.read_buffer.split_to(payload_len);
let message = TakMessage::decode(&payload[..])?;
Ok(Some(message))
}
pub async fn send_cot_event_xml(&mut self, event: CotEvent) -> Result<()> {
let xml = crate::xml::encode_cot_event_xml(&event);
self.connection.write_all(xml.as_bytes()).await?;
self.connection.flush().await?;
Ok(())
}
pub async fn negotiate_protocol(&mut self, version: u32, timeout_secs: u64) -> Result<()> {
use tokio::time::{timeout, Duration};
let xml_buffer = timeout(
Duration::from_secs(timeout_secs),
self.read_xml_messages_until(|msg| crate::xml::is_protocol_support(msg)),
)
.await
.map_err(|_| TakError::NegotiationFailed("Timeout waiting for protocol support".to_string()))??;
if xml_buffer.is_empty() {
return Err(TakError::NegotiationFailed(
"Server did not advertise protocol support".to_string(),
));
}
let request_xml = crate::xml::create_protocol_request(version);
self.connection.write_all(request_xml.as_bytes()).await?;
self.connection.flush().await?;
let response_xml = timeout(
Duration::from_secs(timeout_secs),
self.read_xml_messages_until(|msg| crate::xml::is_protocol_response(msg)),
)
.await
.map_err(|_| TakError::NegotiationFailed("Timeout waiting for protocol response".to_string()))??;
if response_xml.is_empty() {
return Err(TakError::NegotiationFailed(
"No protocol response received".to_string(),
));
}
if !crate::xml::is_protocol_response_success(&response_xml) {
return Err(TakError::NegotiationFailed(
"Server rejected protocol request".to_string(),
));
}
Ok(())
}
async fn read_xml_messages_until<F>(&mut self, mut condition: F) -> Result<String>
where
F: FnMut(&str) -> bool,
{
let mut buffer = String::new();
let mut temp_buf = BytesMut::with_capacity(4096);
loop {
temp_buf.clear();
let n = self.connection.read_buf(&mut temp_buf).await?;
if n == 0 {
return Err(TakError::ConnectionClosed);
}
buffer.push_str(&String::from_utf8_lossy(&temp_buf[..]));
while let Some(end_pos) = buffer.find("</event>") {
let message = &buffer[..end_pos + 8]; if condition(message) {
return Ok(message.to_string());
}
buffer = buffer[end_pos + 8..].to_string();
}
}
}
pub async fn close(mut self) -> Result<()> {
self.connection.shutdown().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cot_event_creation() {
let event = CotEvent {
r#type: "a-f-G-U-C".to_string(),
uid: "TEST-1".to_string(),
send_time: 1000,
start_time: 1000,
stale_time: 2000,
how: "m-g".to_string(),
lat: 37.7749,
lon: -122.4194,
hae: 10.0,
ce: 9.9,
le: 9.9,
..Default::default()
};
assert_eq!(event.r#type, "a-f-G-U-C");
assert_eq!(event.uid, "TEST-1");
}
}