use alloc::collections::VecDeque;
use alloc::string::String;
use alloc::vec::Vec;
use crate::HashMap;
use crate::codec::{self, Decode};
use crate::command::{Command, CommandBuilder};
use crate::error::ConnectionError;
use crate::response::{CommandResponse, ReplyResponse, TrapResponse};
use crate::tag::Tag;
use crate::word::Word;
#[derive(Debug, Clone)]
pub enum Event {
Reply {
tag: Tag,
response: ReplyResponse,
},
Done {
tag: Tag,
},
Empty {
tag: Tag,
},
Trap {
tag: Tag,
response: TrapResponse,
},
Fatal {
reason: String,
},
}
#[derive(Debug)]
pub struct Transmit {
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum State {
Active,
Dead,
}
#[derive(Debug)]
struct CommandState {
reply_count: usize,
}
#[derive(Debug)]
pub struct Connection {
state: State,
recv_buf: Vec<u8>,
in_flight: HashMap<Tag, CommandState>,
events: VecDeque<Event>,
outbound: VecDeque<Transmit>,
}
impl Connection {
pub fn new() -> Self {
Self {
state: State::Active,
recv_buf: Vec::new(),
in_flight: HashMap::new(),
events: VecDeque::new(),
outbound: VecDeque::new(),
}
}
pub fn receive(&mut self, data: &[u8]) -> Result<(), ConnectionError> {
if self.state == State::Dead {
return Err(ConnectionError::Closed);
}
self.recv_buf.extend_from_slice(data);
loop {
let outcome = {
let buf = &self.recv_buf;
match codec::decode_sentence(buf)? {
Decode::Complete {
value: raw_sentence,
bytes_consumed,
} => {
let result = match CommandResponse::parse(&raw_sentence) {
Ok(response) => Ok(response),
Err(e) => {
let tag_opt = raw_sentence.words().find_map(|word_bytes| {
Word::try_from(word_bytes).ok().and_then(|w| match w {
Word::Tag(t) => Some(t),
_ => None,
})
});
Err((e, tag_opt))
}
};
Some((result, bytes_consumed))
}
Decode::Incomplete { .. } => None,
}
};
match outcome {
Some((parsed, bytes_consumed)) => {
self.recv_buf.drain(..bytes_consumed);
match parsed {
Ok(response) => self.dispatch_response(response),
Err((error, tag_opt)) => self.handle_parse_error(&error, tag_opt),
}
}
None => break,
}
}
Ok(())
}
pub fn send_command(&mut self, command: Command) -> Result<Tag, ConnectionError> {
if self.state == State::Dead {
return Err(ConnectionError::Closed);
}
let tag = command.tag;
self.outbound.push_back(Transmit {
data: command.into_data(),
});
self.in_flight.insert(tag, CommandState { reply_count: 0 });
Ok(tag)
}
pub fn cancel_command(&mut self, tag: Tag) -> Result<(), ConnectionError> {
if self.state == State::Dead {
return Err(ConnectionError::Closed);
}
if self.in_flight.remove(&tag).is_some() {
let cancel = CommandBuilder::cancel(tag);
self.outbound.push_back(Transmit {
data: cancel.into_data(),
});
}
Ok(())
}
pub fn cancel_all(&mut self) {
let tags: Vec<Tag> = self.in_flight.keys().copied().collect();
for tag in tags {
let cancel = CommandBuilder::cancel(tag);
self.outbound.push_back(Transmit {
data: cancel.into_data(),
});
}
self.in_flight.clear();
}
pub fn poll_transmit(&mut self) -> Option<Transmit> {
self.outbound.pop_front()
}
pub fn poll_event(&mut self) -> Option<Event> {
self.events.pop_front()
}
pub fn state(&self) -> State {
self.state
}
pub fn is_active(&self) -> bool {
self.state == State::Active
}
pub fn in_flight_count(&self) -> usize {
self.in_flight.len()
}
pub fn has_pending_transmit(&self) -> bool {
!self.outbound.is_empty()
}
pub fn recv_buffer_len(&self) -> usize {
self.recv_buf.len()
}
pub fn is_in_flight(&self, tag: Tag) -> bool {
self.in_flight.contains_key(&tag)
}
fn dispatch_response(&mut self, response: CommandResponse) {
match response {
CommandResponse::Reply(reply) => {
let tag = reply.tag;
if let Some(cmd_state) = self.in_flight.get_mut(&tag) {
cmd_state.reply_count += 1;
self.events.push_back(Event::Reply {
tag,
response: reply,
});
}
}
CommandResponse::Done(done) => {
let tag = done.tag;
self.in_flight.remove(&tag);
self.events.push_back(Event::Done { tag });
}
CommandResponse::Empty(empty) => {
let tag = empty.tag;
self.in_flight.remove(&tag);
self.events.push_back(Event::Empty { tag });
}
CommandResponse::Trap(trap) => {
let tag = trap.tag;
self.in_flight.remove(&tag);
self.events.push_back(Event::Trap {
tag,
response: trap,
});
}
CommandResponse::Fatal(reason) => {
self.in_flight.clear();
self.state = State::Dead;
self.events.push_back(Event::Fatal { reason });
}
}
}
fn handle_parse_error(&mut self, error: &crate::error::ProtocolError, tag_opt: Option<Tag>) {
if let Some(tag) = tag_opt {
self.in_flight.remove(&tag);
self.events.push_back(Event::Trap {
tag,
response: TrapResponse {
tag,
category: None,
message: alloc::format!("Protocol error: {error}"),
},
});
}
}
}
impl Default for Connection {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec;
use alloc::format;
use alloc::string::String;
use alloc::vec;
fn build_sentence(words: &[&[u8]]) -> Vec<u8> {
let mut data = Vec::new();
for word in words {
codec::encode_word(word, &mut data);
}
codec::encode_terminator(&mut data);
data
}
fn build_done(tag: Tag) -> Vec<u8> {
let tag_word = format!(".tag={tag}");
build_sentence(&[b"!done", tag_word.as_bytes()])
}
fn build_empty(tag: Tag) -> Vec<u8> {
let tag_word = format!(".tag={tag}");
build_sentence(&[b"!empty", tag_word.as_bytes()])
}
fn build_reply(tag: Tag, attrs: &[(&str, &str)]) -> Vec<u8> {
let tag_word = format!(".tag={tag}");
let mut words: Vec<Vec<u8>> = vec![b"!re".to_vec(), tag_word.into_bytes()];
for (k, v) in attrs {
words.push(format!("={k}={v}").into_bytes());
}
let word_refs: Vec<&[u8]> = words.iter().map(|w| w.as_slice()).collect();
build_sentence(&word_refs)
}
fn build_trap(tag: Tag, message: &str) -> Vec<u8> {
let tag_word = format!(".tag={tag}");
let msg_word = format!("=message={message}");
build_sentence(&[b"!trap", tag_word.as_bytes(), msg_word.as_bytes()])
}
fn build_fatal(reason: &str) -> Vec<u8> {
build_sentence(&[b"!fatal", reason.as_bytes()])
}
#[test]
fn test_send_command_queues_transmit() {
let mut conn = Connection::new();
let cmd = CommandBuilder::new().command("/test").build();
let expected_data = cmd.data().to_vec();
let tag = conn.send_command(cmd).unwrap();
assert!(conn.is_in_flight(tag));
assert_eq!(conn.in_flight_count(), 1);
assert!(conn.has_pending_transmit());
let transmit = conn.poll_transmit().unwrap();
assert_eq!(transmit.data, expected_data);
assert!(!conn.has_pending_transmit());
}
#[test]
fn test_done_response_completes_command() {
let mut conn = Connection::new();
let cmd = CommandBuilder::new().command("/test").build();
let tag = conn.send_command(cmd).unwrap();
while conn.poll_transmit().is_some() {}
let wire = build_done(tag);
conn.receive(&wire).unwrap();
match conn.poll_event().unwrap() {
Event::Done { tag: t } => assert_eq!(t, tag),
other => panic!("expected Done, got {other:?}"),
}
assert_eq!(conn.in_flight_count(), 0);
assert!(conn.poll_event().is_none());
}
#[test]
fn test_empty_response_completes_command() {
let mut conn = Connection::new();
let cmd = CommandBuilder::new().command("/test").build();
let tag = conn.send_command(cmd).unwrap();
while conn.poll_transmit().is_some() {}
let wire = build_empty(tag);
conn.receive(&wire).unwrap();
match conn.poll_event().unwrap() {
Event::Empty { tag: t } => assert_eq!(t, tag),
other => panic!("expected Empty, got {other:?}"),
}
assert_eq!(conn.in_flight_count(), 0);
}
#[test]
fn test_streaming_replies_then_done() {
let mut conn = Connection::new();
let cmd = CommandBuilder::new().command("/interface/print").build();
let tag = conn.send_command(cmd).unwrap();
while conn.poll_transmit().is_some() {}
conn.receive(&build_reply(tag, &[("name", "ether1")]))
.unwrap();
conn.receive(&build_reply(tag, &[("name", "ether2")]))
.unwrap();
conn.receive(&build_done(tag)).unwrap();
match conn.poll_event().unwrap() {
Event::Reply { tag: t, response } => {
assert_eq!(t, tag);
assert_eq!(
response.attributes.get("name"),
Some(&Some(String::from("ether1")))
);
}
other => panic!("expected Reply, got {other:?}"),
}
match conn.poll_event().unwrap() {
Event::Reply { tag: t, response } => {
assert_eq!(t, tag);
assert_eq!(
response.attributes.get("name"),
Some(&Some(String::from("ether2")))
);
}
other => panic!("expected Reply, got {other:?}"),
}
match conn.poll_event().unwrap() {
Event::Done { tag: t } => assert_eq!(t, tag),
other => panic!("expected Done, got {other:?}"),
}
assert_eq!(conn.in_flight_count(), 0);
assert!(conn.poll_event().is_none());
}
#[test]
fn test_trap_response_terminates_command() {
let mut conn = Connection::new();
let cmd = CommandBuilder::new().command("/test").build();
let tag = conn.send_command(cmd).unwrap();
while conn.poll_transmit().is_some() {}
conn.receive(&build_trap(tag, "no such command")).unwrap();
match conn.poll_event().unwrap() {
Event::Trap { tag: t, response } => {
assert_eq!(t, tag);
assert_eq!(response.message, "no such command");
}
other => panic!("expected Trap, got {other:?}"),
}
assert_eq!(conn.in_flight_count(), 0);
}
#[test]
fn test_fatal_kills_all_commands() {
let mut conn = Connection::new();
let cmd1 = CommandBuilder::new().command("/test1").build();
let cmd2 = CommandBuilder::new().command("/test2").build();
conn.send_command(cmd1).unwrap();
conn.send_command(cmd2).unwrap();
while conn.poll_transmit().is_some() {}
conn.receive(&build_fatal("out of memory")).unwrap();
match conn.poll_event().unwrap() {
Event::Fatal { reason } => assert_eq!(reason, "out of memory"),
other => panic!("expected Fatal, got {other:?}"),
}
assert_eq!(conn.state(), State::Dead);
assert_eq!(conn.in_flight_count(), 0);
let cmd3 = CommandBuilder::new().command("/test3").build();
assert!(conn.send_command(cmd3).is_err());
assert!(conn.receive(&[]).is_err());
}
#[test]
fn test_partial_receive() {
let mut conn = Connection::new();
let cmd = CommandBuilder::new().command("/test").build();
let tag = conn.send_command(cmd).unwrap();
while conn.poll_transmit().is_some() {}
let wire = build_done(tag);
for &byte in &wire {
conn.receive(&[byte]).unwrap();
}
match conn.poll_event().unwrap() {
Event::Done { tag: t } => assert_eq!(t, tag),
other => panic!("expected Done, got {other:?}"),
}
}
#[test]
fn test_cancel_command() {
let mut conn = Connection::new();
let cmd = CommandBuilder::new().command("/test").build();
let tag = conn.send_command(cmd).unwrap();
while conn.poll_transmit().is_some() {}
conn.cancel_command(tag).unwrap();
assert!(conn.has_pending_transmit());
let cancel_transmit = conn.poll_transmit().unwrap();
assert!(!cancel_transmit.data.is_empty());
assert_eq!(conn.in_flight_count(), 0);
}
#[test]
fn test_cancel_all() {
let mut conn = Connection::new();
let cmd1 = CommandBuilder::new().command("/test1").build();
let cmd2 = CommandBuilder::new().command("/test2").build();
conn.send_command(cmd1).unwrap();
conn.send_command(cmd2).unwrap();
while conn.poll_transmit().is_some() {}
conn.cancel_all();
let mut cancel_count = 0;
while conn.poll_transmit().is_some() {
cancel_count += 1;
}
assert_eq!(cancel_count, 2);
assert_eq!(conn.in_flight_count(), 0);
}
#[test]
fn test_multiple_sentences_in_single_receive() {
let mut conn = Connection::new();
let cmd1 = CommandBuilder::new().command("/test1").build();
let cmd2 = CommandBuilder::new().command("/test2").build();
let tag1 = conn.send_command(cmd1).unwrap();
let tag2 = conn.send_command(cmd2).unwrap();
while conn.poll_transmit().is_some() {}
let mut combined = build_done(tag1);
combined.extend_from_slice(&build_done(tag2));
conn.receive(&combined).unwrap();
match conn.poll_event().unwrap() {
Event::Done { tag } => assert_eq!(tag, tag1),
other => panic!("expected Done for tag1, got {other:?}"),
}
match conn.poll_event().unwrap() {
Event::Done { tag } => assert_eq!(tag, tag2),
other => panic!("expected Done for tag2, got {other:?}"),
}
assert_eq!(conn.in_flight_count(), 0);
}
#[test]
fn test_reply_for_unknown_tag_is_ignored() {
let mut conn = Connection::new();
let unknown_tag = Tag::new();
conn.receive(&build_reply(unknown_tag, &[("name", "test")]))
.unwrap();
assert!(conn.poll_event().is_none());
}
#[test]
fn test_connection_starts_active() {
let conn = Connection::new();
assert_eq!(conn.state(), State::Active);
assert!(conn.is_active());
assert_eq!(conn.in_flight_count(), 0);
assert!(!conn.has_pending_transmit());
assert_eq!(conn.recv_buffer_len(), 0);
}
}