protosocket_server/
connection_server.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;

use protosocket::Connection;
use protosocket::ConnectionBindings;
use protosocket::Serializer;
use tokio::sync::mpsc;

pub trait ServerConnector: Unpin {
    type Bindings: ConnectionBindings;

    fn serializer(&self) -> <Self::Bindings as ConnectionBindings>::Serializer;
    fn deserializer(&self) -> <Self::Bindings as ConnectionBindings>::Deserializer;

    fn new_reactor(
        &self,
        optional_outbound: mpsc::Sender<
            <<Self::Bindings as ConnectionBindings>::Serializer as Serializer>::Message,
        >,
    ) -> <Self::Bindings as ConnectionBindings>::Reactor;

    fn maximum_message_length(&self) -> usize {
        4 * (2 << 20)
    }

    fn max_queued_outbound_messages(&self) -> usize {
        256
    }
}

/// A `protosocket::Connection` is an IO driver. It directly uses tokio's io wrapper of mio to poll
/// the OS's io primitives, manages read and write buffers, and vends messages to & from connections.
/// Connections send messages to the ConnectionServer through an mpsc channel, and they receive
/// inbound messages via a reactor callback.
///
/// Protosockets are monomorphic messages: You can only have 1 kind of message per service.
/// The expected way to work with this is to use prost and protocol buffers to encode messages.
/// Of course you can do whatever you want, as the telnet example shows.
///
/// Protosocket messages are not opinionated about request & reply. If you are, you will need
/// to implement such a thing. This allows you freely choose whether you want to send
/// fire-&-forget messages sometimes; however it requires you to write your protocol's rules.
/// You get an inbound iterable of <MessageIn> batches and an outbound stream of <MessageOut> per
/// connection - you decide what those mean for you!
///
/// A ProtosocketServer is a future: You spawn it and it runs forever.
pub struct ProtosocketServer<Connector: ServerConnector> {
    connector: Connector,
    listener: tokio::net::TcpListener,
    max_buffer_length: usize,
    max_queued_outbound_messages: usize,
    runtime: tokio::runtime::Handle,
}

impl<Connector: ServerConnector> ProtosocketServer<Connector> {
    /// Construct a new `ProtosocketServer` listening on the provided address.
    /// The address will be bound and listened upon with `SO_REUSEADDR` set.
    /// The server will use the provided runtime to spawn new tcp connections as `protosocket::Connection`s.
    pub async fn new(
        address: std::net::SocketAddr,
        runtime: tokio::runtime::Handle,
        connector: Connector,
    ) -> crate::Result<Self> {
        let listener = tokio::net::TcpListener::bind(address)
            .await
            .map_err(Arc::new)?;
        Ok(Self {
            connector,
            listener,
            max_buffer_length: 16 * (2 << 20),
            max_queued_outbound_messages: 128,
            runtime,
        })
    }

    /// Set the maximum buffer length for connections created by this server after the setting is applied.
    pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) {
        self.max_buffer_length = max_buffer_length;
    }

    /// Set the maximum queued outbound messages for connections created by this server after the setting is applied.
    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
        self.max_queued_outbound_messages = max_queued_outbound_messages;
    }
}

impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
    type Output = Result<(), Error>;

    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
        loop {
            break match self.listener.poll_accept(context) {
                Poll::Ready(result) => match result {
                    Ok((stream, address)) => {
                        stream.set_nodelay(true)?;
                        let (outbound_submission_queue, outbound_messages) =
                            mpsc::channel(self.max_queued_outbound_messages);
                        let reactor = self
                            .connector
                            .new_reactor(outbound_submission_queue.clone());
                        let connection: Connection<Connector::Bindings> = Connection::new(
                            stream,
                            address,
                            self.connector.deserializer(),
                            self.connector.serializer(),
                            self.max_buffer_length,
                            self.max_queued_outbound_messages,
                            outbound_messages,
                            reactor,
                        );
                        self.runtime.spawn(connection);
                        continue;
                    }
                    Err(e) => {
                        log::error!("failed to accept connection: {e:?}");
                        continue;
                    }
                },
                Poll::Pending => Poll::Pending,
            };
        }
    }
}