Skip to main content

infinite_db/infinitedb_server/
runtime.rs

1//! Tokio TCP server wiring [`crate::InfiniteDb`] to the API layer.
2
3use std::io;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpListener, TcpStream};
9use tokio::sync::Semaphore;
10
11use crate::infinitedb_core::branch::BranchId;
12use crate::infinitedb_core::snapshot::SnapshotId;
13use crate::InfiniteDb;
14
15use super::api::{handle_request, Request, Response};
16use super::session::{AccessLevel, Session, SpaceGrant};
17/// TCP server configuration.
18#[derive(Debug, Clone)]
19pub struct ServerConfig {
20    pub max_connections: usize,
21    pub default_branch: BranchId,
22}
23
24impl Default for ServerConfig {
25    fn default() -> Self {
26        Self {
27            max_connections: 128,
28            default_branch: BranchId::MAIN,
29        }
30    }
31}
32
33/// Length-framed TCP server over a shared [`crate::InfiniteDb`].
34pub struct Server {
35    listener: TcpListener,
36    db: Arc<InfiniteDb>,
37    config: ServerConfig,
38    grants: Vec<SpaceGrant>,
39    limiter: Arc<Semaphore>,
40}
41
42impl Server {
43    /// Bind `addr` and prepare to accept connections.
44    pub async fn bind(
45        addr: SocketAddr,
46        db: Arc<InfiniteDb>,
47        config: ServerConfig,
48        grants: Vec<SpaceGrant>,
49    ) -> io::Result<Self> {
50        let listener = TcpListener::bind(addr).await?;
51        let limiter = Arc::new(Semaphore::new(config.max_connections));
52        Ok(Self {
53            listener,
54            db,
55            config,
56            grants,
57            limiter,
58        })
59    }
60
61    pub fn local_addr(&self) -> io::Result<SocketAddr> {
62        self.listener.local_addr()
63    }
64
65    /// Accept connections until the listener is dropped.
66    pub async fn run(self) -> io::Result<()> {
67        loop {
68            let (stream, _) = self.listener.accept().await?;
69            let permit = Arc::clone(&self.limiter)
70                .acquire_owned()
71                .await
72                .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
73            let db = Arc::clone(&self.db);
74            let grants = self.grants.clone();
75            let branch = self.config.default_branch;
76            tokio::spawn(async move {
77                let _permit = permit;
78                let _ = serve_connection(stream, db, branch, grants).await;
79            });
80        }
81    }
82}
83
84async fn serve_connection(
85    mut stream: TcpStream,
86    db: Arc<InfiniteDb>,
87    branch: BranchId,
88    grants: Vec<SpaceGrant>,
89) -> io::Result<()> {
90    let pinned = db
91        .branch_head(branch)
92        .unwrap_or(SnapshotId(0));
93    let opened_at = db.revision();
94    let session = Session::open_at_revision(branch, pinned, opened_at, grants);
95
96    loop {
97        let request: Request = read_frame_async(&mut stream).await?;
98        let response = handle_request(&db, &session, request);
99        write_frame_async(&mut stream, &response).await?;
100        if matches!(response, Response::Error(_)) {
101            // keep connection alive for clients
102        }
103    }
104}
105
106async fn read_frame_async<T: bincode::Decode<()> + Send + 'static>(
107    stream: &mut TcpStream,
108) -> io::Result<T> {
109    let mut len_buf = [0u8; 8];
110    stream.read_exact(&mut len_buf).await?;
111    let len = u64::from_le_bytes(len_buf) as usize;
112    if len > 64 * 1024 * 1024 {
113        return Err(io::Error::new(
114            io::ErrorKind::InvalidData,
115            "frame too large",
116        ));
117    }
118    let mut payload = vec![0u8; len];
119    stream.read_exact(&mut payload).await?;
120    let (msg, _) = bincode::decode_from_slice::<T, _>(&payload, bincode::config::standard())
121        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
122    Ok(msg)
123}
124
125async fn write_frame_async<T: bincode::Encode + Send + Sync>(
126    stream: &mut TcpStream,
127    msg: &T,
128) -> io::Result<()> {
129    let payload = bincode::encode_to_vec(msg, bincode::config::standard())
130        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
131    let len = payload.len() as u64;
132    stream.write_all(&len.to_le_bytes()).await?;
133    stream.write_all(&payload).await?;
134    stream.flush().await
135}
136
137/// One-shot client helper for integration tests.
138pub async fn client_roundtrip(
139    addr: SocketAddr,
140    request: Request,
141) -> io::Result<Response> {
142    let mut stream = TcpStream::connect(addr).await?;
143    write_frame_async(&mut stream, &request).await?;
144    read_frame_async(&mut stream).await
145}
146
147/// Build admin grants for every registered space id.
148pub fn admin_grants(space_ids: &[u64]) -> Vec<SpaceGrant> {
149    space_ids
150        .iter()
151        .map(|id| SpaceGrant {
152            space: crate::infinitedb_core::address::SpaceId(*id),
153            level: AccessLevel::Admin,
154        })
155        .collect()
156}