1use bitcode::{Decode, Encode};
2#[cfg(target_arch = "wasm32")]
4use gloo_utils::format::JsValueSerdeExt;
5use once_cell::sync::Lazy;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use serde_json::json;
10use std::collections::HashMap;
11use std::io::Read;
12use std::sync::Arc;
13#[cfg(target_arch = "wasm32")]
14use tsify::*;
15#[cfg(target_arch = "wasm32")]
16use wasm_bindgen::prelude::*;
17
18#[cfg(not(target_arch = "wasm32"))]
19use zstd;
20
21#[cfg(target_arch = "wasm32")]
22use lz4_flex::block::{compress_prepend_size, decompress_size_prepended};
23
24use crate::stream::StreamHandle;
25use crate::types::PortType;
26
27pub const COMPRESSION_THRESHOLD: usize = 1024; #[derive(Clone, Debug, Serialize, Deserialize, Encode, Decode, PartialEq, Eq)]
30pub struct EncodedMessage(pub Vec<u8>);
31
32impl EncodedMessage {
33 pub fn new(msg: &Message) -> Self {
34 Self(bitcode::encode(msg))
35 }
36
37 pub fn decode(&self) -> Option<Message> {
38 bitcode::decode(&self.0).ok()
39 }
40}
41
42#[derive(Clone, Default, Debug, Serialize, Deserialize, Encode, Decode, PartialEq)]
44#[cfg_attr(target_arch = "wasm32", derive(Tsify))]
45#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi))]
46#[cfg_attr(target_arch = "wasm32", tsify(from_wasm_abi))]
47#[serde(tag = "type", content = "data")]
48pub enum Message {
49 #[default]
52 Flow,
53
54 Event(EncodableValue),
57
58 Boolean(bool),
60
61 Integer(i64),
63
64 Float(f64),
66
67 String(Arc<String>),
69
70 Object(Arc<EncodableValue>),
73
74 Array(Arc<Vec<EncodableValue>>),
76
77 Bytes(Arc<Vec<u8>>),
81
82 StreamHandle(Arc<StreamHandle>),
89
90 Encoded(Arc<Vec<u8>>),
93
94 Optional(Option<Arc<EncodableValue>>),
97
98 Any(Arc<EncodableValue>),
101
102 Error(Arc<String>),
105
106 RemoteReference {
109 network_id: String,
110 actor_id: String,
111 port: String,
112 },
113
114 NetworkEvent {
117 event_type: NetworkEventType,
118 data: EncodableValue,
119 },
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Encode, Decode)]
123pub enum NetworkEventType {
124 ActorRegistered,
125 ActorUnregistered,
126 NetworkConnected,
127 NetworkDisconnected,
128 HeartbeatMissed,
129}
130
131#[derive(Clone)]
148pub struct CompressionConfig {
149 pub size_threshold: usize,
151 pub streaming_threshold: usize,
153 pub enabled: bool,
155 pub level: u32,
157 pub type_strategies: HashMap<String, CompressionStrategy>,
159}
160
161impl std::fmt::Debug for CompressionConfig {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("CompressionConfig")
164 .field("size_threshold", &self.size_threshold)
165 .field("streaming_threshold", &self.streaming_threshold)
166 .field("enabled", &self.enabled)
167 .field("level", &self.level)
168 .field(
169 "type_strategies",
170 &HashMap::<String, CompressionStrategy>::from_iter(
171 self.type_strategies
172 .iter()
173 .map(|(k, v)| (k.clone(), v.clone()))
174 .collect::<Vec<(String, CompressionStrategy)>>(),
175 ),
176 )
177 .finish()
178 }
179}
180
181impl Serialize for CompressionConfig {
182 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
183 where
184 S: serde::Serializer,
185 {
186 let mut map = serde_json::Map::new();
187 let strategy = self
188 .type_strategies
189 .iter()
190 .map(|(k, v)| (k.clone(), v.clone()))
191 .collect::<HashMap<String, CompressionStrategy>>();
192 map.insert("size_threshold".to_string(), json!(self.size_threshold));
193 map.insert(
194 "streaming_threshold".to_string(),
195 json!(self.streaming_threshold),
196 );
197 map.insert("enabled".to_string(), json!(self.enabled));
198 map.insert("level".to_string(), json!(self.level));
199 map.insert("type_strategies".to_string(), json!(strategy));
200 serde_json::Value::Object(map).serialize(serializer)
201 }
202}
203
204impl<'de> Deserialize<'de> for CompressionConfig {
205 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
206 where
207 D: serde::Deserializer<'de>,
208 {
209 let value: Value = Deserialize::deserialize(deserializer)?;
210 let size_threshold = value
211 .get("size_threshold")
212 .and_then(|t| t.as_u64())
213 .map(|t| t as usize)
214 .unwrap_or(1024);
215 let streaming_threshold = value
216 .get("streaming_threshold")
217 .and_then(|t| t.as_u64())
218 .map(|t| t as usize)
219 .unwrap_or(1024 * 1024); let enabled = value
221 .get("enabled")
222 .and_then(|e| e.as_bool())
223 .unwrap_or(true);
224 let level = value
225 .get("level")
226 .and_then(|l| l.as_u64())
227 .map(|l| l as u32)
228 .unwrap_or(6); let mut type_strategies = HashMap::new();
230 if let Some(strategies) = value.get("type_strategies")
231 && let Some(map) = strategies.as_object()
232 {
233 for (type_name, strategy) in map {
234 let strategy = match strategy.as_str() {
235 Some("Never") => CompressionStrategy::Never,
236 Some("Always") => CompressionStrategy::Always,
237 Some("SizeThreshold") => CompressionStrategy::SizeThreshold,
238 Some("Adaptive") => CompressionStrategy::Adaptive,
239 _ => {
240 return Err(serde::de::Error::custom(format!(
241 "Invalid compression strategy: {}",
242 strategy
243 )));
244 }
245 };
246 type_strategies.insert(type_name.to_string(), strategy);
247 }
248 }
249
250 Ok(Self {
251 size_threshold,
252 streaming_threshold,
253 enabled,
254 level,
255 type_strategies,
256 })
257 }
258}
259
260impl<'de> Deserialize<'de> for CompressionStrategy {
261 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
262 where
263 D: serde::Deserializer<'de>,
264 {
265 let value: Value = Deserialize::deserialize(deserializer)?;
266 match value {
267 Value::String(s) => match s.as_str() {
268 "Never" => Ok(Self::Never),
269 "Always" => Ok(Self::Always),
270 "SizeThreshold" => Ok(Self::SizeThreshold),
271 "Adaptive" => Ok(Self::Adaptive),
272 _ => Err(serde::de::Error::custom(format!(
273 "Invalid compression strategy: {}",
274 s
275 ))),
276 },
277 Value::Object(map) => {
278 let strategy = map
279 .get("strategy")
280 .and_then(|s| s.as_str())
281 .ok_or_else(|| {
282 serde::de::Error::custom("Invalid compression strategy object")
283 })?;
284 match strategy {
285 "Never" => Ok(Self::Never),
286 "Always" => Ok(Self::Always),
287 "SizeThreshold" => Ok(Self::SizeThreshold),
288 "Adaptive" => Ok(Self::Adaptive),
289 _ => Err(serde::de::Error::custom(format!(
290 "Invalid compression strategy: {}",
291 strategy
292 ))),
293 }
294 }
295 _ => Err(serde::de::Error::custom(
296 "Invalid compression strategy value",
297 )),
298 }
299 }
300}
301
302impl Default for CompressionConfig {
303 fn default() -> Self {
304 let mut type_strategies = HashMap::new();
305 type_strategies.insert("Bytes".to_string(), CompressionStrategy::Always);
307 type_strategies.insert("Array".to_string(), CompressionStrategy::Adaptive);
308 type_strategies.insert("String".to_string(), CompressionStrategy::SizeThreshold);
309
310 Self {
311 size_threshold: 1024, streaming_threshold: 1024 * 1024, enabled: true,
314 level: 6, type_strategies,
316 }
317 }
318}
319
320#[derive(Clone)]
322pub enum CompressionStrategy {
323 Never,
325 Always,
327 SizeThreshold,
329 Adaptive,
331 Custom(Arc<dyn Fn(&Message) -> bool + Send + Sync>),
333}
334
335impl Serialize for CompressionStrategy {
336 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
337 where
338 S: serde::Serializer,
339 {
340 match self {
341 Self::Never => serializer.serialize_str("Never"),
342 Self::Always => serializer.serialize_str("Always"),
343 Self::SizeThreshold => serializer.serialize_str("SizeThreshold"),
344 Self::Adaptive => serializer.serialize_str("Adaptive"),
345 Self::Custom(_) => serializer.serialize_str("Custom"),
346 }
347 }
348}
349
350impl std::fmt::Debug for CompressionStrategy {
351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 match self {
353 Self::Never => write!(f, "Never"),
354 Self::Always => write!(f, "Always"),
355 Self::SizeThreshold => write!(f, "SizeThreshold"),
356 Self::Adaptive => write!(f, "Adaptive"),
357 Self::Custom(_) => write!(f, "Custom(_)"),
358 }
359 }
360}
361
362#[derive(Default)]
364pub struct CompressionStats {
365 pub total_original: usize,
366 pub total_compressed: usize,
367 pub samples: usize,
368 pub average_ratio: f64,
369}
370
371impl CompressionStats {
372 #[cfg(not(target_arch = "wasm32"))]
373 pub fn update(&mut self, data: &[u8]) -> bool {
374 const SAMPLE_SIZE: usize = 1024;
375 const MIN_RATIO: f64 = 0.8;
376
377 let sample = if data.len() > SAMPLE_SIZE {
378 &data[..SAMPLE_SIZE]
379 } else {
380 data
381 };
382
383 let compressed = zstd::bulk::compress(sample, 1).unwrap_or_else(|_| sample.to_vec());
384 let ratio = compressed.len() as f64 / sample.len() as f64;
385
386 self.samples += 1;
387 self.total_original += sample.len();
388 self.total_compressed += compressed.len();
389 self.average_ratio =
390 (self.average_ratio * (self.samples - 1) as f64 + ratio) / self.samples as f64;
391
392 self.average_ratio < MIN_RATIO
393 }
394
395 #[cfg(not(target_arch = "wasm32"))]
396 pub fn update_with_threshold(&mut self, data: &[u8], threshold_multiplier: f64) -> bool {
397 const SAMPLE_SIZE: usize = 1024;
398 const BASE_MIN_RATIO: f64 = 0.85;
399
400 let sample = if data.len() > SAMPLE_SIZE {
401 &data[..SAMPLE_SIZE]
402 } else {
403 data
404 };
405
406 let compressed = zstd::bulk::compress(sample, 3).unwrap_or_else(|_| sample.to_vec());
407 let ratio = compressed.len() as f64 / sample.len() as f64;
408
409 const ALPHA: f64 = 0.5;
410 self.samples += 1;
411 self.total_original += sample.len();
412 self.total_compressed += compressed.len();
413 self.average_ratio = (1.0 - ALPHA) * self.average_ratio + ALPHA * ratio;
414
415 let adjusted_threshold = BASE_MIN_RATIO * threshold_multiplier;
416 self.average_ratio < adjusted_threshold
417 }
418
419 #[cfg(target_arch = "wasm32")]
420 pub fn update(&mut self, data: &[u8]) -> bool {
421 const SAMPLE_SIZE: usize = 1024;
422 const MIN_RATIO: f64 = 0.8;
423
424 let sample = if data.len() > SAMPLE_SIZE {
425 &data[..SAMPLE_SIZE]
426 } else {
427 data
428 };
429
430 let compressed = compress_prepend_size(sample);
431 let ratio = compressed.len() as f64 / sample.len() as f64;
432
433 self.samples += 1;
434 self.total_original += sample.len();
435 self.total_compressed += compressed.len();
436 self.average_ratio =
437 (self.average_ratio * (self.samples - 1) as f64 + ratio) / self.samples as f64;
438
439 self.average_ratio < MIN_RATIO
440 }
441
442 #[cfg(target_arch = "wasm32")]
443 pub fn update_with_threshold(&mut self, data: &[u8], threshold_multiplier: f64) -> bool {
444 const SAMPLE_SIZE: usize = 1024;
445 const BASE_MIN_RATIO: f64 = 0.8;
446
447 let sample = if data.len() > SAMPLE_SIZE {
448 &data[..SAMPLE_SIZE]
449 } else {
450 data
451 };
452
453 let compressed = compress_prepend_size(sample);
454 let ratio = compressed.len() as f64 / sample.len() as f64;
455
456 const ALPHA: f64 = 0.1;
457 self.samples += 1;
458 self.total_original += sample.len();
459 self.total_compressed += compressed.len();
460 self.average_ratio = (1.0 - ALPHA) * self.average_ratio + ALPHA * ratio;
461
462 let adjusted_threshold = BASE_MIN_RATIO * threshold_multiplier;
463 self.average_ratio < adjusted_threshold
464 }
465}
466
467impl Message {
468 pub fn encode(&self) -> Result<Vec<u8>, MessageError> {
470 let encoded = bitcode::encode(self);
471
472 Ok(encoded)
473 }
474
475 pub fn decode(bytes: &[u8]) -> Result<Self, MessageError> {
477 bitcode::decode(bytes).map_err(|e| MessageError::Decoding(e.to_string()))
478 }
479
480 pub fn decode_with_config(
481 bytes: &[u8],
482 config: CompressionConfig,
483 ) -> Result<Self, MessageError> {
484 Self::decode_compressed(bytes, &config)
485 }
486
487 pub fn get_type(&self) -> PortType {
489 match self {
490 Message::Flow => PortType::Flow,
491 Message::Event(_) => PortType::Event,
492 Message::Boolean(_) => PortType::Boolean,
493 Message::Integer(_) => PortType::Integer,
494 Message::Float(_) => PortType::Float,
495 Message::String(_) => PortType::String,
496 Message::Object(_v) => {
497 PortType::Object("Dynamic".to_string())
505 }
506 Message::Array(_arr) => PortType::Array(Box::new(PortType::Any)),
507 Message::Bytes(_) => PortType::Bytes,
508 Message::StreamHandle(_) => PortType::Stream,
509 Message::Optional(_opt) => PortType::Option(Box::new(PortType::Any)),
510 Message::Any(_) => PortType::Any,
511 Message::Error(_) => PortType::String,
512 Message::Encoded(..) => PortType::Encoded,
513 Message::RemoteReference { .. } => PortType::Any,
514 Message::NetworkEvent { .. } => PortType::Event,
515 }
516 }
517
518 pub fn validate_type(&self, port_type: &PortType) -> Result<(), MessageError> {
520 match (self, port_type) {
521 (msg, t) if msg.get_type() == *t => Ok(()),
523
524 (Message::Integer(_), PortType::Float) => Ok(()),
526
527 (Message::Array(arr), PortType::Array(_elem_type)) => {
529 arr.iter().try_for_each(|_elem| Ok(()))
530 }
531
532 (Message::Optional(_opt), PortType::Option(_inner_type)) => Ok(()),
534
535 (_, PortType::Any) => Ok(()),
552
553 _ => Err(MessageError::TypeMismatch(format!(
554 "Expected {:?}, got {:?}",
555 port_type,
556 self.get_type()
557 ))),
558 }
559 }
560
561 pub fn encoded_size(&self) -> Result<usize, MessageError> {
563 self.encode().map(|bytes| bytes.len())
564 }
565
566 pub fn encode_with_config(
568 &self,
569 config: &CompressionConfig,
570 ) -> Result<EncodedMessage, MessageError> {
571 if !config.enabled {
572 return Ok(EncodedMessage(bitcode::encode(self)));
573 }
574
575 let strategy = self.get_compression_strategy(config);
576 let encoded = bitcode::encode(self);
577
578 match strategy {
579 CompressionStrategy::Never => Ok(EncodedMessage(encoded)),
580 CompressionStrategy::Always => {
581 Ok(EncodedMessage(self.compress_data(&encoded, config)?))
582 }
583 CompressionStrategy::SizeThreshold => {
584 if encoded.len() >= config.size_threshold {
585 Ok(EncodedMessage(self.compress_data(&encoded, config)?))
586 } else {
587 Ok(EncodedMessage(encoded))
588 }
589 }
590 CompressionStrategy::Adaptive => {
591 if self.should_compress_adaptive(&encoded) {
593 Ok(EncodedMessage(self.compress_data(&encoded, config)?))
594 } else {
595 Ok(EncodedMessage(encoded))
596 }
597 }
598 CompressionStrategy::Custom(strategy_fn) => {
599 if strategy_fn(self) {
600 Ok(EncodedMessage(self.compress_data(&encoded, config)?))
601 } else {
602 Ok(EncodedMessage(encoded))
603 }
604 }
605 }
606 }
607
608 fn get_compression_strategy(&self, config: &CompressionConfig) -> CompressionStrategy {
610 let type_name = self.type_name();
611
612 config
613 .type_strategies
614 .get(type_name)
615 .cloned()
616 .unwrap_or(CompressionStrategy::SizeThreshold)
617 }
618
619 pub fn compress_data(
621 &self,
622 data: &[u8],
623 config: &CompressionConfig,
624 ) -> Result<Vec<u8>, MessageError> {
625 if config.enabled {
626 return if data.len() >= config.streaming_threshold {
627 self.compress_streaming(data, config)
628 } else {
629 self.compress_normal(data, config)
630 };
631 }
632 Ok(data.to_vec())
633 }
634
635 #[cfg(not(target_arch = "wasm32"))]
637 fn compress_normal(
638 &self,
639 data: &[u8],
640 config: &CompressionConfig,
641 ) -> Result<Vec<u8>, MessageError> {
642 let mut encoder = flate2::Compress::new(flate2::Compression::new(config.level), false);
643
644 let mut compressed = Vec::with_capacity(data.len());
645 encoder
646 .compress_vec(data, &mut compressed, flate2::FlushCompress::Finish)
647 .map_err(|e| MessageError::Compression(e.to_string()))?;
648
649 Ok(compressed)
650 }
651
652 #[cfg(target_arch = "wasm32")]
653 fn compress_normal(
654 &self,
655 data: &[u8],
656 config: &CompressionConfig,
657 ) -> Result<Vec<u8>, MessageError> {
658 let compressed = compress_prepend_size(data);
659 Ok(compressed)
660 }
661
662 #[cfg(not(target_arch = "wasm32"))]
664 pub fn compress_streaming(
665 &self,
666 data: &[u8],
667 config: &CompressionConfig,
668 ) -> Result<Vec<u8>, MessageError> {
669 use std::io::Write;
670
671 if !zstd::compression_level_range().contains(&(config.level as i32)) {
672 return Err(MessageError::Compression(format!(
673 "Invalid compression level {}",
674 config.level
675 )));
676 }
677 let mut encoder = zstd::Encoder::new(Vec::new(), config.level as i32)
678 .map_err(|e| MessageError::Compression(e.to_string()))?;
679
680 for chunk in data.chunks(64 * 1024) {
682 encoder
684 .write_all(chunk)
685 .map_err(|e| MessageError::Compression(e.to_string()))?;
686 }
687
688 encoder
689 .finish()
690 .map_err(|e| MessageError::Compression(e.to_string()))
691 }
692
693 fn type_name(&self) -> &'static str {
694 match self {
695 Message::Flow => "Flow",
696 Message::Event(_) => "Event",
697 Message::Boolean(_) => "Boolean",
698 Message::Integer(_) => "Integer",
699 Message::Float(_) => "Float",
700 Message::String(_) => "String",
701 Message::Object(_) => "Object",
702 Message::Array(_) => "Array",
703 Message::Bytes(_) => "Bytes",
704 Message::StreamHandle(_) => "StreamHandle",
705 Message::Optional(_) => "Optional",
706 Message::Any(_) => "Any",
707 Message::Error(_) => "Error",
708 Message::Encoded(..) => "Encoded",
709 Message::RemoteReference { .. } => "NetworkReference",
710 Message::NetworkEvent { .. } => "NetworkEvent",
711 }
712 }
713
714 #[cfg(target_arch = "wasm32")]
715 fn compress_streaming(
716 &self,
717 data: &[u8],
718 config: &CompressionConfig,
719 ) -> Result<Vec<u8>, MessageError> {
720 const CHUNK_SIZE: usize = 64 * 1024; let mut compressed = Vec::new();
722
723 for chunk in data.chunks(CHUNK_SIZE) {
724 let chunk_compressed = compress_prepend_size(chunk);
725 compressed.extend_from_slice(&(chunk_compressed.len() as u32).to_le_bytes());
726 compressed.extend_from_slice(&chunk_compressed);
727 }
728
729 Ok(compressed)
730 }
731
732 pub(crate) fn should_compress_adaptive(&self, data: &[u8]) -> bool {
734 const MAX_HISTORY_SIZE: usize = 1000;
735 const CLEANUP_THRESHOLD: usize = 10000;
736
737 static HISTORY: Lazy<RwLock<HashMap<String, (CompressionStats, std::time::Instant)>>> =
739 Lazy::new(|| RwLock::new(HashMap::new()));
740
741 let type_name = self.type_name();
742 let mut history = HISTORY.write();
743
744 if history.len() > CLEANUP_THRESHOLD {
746 history.retain(|_, (_, last_access)| {
747 last_access.elapsed() < std::time::Duration::from_secs(3600)
748 });
749 }
750
751 let (stats, last_access) = history
753 .entry(type_name.to_string())
754 .or_insert_with(|| (CompressionStats::default(), std::time::Instant::now()));
755
756 *last_access = std::time::Instant::now();
757
758 let threshold_multiplier = match self {
760 Message::Bytes(_) => 0.7, Message::String(_) => 1.5, Message::Array(_) => 0.8, _ => 0.8, };
765
766 if stats.samples > MAX_HISTORY_SIZE {
768 *stats = CompressionStats::default();
769 }
770
771 stats.update_with_threshold(data, threshold_multiplier)
773 }
774
775 #[cfg(not(target_arch = "wasm32"))]
777 pub fn decode_compressed(
778 bytes: &[u8],
779 config: &CompressionConfig,
780 ) -> Result<Self, MessageError> {
781 if !config.enabled {
782 return bitcode::decode(bytes).map_err(|e| MessageError::Decoding(e.to_string()));
783 }
784
785 if bytes.len() > 4 && bytes[0..4] == [0x28, 0xB5, 0x2F, 0xFD] {
787 let decoded =
789 zstd::decode_all(bytes).map_err(|e| MessageError::Compression(e.to_string()))?;
790 bitcode::decode(&decoded).map_err(|e| MessageError::Decoding(e.to_string()))
791 } else {
792 let decoded = &mut Vec::new();
793 match flate2::read::GzDecoder::new(bytes).read_to_end(decoded) {
795 Ok(_) => {
796 bitcode::decode(decoded).map_err(|e| MessageError::Decoding(e.to_string()))
797 }
798 Err(_) => {
799 bitcode::decode(bytes).map_err(|e| MessageError::Decoding(e.to_string()))
801 }
802 }
803 }
804 }
805
806 #[cfg(target_arch = "wasm32")]
807 fn decode_compressed(bytes: &[u8], config: &CompressionConfig) -> Result<Self, MessageError> {
808 if !config.enabled {
809 return bitcode::decode(bytes).map_err(|e| MessageError::Decoding(e.to_string()));
810 }
811
812 match decompress_size_prepended(bytes) {
814 Ok(decompressed) => {
815 bitcode::decode(&decompressed).map_err(|e| MessageError::Decoding(e.to_string()))
816 }
817 Err(_) => {
818 let mut position = 0;
820 let mut decompressed = Vec::new();
821
822 while position + 4 <= bytes.len() {
823 let chunk_size =
824 u32::from_le_bytes(bytes[position..position + 4].try_into().unwrap())
825 as usize;
826 position += 4;
827
828 if position + chunk_size > bytes.len() {
829 break;
830 }
831
832 let chunk = &bytes[position..position + chunk_size];
833 match decompress_size_prepended(chunk) {
834 Ok(chunk_decompressed) => {
835 decompressed.extend_from_slice(&chunk_decompressed)
836 }
837 Err(e) => return Err(MessageError::Compression(e.to_string())),
838 }
839 position += chunk_size;
840 }
841
842 if !decompressed.is_empty() {
843 bitcode::decode(&decompressed)
844 .map_err(|e| MessageError::Decoding(e.to_string()))
845 } else {
846 bitcode::decode(bytes).map_err(|e| MessageError::Decoding(e.to_string()))
848 }
849 }
850 }
851 }
852}
853
854#[derive(Debug)]
855pub enum MessageError {
856 TypeMismatch(String),
857 Encoding(String),
858 Decoding(String),
859 Validation(String),
860 Compression(String),
861}
862
863impl From<Value> for Message {
864 fn from(value: Value) -> Self {
865 match value {
866 Value::Null => Message::Optional(None),
867 Value::Bool(b) => Message::Boolean(b),
868 Value::Number(n) => {
869 if n.is_i64() {
870 Message::Integer(n.as_i64().unwrap())
871 } else {
872 Message::Float(n.as_f64().unwrap())
873 }
874 }
875 Value::String(s) => Message::String(Arc::new(s)),
876 Value::Array(vec) => Message::array(vec.into_iter().map(|v| v.into()).collect()),
877 Value::Object(_) => Message::Object(Arc::new(EncodableValue::from(value))),
878 }
879 }
880}
881
882impl From<Message> for Value {
883 fn from(val: Message) -> Self {
884 match val {
885 Message::Flow => Value::String("flow".to_string()),
886 Message::Event(v) => v.into(),
887 Message::Boolean(b) => Value::Bool(b),
888 Message::Integer(i) => Value::Number(i.into()),
889 Message::Float(f) => Value::Number(serde_json::Number::from_f64(f).unwrap()),
890 Message::String(s) => Value::String(s.as_str().to_string()),
891 Message::Object(v) => v.as_ref().clone().into(),
892 Message::Array(arr) => Value::Array(
893 arr.iter()
894 .map(|m| m.clone().into())
896 .collect(),
897 ),
898 Message::Bytes(bytes) => Value::Array(
899 <Vec<u8> as Clone>::clone(&bytes)
900 .into_iter()
901 .map(|b| Value::Number(b.into()))
902 .collect(),
903 ),
904 Message::StreamHandle(handle) => json!({
905 "stream_id": handle.stream_id,
906 "origin_actor": handle.origin_actor,
907 "origin_port": handle.origin_port,
908 "content_type": handle.content_type,
909 "size_hint": handle.size_hint,
910 }),
911 Message::Optional(opt) => match opt {
912 Some(m) => Value::from(m.as_ref().clone()),
913 None => Value::Null,
914 },
915 Message::Any(v) => v.as_ref().clone().into(),
916 Message::Error(e) => Value::String(e.as_str().to_string()),
917 Message::Encoded(encoded) => bitcode::decode::<Message>(&encoded)
918 .expect("Failed to decode message")
919 .into(),
920 Message::RemoteReference {
921 network_id,
922 actor_id,
923 port,
924 } => json!({
925 "network_id": network_id,
926 "actor_id": actor_id,
927 "port": port
928 }),
929 Message::NetworkEvent { event_type, data } => json!({
930 "event_type": event_type,
931 "data": serde_json::Value::from(data)
932 }),
933 }
934 }
935}
936
937#[cfg(target_arch = "wasm32")]
939impl From<JsValue> for Message {
940 fn from(value: JsValue) -> Self {
941 if let Ok(val) = value.into_serde::<Value>() {
942 match val {
943 Value::Bool(b) => Message::Boolean(b),
944 Value::Number(n) => {
945 if n.is_i64() {
946 Message::Integer(n.as_i64().unwrap())
947 } else {
948 Message::Float(n.as_f64().unwrap())
949 }
950 }
951 Value::String(s) => Message::String(Arc::new(s)),
952 Value::Array(arr) => {
953 Message::array(arr.into_iter().map(|v| EncodableValue::from(v)).collect())
954 }
955 Value::Object(obj) => {
956 Message::Object(Arc::new(EncodableValue::from(Value::Object(obj))))
957 }
958 Value::Null => Message::Optional(None),
959 }
960 } else {
961 Message::Error(Arc::new("Invalid JS value".to_string()))
962 }
963 }
964}
965
966#[cfg(target_arch = "wasm32")]
967impl Into<JsValue> for Message {
968 fn into(self) -> JsValue {
969 match self {
970 Message::Flow => JsValue::from_str("flow"),
971 Message::Event(v) => JsValue::from_serde(&v).unwrap_or_default(),
972 Message::Boolean(b) => JsValue::from_bool(b),
973 Message::Integer(i) => JsValue::from_f64(i as f64),
974 Message::Float(f) => JsValue::from_f64(f),
975 Message::String(s) => JsValue::from_str(&s),
976 Message::Object(v) => JsValue::from_serde(&v).unwrap_or_default(),
977 Message::Array(arr) => {
978 let js_arr = js_sys::Array::new();
979 for msg in arr.iter() {
980 if let Ok(js_val) = JsValue::from_serde(&msg) {
981 js_arr.push(&js_val);
982 }
983 }
984 js_arr.into()
985 }
986 Message::Bytes(bytes) => {
987 let array = js_sys::Uint8Array::new_with_length(bytes.len() as u32);
988 array.copy_from(&bytes);
989 array.into()
990 }
991 Message::StreamHandle(handle) => JsValue::from_serde(&json!({
992 "stream_id": handle.stream_id,
993 "origin_actor": handle.origin_actor,
994 "origin_port": handle.origin_port,
995 "content_type": handle.content_type,
996 "size_hint": handle.size_hint,
997 }))
998 .unwrap_or(JsValue::null()),
999 Message::Optional(opt) => match opt {
1000 Some(msg) => msg
1001 .decode()
1002 .map(|m: Message| m.into())
1003 .unwrap_or(JsValue::NULL),
1004 None => JsValue::NULL,
1005 },
1006 Message::Any(v) => JsValue::from_serde(&v).unwrap_or_default(),
1015 Message::Error(e) => JsValue::from_str(&e),
1016 Message::Encoded(encoded) => {
1017 let decoded = bitcode::decode::<Message>(&encoded).unwrap_or_default();
1018 decoded.into()
1019 }
1020 Message::RemoteReference {
1021 network_id,
1022 actor_id,
1023 port,
1024 } => JsValue::from_serde(&json!({
1025 "network_id": network_id,
1026 "actor_id": actor_id,
1027 "port": port
1028 }))
1029 .unwrap_or(JsValue::null()),
1030 Message::NetworkEvent { event_type, data } => JsValue::from_serde(&json!({
1031 "event_type": event_type,
1032 "data": serde_json::Value::from(data)
1033 }))
1034 .unwrap_or(JsValue::null()),
1035 }
1036 }
1037}
1038
1039#[derive(Clone, Debug, Serialize, Deserialize, Encode, Decode, PartialEq, Eq)]
1041pub struct EncodableValue {
1042 pub(crate) data: Vec<u8>,
1043}
1044
1045impl EncodableValue {
1046 pub fn new<T: Encode>(value: &T) -> Self {
1047 Self {
1048 data: bitcode::encode(value),
1049 }
1050 }
1051
1052 pub fn decode<'a, T: Decode<'a>>(&'a self) -> Option<T> {
1053 bitcode::decode(&self.data).ok()
1054 }
1055
1056 #[allow(dead_code)]
1057 pub(crate) fn len(&self) -> usize {
1058 self.data.len()
1059 }
1060}
1061
1062impl From<Value> for EncodableValue {
1064 fn from(v: Value) -> Self {
1065 Self {
1066 data: serde_json::to_vec(&v).unwrap_or_default(),
1067 }
1068 }
1069}
1070
1071impl From<EncodableValue> for Value {
1072 fn from(v: EncodableValue) -> Self {
1073 serde_json::from_slice(&v.data).unwrap_or(Value::Null)
1074 }
1075}
1076
1077impl Message {
1079 pub fn object(value: EncodableValue) -> Self {
1080 Message::Object(Arc::new(value))
1081 }
1082 pub fn any(value: EncodableValue) -> Self {
1083 Message::Any(Arc::new(value))
1084 }
1085 pub fn event(value: EncodableValue) -> Self {
1086 Message::Event(value)
1087 }
1088
1089 pub fn array(messages: Vec<EncodableValue>) -> Self {
1090 Message::Array(Arc::new(messages))
1091 }
1092 pub fn bytes(bytes: Vec<u8>) -> Self {
1093 Message::Bytes(bytes.into())
1094 }
1095 pub fn stream_handle(handle: StreamHandle) -> Self {
1096 Message::StreamHandle(Arc::new(handle))
1097 }
1098 pub fn encoded(encoded: Vec<u8>) -> Self {
1099 Message::Encoded(Arc::new(encoded))
1100 }
1101 pub fn error(msg: String) -> Self {
1102 Message::Error(msg.into())
1103 }
1104 pub fn boolean(value: bool) -> Self {
1105 Message::Boolean(value)
1106 }
1107 pub fn integer(value: i64) -> Self {
1108 Message::Integer(value)
1109 }
1110 pub fn float(value: f64) -> Self {
1111 Message::Float(value)
1112 }
1113 pub fn string(value: String) -> Self {
1114 Message::String(Arc::new(value))
1115 }
1116 pub fn flow() -> Self {
1117 Message::Flow
1118 }
1119
1120 pub fn optional(msg: Option<EncodableValue>) -> Self {
1121 Message::Optional(msg.map(Arc::new))
1122 }
1123}