use std::collections::HashMap;
use crate::{
Error,
message::{Message, MessageType},
message_stream::MessageStream,
};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader};
pub struct AptStream {
inner: Box<dyn AsyncBufRead + Unpin + Send>,
message_stream: MessageStream,
output_stream: Box<dyn AsyncWrite + Unpin + Send>,
has_initialized: bool,
}
impl AptStream {
pub fn new() -> Self {
Self {
inner: Box::new(BufReader::new(tokio::io::stdin())),
message_stream: MessageStream::default(),
output_stream: Box::new(tokio::io::stdout()),
has_initialized: false,
}
}
pub fn with_input_stream(
mut self,
input_stream: Box<dyn AsyncBufRead + Unpin + Send>,
) -> Result<Self, Error> {
if self.has_initialized {
return Err(Error::StreamAlreadyInitialized);
}
self.inner = Box::new(input_stream);
Ok(self)
}
pub fn with_output_stream(
mut self,
output_stream: Box<dyn AsyncWrite + Unpin + Send>,
) -> Result<Self, Error> {
if self.has_initialized {
return Err(Error::StreamAlreadyInitialized);
}
self.output_stream = output_stream;
Ok(self)
}
pub async fn next<'a>(&'a mut self) -> Result<Option<AptRequest<'a>>, Error> {
if !self.has_initialized {
self.has_initialized = true;
log::debug!("synthesizing capabilities request");
let capabilities_req = AptRequest::Capabilities(CapabilitiesRequest { this: self });
return Ok(Some(capabilities_req));
}
let mut line = String::new();
loop {
log::trace!("reading line from input stream");
line.clear();
let nread = self.inner.as_mut().read_line(&mut line).await?;
log::trace!(
"read {} bytes from input stream: {:?}",
nread,
line.as_bytes()
);
if let Some(message_result) = self.message_stream.push_line(line.as_bytes()) {
log::debug!("complete message received");
log::trace!("complete message received: {:?}", message_result);
match message_result {
Ok(message) => {
let apt_request = AptRequest::<'a>::try_from_message((self, message)).await;
return match apt_request {
Ok(req) => {
log::debug!("APT request parse OK");
Ok(Some(req))
}
Err(e) => {
log::debug!("APT request parse error: {e:#?}");
Err(e)
}
};
}
Err(e) => {
log::debug!("APT message parse error: {e:?}");
return Err(e);
}
}
} else if nread == 0 {
log::debug!("EOF reached on input stream");
return Ok(None);
}
}
}
}
pub enum AptRequest<'a> {
Capabilities(CapabilitiesRequest<'a>),
Configuration(ConfigRequest),
Uri(UriRequest<'a>),
}
pub struct CapabilitiesRequest<'a> {
this: &'a mut AptStream,
}
impl<'a> CapabilitiesRequest<'a> {
#[must_use = "responses must be sent to the APT client with `.send().await`"]
pub fn respond(self) -> CapabilitiesResponse<'a> {
CapabilitiesResponse::new(self.this)
}
}
pub struct CapabilitiesResponse<'a> {
this: &'a mut AptStream,
single_instance: bool,
send_config: bool,
pipeline: bool,
local_only: bool,
removable: bool,
needs_cleanup: bool,
version: String,
}
impl CapabilitiesResponse<'_> {
#[must_use]
fn new(this: &mut AptStream) -> CapabilitiesResponse<'_> {
CapabilitiesResponse {
this,
single_instance: false,
send_config: false,
pipeline: false,
local_only: false,
removable: false,
needs_cleanup: false,
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
#[must_use = "responses must be sent with `.send().await`"]
pub fn single_instance(mut self, enabled: bool) -> Self {
self.single_instance = enabled;
self
}
#[must_use = "responses must be sent with `.send().await`"]
pub fn send_config(mut self, enabled: bool) -> Self {
self.send_config = enabled;
self
}
#[must_use = "responses must be sent with `.send().await`"]
pub fn version<S: Into<String>>(mut self, version: S) -> Self {
self.version = version.into();
self
}
#[must_use = "responses must be sent with `.send().await`"]
pub fn pipeline(mut self, enabled: bool) -> Self {
self.pipeline = enabled;
self
}
#[must_use = "responses must be sent with `.send().await`"]
pub fn local_only(mut self, enabled: bool) -> Self {
self.local_only = enabled;
self
}
#[must_use = "responses must be sent with `.send().await`"]
pub fn removable(mut self, enabled: bool) -> Self {
self.removable = enabled;
self
}
#[must_use = "responses must be sent with `.send().await`"]
pub fn needs_cleanup(mut self, enabled: bool) -> Self {
self.needs_cleanup = enabled;
self
}
pub async fn send(self) -> Result<(), Error> {
let msg = Message::new(
MessageType::Capabilities,
vec![
("Version", &self.version),
(
"Send-Config",
if self.send_config { "true" } else { "false" },
),
(
"Single-Instance",
if self.single_instance {
"true"
} else {
"false"
},
),
("Pipeline", if self.pipeline { "true" } else { "false" }),
("Local-Only", if self.local_only { "true" } else { "false" }),
("Removable", if self.removable { "true" } else { "false" }),
(
"Needs-Cleanup",
if self.needs_cleanup { "true" } else { "false" },
),
],
)
.to_string();
log::debug!("sending capabilities response");
log::trace!("sending capabilities response: {}", msg.trim());
self.this
.output_stream
.as_mut()
.write_all(msg.as_bytes())
.await?;
log::debug!("flushing output stream");
self.this.output_stream.as_mut().flush().await?;
log::debug!("capabilities response sent successfully");
Ok(())
}
}
pub struct ConfigRequest {
options: HashMap<String, String>,
}
impl ConfigRequest {
fn from(message: Message) -> Self {
let options = message
.headers
.into_iter()
.filter_map(|(key, value)| {
if key == "Config-Item" {
let (key, value) = value.split_once('=')?;
Some((key.to_string(), value.to_string()))
} else {
None
}
})
.collect::<HashMap<String, String>>();
ConfigRequest { options }
}
pub fn options(&self) -> &HashMap<String, String> {
&self.options
}
}
pub struct UriRequest<'a> {
this: &'a mut AptStream,
uri: String,
repo_uri: String,
filename: String,
}
impl<'a> UriRequest<'a> {
async fn from(this: &'a mut AptStream, message: Message) -> Result<UriRequest<'a>, Error> {
let Ok(uri) = message.header("URI") else {
let msg = Message::uri_failure("", "Missing URI header").to_string();
this.output_stream.write_all(msg.as_bytes()).await?;
this.output_stream.as_mut().flush().await?;
return Err(Error::HeaderNotFound("URI".to_string()));
};
let Ok(filename) = message.header("Filename") else {
let msg = Message::uri_failure(uri, "Missing Filename header").to_string();
this.output_stream.write_all(msg.as_bytes()).await?;
this.output_stream.as_mut().flush().await?;
return Err(Error::HeaderNotFound("Filename".to_string()));
};
let Ok(target_uri) = message.header("Target-Site") else {
let msg = Message::uri_failure(uri, "Missing Target-Site header").to_string();
this.output_stream.write_all(msg.as_bytes()).await?;
this.output_stream.as_mut().flush().await?;
return Err(Error::HeaderNotFound("Target-Site".to_string()));
};
Ok(UriRequest {
this,
uri: uri.to_string(),
repo_uri: target_uri.to_string(),
filename: filename.to_string(),
})
}
pub fn uri(&self) -> &str {
&self.uri
}
pub fn filename(&self) -> &str {
&self.filename
}
pub fn repo_uri(&self) -> &str {
&self.repo_uri
}
pub async fn send_status(&mut self, status: &str) -> Result<(), Error> {
let msg = Message::status(status).to_string();
self.this
.output_stream
.as_mut()
.write_all(msg.as_bytes())
.await?;
self.this.output_stream.as_mut().flush().await?;
Ok(())
}
pub async fn fail(self, reason: &str) -> Result<(), Error> {
let msg = Message::uri_failure(&self.uri, reason).to_string();
self.this
.output_stream
.as_mut()
.write_all(msg.as_bytes())
.await?;
self.this.output_stream.as_mut().flush().await?;
Ok(())
}
pub async fn start(
self,
size_in_bytes: u64,
last_modified: &str,
) -> Result<UriResponse<'a>, Error> {
let msg = Message::uri_start(&self.uri, size_in_bytes, last_modified).to_string();
self.this
.output_stream
.as_mut()
.write_all(msg.as_bytes())
.await?;
self.this.output_stream.as_mut().flush().await?;
Ok(UriResponse {
this: self.this,
uri: self.uri,
filename: self.filename,
})
}
}
pub struct UriResponse<'a> {
this: &'a mut AptStream,
uri: String,
filename: String,
}
impl UriResponse<'_> {
pub async fn complete(self) -> Result<(), Error> {
let msg = Message::uri_success(&self.uri, &self.filename).to_string();
self.this
.output_stream
.as_mut()
.write_all(msg.as_bytes())
.await?;
self.this.output_stream.as_mut().flush().await?;
Ok(())
}
pub async fn fail(self, reason: &str) -> Result<(), Error> {
let msg = Message::uri_failure(&self.uri, reason).to_string();
self.this
.output_stream
.as_mut()
.write_all(msg.as_bytes())
.await?;
self.this.output_stream.as_mut().flush().await?;
Ok(())
}
pub async fn writer(&self) -> Result<impl tokio::io::AsyncWrite + 'static, Error> {
Ok(tokio::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(self.filename.clone())
.await?)
}
}
trait TryFromMessage<'a>: Sized {
type Error;
async fn try_from_message(msg: (&'a mut AptStream, Message)) -> Result<Self, Self::Error>;
}
impl<'a> TryFromMessage<'a> for AptRequest<'a> {
type Error = Error;
async fn try_from_message(
(this, message): (&'a mut AptStream, Message),
) -> Result<Self, Self::Error> {
match message.message_type {
MessageType::Configuration => {
let config_request = ConfigRequest::from(message);
Ok(AptRequest::Configuration(config_request))
}
MessageType::URIAcquire => {
let uri_request = UriRequest::from(this, message).await?;
Ok(AptRequest::Uri(uri_request))
}
other => Err(Error::UnexpectedMessageType(other, message)),
}
}
}