1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;

use datacake_rpc::Channel;
use parking_lot::RwLock;
use tracing::trace;

#[derive(Clone, Default)]
/// A collection of RPC client connections which can be reused and multiplexed.
pub struct RpcNetwork {
    clients: Arc<RwLock<HashMap<SocketAddr, Channel>>>,
}

impl RpcNetwork {
    /// Attempts to get an already existing connection or creates a new connection.
    pub fn get_or_connect(&self, addr: SocketAddr) -> Channel {
        {
            let guard = self.clients.read();
            if let Some(channel) = guard.get(&addr) {
                return channel.clone();
            }
        }

        trace!(addr = %addr, "Connect client to network.");
        self.connect(addr)
    }

    /// Connects to a given address and adds it to the clients.
    pub fn connect(&self, addr: SocketAddr) -> Channel {
        let channel = Channel::connect(addr);

        {
            let mut guard = self.clients.write();
            guard.insert(addr, channel.clone());
        }

        channel
    }

    /// Removes a client from the network.
    pub fn disconnect(&self, addr: SocketAddr) {
        let mut guard = self.clients.write();
        guard.remove(&addr);
    }
}