1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use protosocket::{Connection, MessageReactor};
use tokio::{net::TcpStream, sync::mpsc};

use crate::{ProstClientConnectionBindings, ProstSerializer};

/// A factory for creating client connections to a `protosocket` server.
#[derive(Debug, Clone)]
pub struct ClientRegistry {
    max_buffer_length: usize,
    max_queued_outbound_messages: usize,
    runtime: tokio::runtime::Handle,
}

impl ClientRegistry {
    /// Construct a new client registry. Connections will be spawned on the provided runtime.
    pub fn new(runtime: tokio::runtime::Handle) -> Self {
        log::trace!("new client registry");
        Self {
            max_buffer_length: 4 * (2 << 20),
            max_queued_outbound_messages: 256,
            runtime,
        }
    }

    /// Sets the maximum read buffer length for connections created by this registry after
    /// the setting is applied.
    pub fn set_max_read_buffer_length(&mut self, max_buffer_length: usize) {
        self.max_buffer_length = max_buffer_length;
    }

    /// Sets the maximum queued outbound messages for connections created by this registry after
    /// the setting is applied.
    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
        self.max_queued_outbound_messages = max_queued_outbound_messages;
    }

    /// Get a new connection to a `protosocket` server.
    pub async fn register_client<Request, Response, Reactor>(
        &self,
        address: impl Into<String>,
        message_reactor: Reactor,
    ) -> crate::Result<mpsc::Sender<Request>>
    where
        Request: prost::Message + Default + Unpin + 'static,
        Response: prost::Message + Default + Unpin + 'static,
        Reactor: MessageReactor<Inbound = Response>,
    {
        let address = address.into().parse()?;
        let stream = TcpStream::connect(address)
            .await
            .map_err(std::sync::Arc::new)?;
        stream.set_nodelay(true).map_err(std::sync::Arc::new)?;
        let (outbound, outbound_messages) = mpsc::channel(self.max_queued_outbound_messages);
        let connection =
            Connection::<ProstClientConnectionBindings<Request, Response, Reactor>>::new(
                stream,
                address,
                ProstSerializer::default(),
                ProstSerializer::default(),
                self.max_buffer_length,
                self.max_queued_outbound_messages,
                outbound_messages,
                message_reactor,
            );
        self.runtime.spawn(connection);
        Ok(outbound)
    }
}