Skip to main content

aura_agent/runtime/effects/
network.rs

1use super::AuraEffectSystem;
2use crate::core::default_context_id_for_authority;
3use async_trait::async_trait;
4use aura_core::effects::network::PeerEventStream;
5#[cfg(not(target_arch = "wasm32"))]
6use aura_core::effects::time::PhysicalTimeEffects;
7use aura_core::effects::transport::TransportEnvelope;
8use aura_core::effects::{
9    NetworkCoreEffects, NetworkError, NetworkExtendedEffects, RandomExtendedEffects,
10    TransportEffects, TransportError,
11};
12use aura_core::types::identifiers::AuthorityId;
13#[cfg(not(target_arch = "wasm32"))]
14use aura_core::{execute_with_timeout_budget, TimeoutBudget, TimeoutRunError};
15#[cfg(not(target_arch = "wasm32"))]
16use aura_effects::time::PhysicalTimeHandler;
17use aura_protocol::amp::deserialize_amp_message;
18use cfg_if::cfg_if;
19#[cfg(target_arch = "wasm32")]
20use futures::SinkExt;
21#[cfg(target_arch = "wasm32")]
22use gloo_net::websocket::{futures::WebSocket, Message};
23use std::collections::HashMap;
24use std::collections::HashSet;
25#[cfg(target_arch = "wasm32")]
26use std::future::Future;
27#[cfg(not(target_arch = "wasm32"))]
28use std::net::SocketAddr;
29#[cfg(not(target_arch = "wasm32"))]
30use tokio::io::AsyncWriteExt;
31const NETWORK_CONTENT_TYPE: &str = "application/aura-network";
32const CONNECTION_ID_PREFIX: &str = "conn-";
33
34#[cfg(not(target_arch = "wasm32"))]
35async fn execute_network_timeout<F, Fut, T>(
36    timeout: std::time::Duration,
37    timeout_error: impl Fn() -> NetworkError + Copy,
38    f: F,
39) -> Result<T, NetworkError>
40where
41    F: FnOnce() -> Fut,
42    Fut: std::future::Future<Output = Result<T, NetworkError>>,
43{
44    let time = PhysicalTimeHandler::new();
45    let started_at = time.physical_time().await.map_err(|_| timeout_error())?;
46    let budget =
47        TimeoutBudget::from_start_and_timeout(&started_at, timeout).map_err(|_| timeout_error())?;
48    execute_with_timeout_budget(&time, &budget, f)
49        .await
50        .map_err(|error| match error {
51            TimeoutRunError::Timeout(_) => timeout_error(),
52            TimeoutRunError::Operation(error) => error,
53        })
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57struct ConnectionId(uuid::Uuid);
58
59impl ConnectionId {
60    fn new(id: uuid::Uuid) -> Self {
61        Self(id)
62    }
63
64    fn as_uuid(&self) -> uuid::Uuid {
65        self.0
66    }
67
68    fn to_wire(self) -> String {
69        format!("{CONNECTION_ID_PREFIX}{}", self.0)
70    }
71
72    fn parse_wire(value: &str) -> Result<Self, NetworkError> {
73        let raw = value.strip_prefix(CONNECTION_ID_PREFIX).unwrap_or(value);
74        let id = uuid::Uuid::parse_str(raw).map_err(|e| {
75            NetworkError::ConnectionFailed(format!("invalid connection id `{value}`: {e}"))
76        })?;
77        Ok(Self(id))
78    }
79}
80
81// Implementation of NetworkEffects
82#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
83#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
84impl NetworkCoreEffects for AuraEffectSystem {
85    async fn send_to_peer(
86        &self,
87        peer_id: uuid::Uuid,
88        message: Vec<u8>,
89    ) -> Result<(), NetworkError> {
90        if self.execution_mode.is_deterministic() {
91            if let Some(shared) = self.transport.shared_transport() {
92                let peer = AuthorityId::from_uuid(peer_id);
93                let mut metadata = HashMap::new();
94                metadata.insert("content-type".to_string(), NETWORK_CONTENT_TYPE.to_string());
95                let envelope = TransportEnvelope {
96                    destination: peer,
97                    source: self.authority_id,
98                    context: default_context_id_for_authority(peer),
99                    payload: message,
100                    metadata,
101                    receipt: None,
102                };
103                shared.route_envelope(envelope);
104                return Ok(());
105            }
106            self.ensure_mock_network()?;
107            return Ok(());
108        }
109
110        let peer = AuthorityId::from_uuid(peer_id);
111        let mut metadata = HashMap::new();
112        metadata.insert("content-type".to_string(), NETWORK_CONTENT_TYPE.to_string());
113        let envelope = TransportEnvelope {
114            destination: peer,
115            source: self.authority_id,
116            context: default_context_id_for_authority(peer),
117            payload: message,
118            metadata,
119            receipt: None,
120        };
121
122        TransportEffects::send_envelope(self, envelope)
123            .await
124            .map_err(|e| NetworkError::SendFailed {
125                peer_id: Some(peer_id),
126                reason: e.to_string(),
127            })?;
128        Ok(())
129    }
130
131    async fn broadcast(&self, message: Vec<u8>) -> Result<(), NetworkError> {
132        if self.execution_mode.is_deterministic() {
133            self.ensure_mock_network()?;
134            let Some(shared) = self.transport.shared_transport() else {
135                return Err(NetworkError::BroadcastFailed {
136                    reason: "shared transport not configured".to_string(),
137                });
138            };
139
140            let wire = deserialize_amp_message(&message).map_err(|e| {
141                NetworkError::SerializationFailed {
142                    error: e.to_string(),
143                }
144            })?;
145
146            let mut metadata = HashMap::new();
147            metadata.insert(
148                "content-type".to_string(),
149                super::AMP_CONTENT_TYPE.to_string(),
150            );
151
152            let source = self.authority_id;
153            let context = wire.header.context;
154
155            for peer in shared.online_peers() {
156                if peer == source {
157                    continue;
158                }
159
160                let envelope = TransportEnvelope {
161                    destination: peer,
162                    source,
163                    context,
164                    payload: message.clone(),
165                    metadata: metadata.clone(),
166                    receipt: None,
167                };
168
169                shared.route_envelope(envelope);
170            }
171
172            return Ok(());
173        }
174
175        let peers: HashSet<uuid::Uuid> = self.connected_peers().await.into_iter().collect();
176        for peer in peers {
177            let _ = self.send_to_peer(peer, message.clone()).await;
178        }
179        Ok(())
180    }
181
182    async fn receive(&self) -> Result<(uuid::Uuid, Vec<u8>), NetworkError> {
183        let envelope = match TransportEffects::receive_envelope(self).await {
184            Ok(env) => env,
185            Err(TransportError::NoMessage) => return Err(NetworkError::NoMessage),
186            Err(e) => {
187                return Err(NetworkError::ReceiveFailed {
188                    reason: e.to_string(),
189                })
190            }
191        };
192
193        let Some(content_type) = envelope.metadata.get("content-type") else {
194            self.requeue_envelope(envelope);
195            return Err(NetworkError::NoMessage);
196        };
197
198        if content_type != NETWORK_CONTENT_TYPE {
199            self.requeue_envelope(envelope);
200            return Err(NetworkError::NoMessage);
201        }
202
203        Ok((envelope.source.uuid(), envelope.payload))
204    }
205}
206
207#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
208#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
209impl NetworkExtendedEffects for AuraEffectSystem {
210    async fn receive_from(&self, _peer_id: uuid::Uuid) -> Result<Vec<u8>, NetworkError> {
211        let peer_id = _peer_id;
212        let envelope = match TransportEffects::receive_envelope(self).await {
213            Ok(env) => env,
214            Err(TransportError::NoMessage) => return Err(NetworkError::NoMessage),
215            Err(e) => {
216                return Err(NetworkError::ReceiveFailed {
217                    reason: e.to_string(),
218                })
219            }
220        };
221
222        let Some(content_type) = envelope.metadata.get("content-type") else {
223            self.requeue_envelope(envelope);
224            return Err(NetworkError::NoMessage);
225        };
226
227        if content_type != NETWORK_CONTENT_TYPE || envelope.source.uuid() != peer_id {
228            self.requeue_envelope(envelope);
229            return Err(NetworkError::NoMessage);
230        }
231
232        Ok(envelope.payload)
233    }
234
235    async fn connected_peers(&self) -> Vec<uuid::Uuid> {
236        if let Some(shared) = self.transport.shared_transport() {
237            return shared
238                .online_peers()
239                .into_iter()
240                .map(|peer| peer.uuid())
241                .collect();
242        }
243
244        if let Some(manager) = self.rendezvous_manager() {
245            return manager
246                .list_cached_peers()
247                .await
248                .into_iter()
249                .map(|peer| peer.uuid())
250                .collect();
251        }
252
253        vec![]
254    }
255
256    async fn is_peer_connected(&self, _peer_id: uuid::Uuid) -> bool {
257        if let Some(shared) = self.transport.shared_transport() {
258            return shared.is_peer_online(AuthorityId::from_uuid(_peer_id));
259        }
260
261        if let Some(manager) = self.rendezvous_manager() {
262            let peer = AuthorityId::from_uuid(_peer_id);
263            let context = default_context_id_for_authority(peer);
264            return manager.get_descriptor(context, peer).await.is_some();
265        }
266
267        false
268    }
269
270    async fn subscribe_to_peer_events(&self) -> Result<PeerEventStream, NetworkError> {
271        self.ensure_mock_network()?;
272        Err(NetworkError::NotImplemented)
273    }
274
275    async fn open(&self, _address: &str) -> Result<String, NetworkError> {
276        if self.execution_mode.is_deterministic() {
277            self.ensure_mock_network()?;
278            return Err(NetworkError::NotImplemented);
279        }
280
281        cfg_if! {
282            if #[cfg(target_arch = "wasm32")] {
283                let connection_id = ConnectionId::new(self.random_uuid().await);
284                let ws_url = normalize_ws_url(_address);
285
286                run_local_ws(move || async move {
287                    let ws = WebSocket::open(&ws_url)
288                        .map_err(|e| format!("WebSocket open failed ({ws_url}): {e}"))?;
289                    ws.close(None, None)
290                        .map_err(|e| format!("WebSocket close failed ({ws_url}): {e}"))?;
291                    Ok(())
292                })
293                .await
294                .map_err(NetworkError::ConnectionFailed)?;
295
296                self.network_connections
297                    .write()
298                    .insert(connection_id.as_uuid(), normalize_ws_url(_address));
299                Ok(connection_id.to_wire())
300            } else {
301                let socket_addr = _address
302                    .parse::<SocketAddr>()
303                    .map_err(|e| NetworkError::ConnectionFailed(e.to_string()))?;
304                let _stream = tokio::net::TcpStream::connect(socket_addr)
305                    .await
306                    .map_err(|e| NetworkError::ConnectionFailed(e.to_string()))?;
307
308                let connection_id = ConnectionId::new(self.random_uuid().await);
309                self.network_connections
310                    .write()
311                    .insert(connection_id.as_uuid(), socket_addr);
312                Ok(connection_id.to_wire())
313            }
314        }
315    }
316
317    async fn send(&self, _connection_id: &str, _data: Vec<u8>) -> Result<(), NetworkError> {
318        if self.execution_mode.is_deterministic() {
319            self.ensure_mock_network()?;
320            return Err(NetworkError::NotImplemented);
321        }
322
323        cfg_if! {
324            if #[cfg(target_arch = "wasm32")] {
325                let connection_id = ConnectionId::parse_wire(_connection_id)?;
326                let ws_url = self
327                    .network_connections
328                    .read()
329                    .get(&connection_id.as_uuid())
330                    .cloned()
331                    .ok_or_else(|| NetworkError::SendFailed {
332                        peer_id: None,
333                        reason: format!("Unknown connection id `{_connection_id}`"),
334                    })?;
335
336                run_local_ws(move || async move {
337                    let mut ws = WebSocket::open(&ws_url)
338                        .map_err(|e| format!("WebSocket open failed ({ws_url}): {e}"))?;
339                    ws.send(Message::Bytes(_data))
340                        .await
341                        .map_err(|e| format!("WebSocket send failed ({ws_url}): {e}"))?;
342                    Ok(())
343                })
344                .await
345                .map_err(|reason| NetworkError::SendFailed {
346                    peer_id: None,
347                    reason,
348                })?;
349                Ok(())
350            } else {
351                let connection_id = ConnectionId::parse_wire(_connection_id)?;
352                let socket_addr = self
353                    .network_connections
354                    .read()
355                    .get(&connection_id.as_uuid())
356                    .copied()
357                    .ok_or_else(|| NetworkError::SendFailed {
358                        peer_id: None,
359                        reason: format!("Unknown connection id `{_connection_id}`"),
360                    })?;
361
362                let config = aura_effects::transport::TransportConfig::default();
363                let mut stream = execute_network_timeout(
364                    config.connect_timeout.get(),
365                    || NetworkError::OperationTimeout {
366                        operation: "network_send_connect".to_string(),
367                        timeout_ms: config.connect_timeout.get().as_millis() as u64,
368                    },
369                    || async {
370                        tokio::net::TcpStream::connect(socket_addr)
371                            .await
372                            .map_err(|e| NetworkError::ConnectionFailed(e.to_string()))
373                    },
374                )
375                .await?;
376
377                let len = (_data.len() as u32).to_be_bytes();
378                execute_network_timeout(
379                    config.write_timeout.get(),
380                    || NetworkError::OperationTimeout {
381                        operation: "network_send_len".to_string(),
382                        timeout_ms: config.write_timeout.get().as_millis() as u64,
383                    },
384                    || async {
385                        stream.write_all(&len).await.map_err(|e| NetworkError::SendFailed {
386                            peer_id: None,
387                            reason: e.to_string(),
388                        })
389                    },
390                )
391                .await?;
392                execute_network_timeout(
393                    config.write_timeout.get(),
394                    || NetworkError::OperationTimeout {
395                        operation: "network_send_payload".to_string(),
396                        timeout_ms: config.write_timeout.get().as_millis() as u64,
397                    },
398                    || async {
399                        stream.write_all(&_data).await.map_err(|e| NetworkError::SendFailed {
400                            peer_id: None,
401                            reason: e.to_string(),
402                        })
403                    },
404                )
405                .await?;
406
407                Ok(())
408            }
409        }
410    }
411
412    async fn close(&self, _connection_id: &str) -> Result<(), NetworkError> {
413        if self.execution_mode.is_deterministic() {
414            self.ensure_mock_network()?;
415        }
416        cfg_if! {
417            if #[cfg(target_arch = "wasm32")] {
418                let connection_id = ConnectionId::parse_wire(_connection_id)?;
419                let removed = self
420                    .network_connections
421                    .write()
422                    .remove(&connection_id.as_uuid());
423                if removed.is_none() {
424                    return Err(NetworkError::ConnectionFailed(format!(
425                        "unknown connection id `{_connection_id}`"
426                    )));
427                }
428            } else {
429                let connection_id = ConnectionId::parse_wire(_connection_id)?;
430                let removed = self
431                    .network_connections
432                    .write()
433                    .remove(&connection_id.as_uuid());
434                if removed.is_none() {
435                    return Err(NetworkError::ConnectionFailed(format!(
436                        "unknown connection id `{_connection_id}`"
437                    )));
438                }
439            }
440        }
441        Ok(())
442    }
443}
444
445#[cfg(target_arch = "wasm32")]
446fn normalize_ws_url(address: &str) -> String {
447    if address.starts_with("ws://") || address.starts_with("wss://") {
448        address.to_string()
449    } else {
450        format!("ws://{address}")
451    }
452}
453
454#[cfg(target_arch = "wasm32")]
455async fn run_local_ws<Mk, Fut>(make_fut: Mk) -> Result<(), String>
456where
457    Mk: FnOnce() -> Fut + 'static,
458    Fut: Future<Output = Result<(), String>> + 'static,
459{
460    make_fut().await
461}
462
463#[cfg(all(test, not(target_arch = "wasm32")))]
464#[allow(clippy::disallowed_methods)]
465mod tests {
466    use super::*;
467    use crate::core::AgentConfig;
468    use tokio::io::AsyncReadExt;
469
470    fn production_config_for_tests() -> AgentConfig {
471        let mut config = AgentConfig::default();
472        let path =
473            std::env::temp_dir().join(format!("aura-agent-network-test-{}", std::process::id()));
474        let _ = std::fs::create_dir_all(&path);
475        config.storage.base_path = path;
476        config
477    }
478
479    #[test]
480    fn connection_id_round_trip() {
481        let original = ConnectionId::new(uuid::Uuid::from_u128(
482            0x1234_5678_9abc_def0_1234_5678_9abc_def0,
483        ));
484        let wire = original.to_wire();
485        let parsed = ConnectionId::parse_wire(&wire).expect("parse connection id");
486        assert_eq!(parsed, original);
487    }
488
489    #[tokio::test]
490    async fn send_rejects_invalid_connection_handle() {
491        let authority_id = AuthorityId::new_from_entropy([1u8; 32]);
492        let effects = AuraEffectSystem::production(production_config_for_tests(), authority_id)
493            .expect("create production effects");
494        let err = effects
495            .send("not-a-connection-id", vec![1, 2, 3])
496            .await
497            .expect_err("invalid connection handle should fail");
498        assert!(matches!(err, NetworkError::ConnectionFailed(_)));
499    }
500
501    #[tokio::test]
502    async fn typed_connection_lifecycle_open_send_close() {
503        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
504            .await
505            .expect("bind listener");
506        let addr = listener.local_addr().expect("local addr");
507
508        let receiver = async move {
509            // open() does a connectivity check, consuming one accept.
510            let (_warmup, _) = listener.accept().await.expect("accept warmup");
511            let (mut stream, _) = listener.accept().await.expect("accept payload");
512            let mut len = [0u8; 4];
513            stream.read_exact(&mut len).await.expect("read length");
514            let payload_len = u32::from_be_bytes(len) as usize;
515            let mut payload = vec![0u8; payload_len];
516            stream.read_exact(&mut payload).await.expect("read payload");
517            payload
518        };
519
520        let authority_id = AuthorityId::new_from_entropy([2u8; 32]);
521        let effects = AuraEffectSystem::production(production_config_for_tests(), authority_id)
522            .expect("create production effects");
523        let connection_id = effects
524            .open(&addr.to_string())
525            .await
526            .expect("open connection");
527        assert!(
528            connection_id.starts_with(CONNECTION_ID_PREFIX),
529            "open should return opaque connection handle"
530        );
531
532        let payload = b"typed-handle-payload".to_vec();
533        let sender = async {
534            effects
535                .send(&connection_id, payload.clone())
536                .await
537                .expect("send payload");
538            effects
539                .close(&connection_id)
540                .await
541                .expect("close connection");
542        };
543
544        let (_, received) = tokio::join!(sender, receiver);
545        assert_eq!(received, payload);
546
547        let close_err = effects
548            .close(&connection_id)
549            .await
550            .expect_err("closing an already closed handle should fail");
551        assert!(matches!(close_err, NetworkError::ConnectionFailed(_)));
552    }
553}