use super::message::*;
use serde::Serialize;
use tokio::io::{AsyncWrite, AsyncWriteExt};
pub struct SkillMessageWriter<W> {
writer: W,
}
impl<W: AsyncWrite + Unpin> SkillMessageWriter<W> {
pub fn new(writer: W) -> Self {
Self { writer }
}
async fn write_message(&mut self, msg: &SkillMessage) -> std::io::Result<()> {
let type_bytes = (msg.msg_type as u32).to_be_bytes();
let len_bytes = (msg.payload.len() as u32).to_be_bytes();
self.writer.write_all(&type_bytes).await?;
self.writer.write_all(&len_bytes).await?;
if !msg.payload.is_empty() {
self.writer.write_all(&msg.payload).await?;
}
self.writer.flush().await?;
Ok(())
}
async fn write_typed<T: Serialize>(
&mut self,
msg_type: SkillMessageType,
payload: &T,
) -> std::io::Result<()> {
let json = serde_json::to_vec(payload)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
self.write_message(&SkillMessage::new(msg_type, json)).await
}
pub async fn write_execute(&mut self, payload: &ExecutePayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::Execute, payload).await
}
pub async fn write_cancel(&mut self, execution_id: &str) -> std::io::Result<()> {
let payload = execution_id.as_bytes().to_vec();
self.write_message(&SkillMessage::new(SkillMessageType::Cancel, payload))
.await
}
pub async fn write_stdin_data(&mut self, payload: &StdinDataPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::StdinData, payload).await
}
pub async fn write_resize(&mut self, payload: &ResizePayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::Resize, payload).await
}
pub async fn write_signal(&mut self, payload: &SignalPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::Signal, payload).await
}
pub async fn write_start_session(
&mut self,
payload: &StartSessionPayload,
) -> std::io::Result<()> {
self.write_typed(SkillMessageType::StartSession, payload)
.await
}
pub async fn write_shutdown(&mut self) -> std::io::Result<()> {
self.write_message(&SkillMessage::new(SkillMessageType::Shutdown, vec![]))
.await
}
pub async fn write_ack(&mut self, payload: &AckPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::Ack, payload).await
}
pub async fn write_stdout_chunk(&mut self, payload: &DataChunkPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::StdoutChunk, payload)
.await
}
pub async fn write_stderr_chunk(&mut self, payload: &DataChunkPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::StderrChunk, payload)
.await
}
pub async fn write_progress(&mut self, payload: &ProgressPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::Progress, payload).await
}
pub async fn write_completed(&mut self, payload: &CompletedPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::Completed, payload).await
}
pub async fn write_error(&mut self, payload: &ErrorPayload) -> std::io::Result<()> {
self.write_typed(SkillMessageType::Error, payload).await
}
pub async fn write_session_started(
&mut self,
payload: &SessionStartedPayload,
) -> std::io::Result<()> {
self.write_typed(SkillMessageType::SessionStarted, payload)
.await
}
pub async fn write_proxy_submit(
&mut self,
payload: &ProxySubmitPayload,
) -> std::io::Result<()> {
self.write_typed(SkillMessageType::ProxySubmit, payload)
.await
}
pub async fn write_proxy_cancel(&mut self, proxy_id: &str) -> std::io::Result<()> {
let payload = proxy_id.as_bytes().to_vec();
self.write_message(&SkillMessage::new(SkillMessageType::ProxyCancel, payload))
.await
}
pub async fn write_proxy_stdout_chunk(
&mut self,
payload: &ProxyChunkPayload,
) -> std::io::Result<()> {
self.write_typed(SkillMessageType::ProxyStdoutChunk, payload)
.await
}
pub async fn write_proxy_stderr_chunk(
&mut self,
payload: &ProxyChunkPayload,
) -> std::io::Result<()> {
self.write_typed(SkillMessageType::ProxyStderrChunk, payload)
.await
}
pub async fn write_proxy_completed(
&mut self,
payload: &ProxyCompletedPayload,
) -> std::io::Result<()> {
self.write_typed(SkillMessageType::ProxyCompleted, payload)
.await
}
pub async fn write_proxy_rejected(
&mut self,
payload: &ProxyRejectedPayload,
) -> std::io::Result<()> {
self.write_typed(SkillMessageType::ProxyRejected, payload)
.await
}
}
#[cfg(test)]
#[path = "writer_tests.rs"]
mod tests;