Skip to main content

amqp_client_rust/api/
utils.rs

1use std::{
2    cell::LazyCell,
3    collections::HashMap,
4    fmt::{Display, write},
5    pin::Pin, sync::Arc,
6    hash::Hash
7};
8use std::error::Error as StdError;
9use crate::{api::channel::AsyncChannel, errors::{AppError, AppErrorType}};
10use amqprs::{FieldName, FieldTable, FieldValue, LongStr, ShortStr, channel::Channel};
11use dashmap::DashMap;
12use tracing::error;
13
14
15#[derive(Debug, Clone)]
16pub struct Message {
17    pub body: Arc<[u8]>,
18    pub content_type: Option<String>,
19}
20
21pub type Handler = Arc<
22    dyn Fn(
23            Message,
24        )
25            -> Pin<Box<dyn Future<Output = Result<(), Box<dyn StdError + Send + Sync>>> + Send>>
26        + Send
27        + Sync,
28>;
29pub type RPCHandler = Arc<
30    dyn Fn(
31            Message,
32        )
33            -> Pin<Box<dyn Future<Output = Result<Message, Box<dyn StdError + Send + Sync>>> + Send>>
34        + Send
35        + Sync,
36>;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Confirmations{
40    Disables,
41    PublisherConfirms,
42    RPCClientPublisherConfirms,
43    RPCServerPublisherConfirms,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum DeliveryMode {
48    Transient = 1,
49    Persistent = 2,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum ExchangeType {
54    Direct,
55    Fanout,
56    Topic,
57}
58
59pub enum ChannelCmd {
60    PublishAck((u64, bool)),
61    PublishNack((u64, bool)),
62    ReOpen(u16),
63}
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum ContentEncoding {
66    #[cfg(feature = "zstd")]
67    Zstd,
68    #[cfg(feature = "lz4_flex")]
69    Lz4,
70    #[cfg(feature = "flate2")]
71    Zlib,
72    None,
73}
74impl ContentEncoding {
75    pub fn from_str(s: &str) -> Option<ContentEncoding> {
76        match s {
77            #[cfg(feature = "zstd")]
78            "application/zstd" | "application/zstandard" => Some(ContentEncoding::Zstd),
79            #[cfg(feature = "lz4_flex")]
80            "application/lz4" => Some(ContentEncoding::Lz4),
81            #[cfg(feature = "flate2")]
82            "application/x-gzip" | "application/gzip" | "application/zlib" => Some(ContentEncoding::Zlib),
83            "none" => Some(ContentEncoding::None),
84            _ => None,
85        }
86    }
87    pub fn as_str(&self) -> &'static str {
88        match self {
89            #[cfg(feature = "zstd")]
90            ContentEncoding::Zstd => "application/zstd",
91            #[cfg(feature = "lz4_flex")]
92            ContentEncoding::Lz4 => "application/lz4",
93            #[cfg(feature = "flate2")]
94            ContentEncoding::Zlib => "application/x-gzip",
95            ContentEncoding::None => "none",
96        }
97    }
98}
99
100impl Display for ContentEncoding {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        write!(f, "{}", self.as_str())
103    }
104}
105
106#[derive(Clone,Debug)]
107pub struct TopicNode<T> {
108    children: HashMap<String, TopicNode<T>>,
109    values: Vec<T>,
110}
111
112impl<T> Default for TopicNode<T> {
113    fn default() -> Self {
114        Self {
115            children: HashMap::new(),
116            values: Vec::new(),
117        }
118    }
119}
120
121#[derive(Clone,Debug, Default)]
122pub struct TopicTrie<T> {
123    root: TopicNode<T>,
124}
125
126impl<T: Clone> TopicTrie<T> {
127    pub fn new() -> Self {
128        Self {
129            root: TopicNode::default(),
130        }
131    }
132
133    /// Inserts a new subscription pattern (binding key) and its associated handler.
134    pub fn insert(&mut self, pattern: &str, value: T) {
135        let segments: Vec<&str> = if pattern.is_empty() {
136            vec![]
137        } else {
138            pattern.split('.').collect()
139        };
140
141        let mut current = &mut self.root;
142        for segment in segments {
143            // Move to the child node, creating it if it doesn't exist
144            current = current.children.entry(segment.to_string()).or_default();
145        }
146        // Add the handler at the terminal node
147        current.values.push(value);
148    }
149
150    /// Searches for all handlers that match the incoming message's routing key.
151    pub fn search(&self, routing_key: &str) -> Vec<T> {
152        let mut results = Vec::new();
153        let segments: Vec<&str> = if routing_key.is_empty() {
154            vec![]
155        } else {
156            routing_key.split('.').collect()
157        };
158        
159        self.search_node(&self.root, &segments, &mut results);
160        results
161    }
162
163    /// Recursive search to handle branches created by '*' and '#'
164    fn search_node(&self, node: &TopicNode<T>, segments: &[&str], results: &mut Vec<T>) {
165        if segments.is_empty() {
166            // 1. If we've exhausted the routing key, any values at this node are a match.
167            results.extend(node.values.iter().cloned());
168
169            // 2. Edge Case: A '#' can match ZERO segments. 
170            // If we are out of segments, but the pattern ends in '#', it still matches.
171            // Example: Pattern "stock.#" matches routing key "stock"
172            if let Some(hash_child) = node.children.get("#") {
173                self.search_node(hash_child, segments, results);
174            }
175            return;
176        }
177
178        let head = segments[0];
179        let tail = &segments[1..];
180
181        // Path A: Exact Match
182        if let Some(child) = node.children.get(head) {
183            self.search_node(child, tail, results);
184        }
185
186        // Path B: Star '*' Match (substitutes exactly one word)
187        if let Some(star_child) = node.children.get("*") {
188            self.search_node(star_child, tail, results);
189        }
190
191        // Path C: Hash '#' Match (substitutes zero or more words)
192        if let Some(hash_child) = node.children.get("#") {
193            // Because '#' can consume any number of words, we branch out and test 
194            // consuming 0 segments, 1 segment, 2 segments... all the way to the end.
195            for i in 0..=segments.len() {
196                self.search_node(hash_child, &segments[i..], results);
197            }
198        }
199    }
200}
201
202#[cfg(feature = "zstd")]
203fn compress_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
204    zstd::encode_all(data, 1) 
205}
206
207
208#[cfg(feature = "zstd")]
209fn decompress_zstd(compressed_data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
210    zstd::decode_all(compressed_data)
211}
212
213#[cfg(feature = "lz4_flex")]
214fn compress_lz4(data: &[u8]) -> Vec<u8> {
215    lz4_flex::compress_prepend_size(data)
216}
217#[cfg(feature = "lz4_flex")]
218fn decompress_lz4(compressed_data: &[u8]) -> Result<Vec<u8>, AppError> {
219    Ok(lz4_flex::decompress_size_prepended(compressed_data)?)
220}
221
222#[cfg(feature = "flate2")]
223fn compress_zlib(data: &[u8]) -> Result<Vec<u8>, AppError> {
224    use std::io::Read;
225    let mut encoder = flate2::read::ZlibEncoder::new(data, flate2::Compression::default());
226    let mut compressed = Vec::new();
227    
228    match encoder.read_to_end(&mut compressed) {
229        Ok(_) => Ok(compressed),
230        Err(e) => Err(AppError::new(
231            Some(format!("Zlib compression failed: {}", e)), 
232            None, 
233            AppErrorType::InternalError
234        )),
235    }
236}
237#[cfg(feature = "flate2")]
238fn decompress_zlib(compressed_data: &[u8]) -> Result<Vec<u8>, AppError> {
239    use std::io::Read;
240
241    let mut decoder = flate2::read::ZlibDecoder::new(compressed_data);
242    let mut decompressed = Vec::new();
243    
244    match decoder.read_to_end(&mut decompressed) {
245        Ok(_) => Ok(decompressed),
246        Err(e) => Err(AppError::new(
247            Some(format!("Zlib decompression failed: {}", e)), 
248            None, 
249            AppErrorType::InternalError
250        )),
251    }
252}
253
254
255pub fn decompress(content: Vec<u8>, content_encoding: Option<&str>) -> Result<Vec<u8>, AppError> {
256    if let Some(ct) = content_encoding {
257        match ct {
258            #[cfg(feature = "zstd")]
259            "application/zstd" | "application/zstandard" => {
260                match decompress_zstd(&content[..]) {
261                    Ok(decompressed) => {
262                        Ok(decompressed)
263                    }
264                    Err(e) => {
265                        error!("Failed to create gzip decoder: {}", e);
266                        Err(AppError::new(Some("Failed to create gzip decoder".to_string()), None, AppErrorType::InternalError).into())
267                    }
268                }
269            },
270            #[cfg(feature = "lz4_flex")]
271            "application/lz4" => {
272                match decompress_lz4(&content[..]) {
273                    Ok(decompressed) => Ok(decompressed),
274                    Err(e) => {
275                        error!("Failed to decompress LZ4 content: {}", e);
276                        Err(AppError::new(Some("Failed to decompress LZ4 content".to_string()), None, AppErrorType::InternalError))
277                    }
278                }
279            },
280            #[cfg(feature = "flate2")]
281            "application/x-gzip" | "application/gzip" | "application/zlib" => {
282                match decompress_zlib(&content[..]) {
283                    Ok(decompressed) => Ok(decompressed),
284                    Err(e) => {
285                        error!("Failed to create gzip decoder: {}", e);
286                        Err(AppError::new(Some("Failed to create gzip decoder".to_string()), None, AppErrorType::InternalError).into())
287                    }
288                }
289            },
290            _ => Err(AppError::new(Some(format!("Unsupported content encoding: {}", ct)), None, AppErrorType::InternalError))
291        }
292    } else {
293        Ok(content)
294    }
295}
296
297pub fn compress(content: impl Into<Vec<u8>>, content_type: ContentEncoding) -> Result<Vec<u8>, AppError> {
298    match content_type {
299        #[cfg(feature = "zstd")]
300        ContentEncoding::Zstd => {
301            match compress_zstd(&content.into()) {
302                Ok(compressed) => Ok(compressed),
303                Err(e) => {
304                    error!("Failed to compress with zstd: {}", e);
305                    Err(AppError::new(Some("Failed to compress with zstd".to_string()), None, AppErrorType::InternalError))
306                }
307            }
308        },
309        #[cfg(feature = "lz4_flex")]
310        ContentEncoding::Lz4 => Ok(compress_lz4(&content.into())),
311        #[cfg(feature = "flate2")]
312        ContentEncoding::Zlib => Ok(compress_zlib(&mut content.into())?),
313        ContentEncoding::None => Ok(content.into()),
314    }
315}
316
317#[derive(Debug, Clone, PartialEq, Eq, Default)]
318pub struct QueueOptions {
319    pub auto_delete: bool,
320    pub durable: bool,
321    pub exclusive: bool,
322    pub no_create: bool,
323    arguments: HashMap<String, String>,
324}
325
326impl Hash for QueueOptions {
327    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
328        self.auto_delete.hash(state);
329        self.durable.hash(state);
330        self.exclusive.hash(state);
331        self.no_create.hash(state);
332        let mut sorted_args: Vec<(&String, &String)> = self.arguments.iter().collect();
333        sorted_args.sort_by(|a, b| a.0.cmp(b.0));
334        for (key, value) in sorted_args {
335            key.hash(state);
336            value.hash(state);
337        }
338    }
339}
340impl QueueOptions {
341    pub fn new() -> Self {
342        Self {
343            auto_delete: false,
344            durable: false,
345            exclusive: false,
346            no_create: false,
347            arguments: HashMap::new(),
348        }
349    }
350
351    pub fn build() -> Self {
352        Self::new()
353    }
354
355    pub fn auto_delete(mut self, auto_delete: bool) -> Self {
356        self.auto_delete = auto_delete;
357        self
358    }
359    pub fn durable(mut self, durable: bool) -> Self {
360        self.durable = durable;
361        self
362    }
363    pub fn exclusive(mut self, exclusive: bool) -> Self {
364        self.exclusive = exclusive;
365        self
366    }
367    pub fn no_create(mut self, no_create: bool) -> Self {
368        self.no_create = no_create;
369        self
370    }
371    pub fn argument(mut self, key: String, value: String) -> Result<Self, AppError> {
372        self.arguments.insert(key.try_into().map_err(|_| AppError::new(Some("key must be short".to_owned()), None, AppErrorType::InternalError))?, value);
373        Ok(self)
374    }
375    pub fn arguments(mut self, arguments: &HashMap<String, String>) -> Result<Self, AppError> {
376        for (key, value) in arguments.iter() {
377            let key_2 = key.to_owned();
378            let _: ShortStr = key_2.try_into().map_err(|_| AppError::new(Some(format!("key '{}' must be short", key)), None, AppErrorType::InternalError))?;
379            let value = value.to_owned();
380            self.arguments.insert(key.to_owned(), value);
381        }
382        Ok(self)
383    }
384}
385
386impl Into<FieldTable> for QueueOptions {
387    fn into(self) -> FieldTable {
388        let mut table = FieldTable::new();
389        for (key, value) in self.arguments.into_iter() {
390            table.insert(key.try_into().unwrap(), value.into());
391        }
392        table
393    }
394}
395
396pub const QUEUES: LazyCell<DashMap<String, (AsyncChannel, QueueOptions)>> = LazyCell::new(|| DashMap::new());