use crate::Result;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ProtocolMessage {
pub id: Option<String>,
pub topic: String,
pub payload: Vec<u8>,
pub metadata: std::collections::HashMap<String, String>,
pub qos: Option<u8>,
pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
}
pub type MessageStream = Pin<Box<dyn Stream<Item = Result<ProtocolMessage>> + Send>>;
#[derive(Debug, Clone)]
pub struct StreamingMetadata {
pub protocol: super::Protocol,
pub connection_id: String,
pub server_info: Option<String>,
pub subscriptions: Vec<String>,
pub connected: bool,
}
#[async_trait]
pub trait StreamingProtocol: Send + Sync {
async fn subscribe(&self, topic: &str, consumer_id: &str) -> Result<MessageStream>;
async fn publish(&self, topic: &str, message: ProtocolMessage) -> Result<()>;
async fn unsubscribe(&self, _topic: &str, _consumer_id: &str) -> Result<()> {
Ok(())
}
fn get_metadata(&self) -> StreamingMetadata;
fn is_connected(&self) -> bool {
self.get_metadata().connected
}
}
pub struct StreamingProtocolRegistry {
handlers: std::collections::HashMap<super::Protocol, Arc<dyn StreamingProtocol>>,
}
impl StreamingProtocolRegistry {
pub fn new() -> Self {
Self {
handlers: std::collections::HashMap::new(),
}
}
pub fn register_handler(
&mut self,
protocol: super::Protocol,
handler: Arc<dyn StreamingProtocol>,
) {
self.handlers.insert(protocol, handler);
}
pub fn get_handler(&self, protocol: &super::Protocol) -> Option<&Arc<dyn StreamingProtocol>> {
self.handlers.get(protocol)
}
pub fn registered_protocols(&self) -> Vec<super::Protocol> {
self.handlers.keys().cloned().collect()
}
pub fn supports_protocol(&self, protocol: &super::Protocol) -> bool {
self.handlers.contains_key(protocol)
}
}
impl Default for StreamingProtocolRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct MessageBuilder {
message: ProtocolMessage,
}
impl MessageBuilder {
pub fn new(topic: impl Into<String>) -> Self {
Self {
message: ProtocolMessage {
id: None,
topic: topic.into(),
payload: Vec::new(),
metadata: std::collections::HashMap::new(),
qos: None,
timestamp: Some(chrono::Utc::now()),
},
}
}
pub fn id(mut self, id: impl Into<String>) -> Self {
self.message.id = Some(id.into());
self
}
pub fn payload(mut self, payload: impl Into<Vec<u8>>) -> Self {
self.message.payload = payload.into();
self
}
pub fn text(mut self, text: impl AsRef<str>) -> Self {
self.message.payload = text.as_ref().as_bytes().to_vec();
self
}
pub fn json<T: serde::Serialize>(mut self, value: &T) -> Result<Self> {
self.message.payload = serde_json::to_vec(value)?;
self.message
.metadata
.insert("content-type".to_string(), "application/json".to_string());
Ok(self)
}
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.message.metadata.insert(key.into(), value.into());
self
}
pub fn qos(mut self, qos: u8) -> Self {
self.message.qos = Some(qos);
self
}
pub fn build(self) -> ProtocolMessage {
self.message
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_builder() {
let message = MessageBuilder::new("test-topic")
.id("msg-123")
.text("Hello, World!")
.metadata("priority", "high")
.qos(1)
.build();
assert_eq!(message.topic, "test-topic");
assert_eq!(message.id, Some("msg-123".to_string()));
assert_eq!(message.payload, b"Hello, World!");
assert_eq!(message.metadata.get("priority"), Some(&"high".to_string()));
assert_eq!(message.qos, Some(1));
assert!(message.timestamp.is_some());
}
#[test]
fn test_message_builder_json() {
#[derive(serde::Serialize)]
struct TestData {
name: String,
value: i32,
}
let data = TestData {
name: "test".to_string(),
value: 42,
};
let message = MessageBuilder::new("json-topic").json(&data).unwrap().build();
assert_eq!(message.topic, "json-topic");
assert_eq!(message.metadata.get("content-type"), Some(&"application/json".to_string()));
assert!(!message.payload.is_empty());
}
#[test]
fn test_streaming_registry() {
let registry = StreamingProtocolRegistry::new();
assert!(!registry.supports_protocol(&crate::protocol_abstraction::Protocol::Mqtt));
assert_eq!(registry.registered_protocols().len(), 0);
}
}